helper_classes

Module Contents

Classes

DatasetTriple

An abstract class representing a Dataset.

HeadAndRelationBatchLoader

An abstract class representing a Dataset.

Reproduce

Attributes

seed

helper_classes.seed = 1
class helper_classes.DatasetTriple(data)

Bases: torch.utils.data.Dataset

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

Note

DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

__len__()
__getitem__(idx)
class helper_classes.HeadAndRelationBatchLoader(er_vocab, num_e)

Bases: torch.utils.data.Dataset

An abstract class representing a Dataset.

All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and the default options of DataLoader.

Note

DataLoader by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.

__len__()
__getitem__(idx)
class helper_classes.Reproduce
static get_er_vocab(data)
static get_head_tail_vocab(data)
get_data_idxs(data)
get_batch_1_to_N(er_vocab, er_vocab_pairs, idx)
reproduce(model_path, data_path, model_name, per_rel_flag_=False, tail_pred_constraint=False, out_of_vocab_flag=False)
get_embeddings(model_path, data_path, model_name, per_rel_flag_=False, tail_pred_constraint=False, out_of_vocab_flag=False)
load_model(model_path, model_name)
reproduce_ensemble(model, data_path, per_rel_flag_=False, tail_pred_constraint=False, out_of_vocab_flag=False)

per_rel_flag_ reports link prediction results per relations. flag_of_removal -> removes triples from testing split containing entities that did not occur during training at testing time.

lp_based_on_head_and_tail_entity_rankings-> computes rank of missing entities based on head and tail entity.