quaterion.dataset.similarity_data_loader module¶
- class GroupSimilarityDataLoader(dataset: Dataset[SimilarityGroupSample], **kwargs)[source]¶
Bases:
SimilarityDataLoader
[SimilarityGroupSample
]DataLoader designed to work with data represented as
SimilarityGroupSample
.- classmethod collate_labels(batch: List[SimilarityGroupSample]) Dict[str, Tensor] [source]¶
Collate function for labels
Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators.
- Parameters:
batch – List of
SimilarityGroupSample
- Returns:
Collated labels –
groups – id of the group for each feature object
Examples
>>> GroupSimilarityDataLoader.collate_labels( ... [ ... SimilarityGroupSample(obj="orange", group=0), ... SimilarityGroupSample(obj="lemon", group=0), ... SimilarityGroupSample(obj="apple", group=1) ... ] ... ) {'groups': tensor([0, 0, 1])}
- classmethod flatten_objects(batch: List[SimilarityGroupSample], hash_ids: List[int]) Tuple[List[Any], List[int]] [source]¶
Retrieve and enumerate objects from similarity samples.
Each individual object should be used as input for the encoder. Additionally, associates hash_id with each feature, if there are more than one feature in the sample - generates new unique ids based on input one.
- Parameters:
batch – List of similarity samples
hash_ids – pseudo-random ids of the similarity samples
- Returns:
List of input features for encoder collate
List of ids, associated with each feature
- batch_size: int | None¶
- drop_last: bool¶
- num_workers: int¶
- pin_memory: bool¶
- pin_memory_device: str¶
- prefetch_factor: int | None¶
- timeout: float¶
- class PairsSimilarityDataLoader(dataset: Dataset[SimilarityPairSample], **kwargs)[source]¶
Bases:
SimilarityDataLoader
[SimilarityPairSample
]DataLoader designed to work with data represented as
SimilarityPairSample
.- classmethod collate_labels(batch: List[SimilarityPairSample]) Dict[str, Tensor] [source]¶
Collate function for labels of
SimilarityPairSample
Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators.
- Parameters:
batch – List of
SimilarityPairSample
- Returns:
Collated labels –
labels - tensor of scores for each input pair
pairs - pairs of id offsets of features, associated with respect labels
subgroups - subgroup id for each featire
Examples
>>> labels_batch = PairsSimilarityDataLoader.collate_labels( ... [ ... SimilarityPairSample( ... obj_a="1st_pair_1st_obj", obj_b="1st_pair_2nd_obj", score=1.0, subgroup=0 ... ), ... SimilarityPairSample( ... obj_a="2nd_pair_1st_obj", obj_b="2nd_pair_2nd_obj", score=0.0, subgroup=1 ... ), ... ] ... ) >>> labels_batch['labels'] tensor([1., 0.]) >>> labels_batch['pairs'] tensor([[0, 2], [1, 3]]) >>> labels_batch['subgroups'] tensor([0., 1., 0., 1.])
- classmethod flatten_objects(batch: List[SimilarityPairSample], hash_ids: List[int]) Tuple[List[Any], List[int]] [source]¶
Retrieve and enumerate objects from similarity samples.
Each individual object should be used as input for the encoder. Additionally, associates hash_id with each feature, if there are more than one feature in the sample - generates new unique ids based on input one.
- Parameters:
batch – List of similarity samples
hash_ids – pseudo-random ids of the similarity samples
- Returns:
List of input features for encoder collate
List of ids, associated with each feature
- batch_size: int | None¶
- drop_last: bool¶
- num_workers: int¶
- pin_memory: bool¶
- pin_memory_device: str¶
- prefetch_factor: int | None¶
- timeout: float¶
- class SimilarityDataLoader(dataset: Dataset, **kwargs)[source]¶
Bases:
DataLoader
,Generic
[T_co
]Special version of
DataLoader
which works with similarity samples.SimilarityDataLoader will automatically assign dummy collate_fn for debug purposes, it will be overwritten once dataloader is used for training.
Required collate function should be defined individually for each encoder by overwriting
get_collate_fn()
- Parameters:
dataset – Dataset which outputs similarity samples
**kwargs – Parameters passed directly into
__init__()
- classmethod collate_labels(batch: List[T_co]) Dict[str, Tensor] [source]¶
Collate function for labels
Convert labels into tensors, suitable for loss passing directly into loss functions and metric estimators.
- Parameters:
batch – List of similarity samples
- Returns:
Collated labels
- classmethod flatten_objects(batch: List[T_co], hash_ids: List[int]) Tuple[List[Any], List[int]] [source]¶
Retrieve and enumerate objects from similarity samples.
Each individual object should be used as input for the encoder. Additionally, associates hash_id with each feature, if there are more than one feature in the sample - generates new unique ids based on input one.
- Parameters:
batch – List of similarity samples
hash_ids – pseudo-random ids of the similarity samples
- Returns:
List of input features for encoder collate
List of ids, associated with each feature
- classmethod pre_collate_fn(batch: List[T_co])[source]¶
Function applied to batch before actual collate.
Splits batch into features - arguments of prediction and labels - targets. Encoder-specific collate_fn will then be applied to feature list only. Loss functions consumes labels from this function without any additional transformations.
- Parameters:
batch – List of similarity samples
- Returns:
ids of the features
features batch
labels batch
- set_label_cache_mode(mode: LabelCacheMode)[source]¶
Manges how label caching works
- set_salt(salt)[source]¶
Assigns a new salt to the IndexingDataset. Might be useful to distinguish cache sequential keys for train and validation datasets.
- Parameters:
salt – salt for index generation
- set_skip_read(skip: bool)[source]¶
Disable reading items in IndexingDataset. If cache is already filled and sequential key is used - it is not necessary to read dataset items the second time
- Parameters:
skip – if True - do not read items, only indexes
- batch_size: int | None¶
- drop_last: bool¶
- property full_cache_used¶
- num_workers: int¶
- property original_params: Dict[str, Any]¶
Initialization params of the original dataset.
- pin_memory: bool¶
- pin_memory_device: str¶
- prefetch_factor: int | None¶
- timeout: float¶