Source code for quaterion.dataset.similarity_data_loader
from typing import Any, Dict, Generic, List, Tuple, Union
import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.utils.data.dataloader import T_co
from quaterion.dataset.indexing_dataset import IndexingDataset, IndexingIterableDataset
from quaterion.dataset.label_cache_dataset import (
LabelCacheDataset,
LabelCacheIterableDataset,
LabelCacheMode,
)
from quaterion.dataset.similarity_samples import (
SimilarityGroupSample,
SimilarityPairSample,
)
[docs]class SimilarityDataLoader(DataLoader, Generic[T_co]):
"""Special version of :class:`~torch.utils.data.DataLoader` which works with similarity samples.
SimilarityDataLoader will automatically assign dummy collate_fn for debug purposes,
it will be overwritten once dataloader is used for training.
Required collate function should be defined individually for each encoder
by overwriting :meth:`~quaterion_models.encoders.encoder.Encoder.get_collate_fn`
Args:
dataset: Dataset which outputs similarity samples
**kwargs: Parameters passed directly into :meth:`~torch.utils.data.DataLoader.__init__`
"""
def __init__(self, dataset: Dataset, **kwargs):
if "collate_fn" not in kwargs:
kwargs["collate_fn"] = self.__class__.pre_collate_fn
self._original_dataset = dataset
self._original_params = kwargs
self._indexing_dataset_layer = self._wrap_indexing_dataset(dataset)
self._label_cache_layer = self._wrap_label_cache_dataset(
self._indexing_dataset_layer
)
super().__init__(self._label_cache_layer, **kwargs)
@property
def full_cache_used(self):
return self._label_cache_layer.mode != LabelCacheMode.transparent
[docs] def set_salt(self, salt):
"""Assigns a new salt to the IndexingDataset.
Might be useful to distinguish cache sequential keys for train and validation datasets.
Args:
salt: salt for index generation
"""
self._indexing_dataset_layer.set_salt(salt)
[docs] def set_skip_read(self, skip: bool):
"""Disable reading items in IndexingDataset.
If cache is already filled and sequential key is used -
it is not necessary to read dataset items the second time
Args:
skip: if True - do not read items, only indexes
"""
self._indexing_dataset_layer.set_skip_read(skip)
[docs] def set_label_cache_mode(self, mode: LabelCacheMode):
"""Manges how label caching works"""
self._label_cache_layer.set_mode(mode)
[docs] def save_label_cache(self, path: str):
self._label_cache_layer.save(path)
[docs] def load_label_cache(self, path: str):
self._label_cache_layer.load(path)
@classmethod
def _wrap_label_cache_dataset(
cls, dataset: Union[IndexingIterableDataset, IndexingDataset]
) -> Union[LabelCacheDataset, LabelCacheIterableDataset]:
if isinstance(dataset, IndexingDataset):
return LabelCacheDataset(dataset)
if isinstance(dataset, IndexingIterableDataset):
return LabelCacheIterableDataset(dataset)
raise NotImplementedError()
@classmethod
def _wrap_indexing_dataset(
cls, dataset: Dataset
) -> Union[IndexingIterableDataset, IndexingDataset]:
if isinstance(dataset, IterableDataset):
return IndexingIterableDataset(dataset)
else:
return IndexingDataset(dataset)
@property
def original_params(self) -> Dict[str, Any]:
"""Initialization params of the original dataset."""
return self._original_params
[docs] @classmethod
def pre_collate_fn(cls, batch: List[T_co]):
"""Function applied to batch before actual collate.
Splits batch into features - arguments of prediction and labels - targets.
Encoder-specific `collate_fn` will then be applied to feature list only.
Loss functions consumes labels from this function without any additional transformations.
Args:
batch: List of similarity samples
Returns:
- ids of the features
- features batch
- labels batch
"""
sample_ids, similarity_samples = list(zip(*batch))
sample_ids = list(sample_ids)
labels = cls.collate_labels(similarity_samples)
features, feature_ids = cls.flatten_objects(
batch=similarity_samples, hash_ids=sample_ids
)
return feature_ids, features, labels
[docs] @classmethod
def collate_labels(cls, batch: List[T_co]) -> Dict[str, torch.Tensor]:
"""Collate function for labels
Convert labels into tensors, suitable for loss passing directly into loss functions and
metric estimators.
Args:
batch: List of similarity samples
Returns:
Collated labels
"""
raise NotImplementedError()
[docs] @classmethod
def flatten_objects(
cls, batch: List[T_co], hash_ids: List[int]
) -> Tuple[List[Any], List[int]]:
"""Retrieve and enumerate objects from similarity samples.
Each individual object should be used as input for the encoder.
Additionally, associates hash_id with each feature, if there are more than one feature in
the sample - generates new unique ids based on input one.
Args:
batch: List of similarity samples
hash_ids: pseudo-random ids of the similarity samples
Returns:
- List of input features for encoder collate
- List of ids, associated with each feature
"""
raise NotImplementedError()
[docs]class GroupSimilarityDataLoader(SimilarityDataLoader[SimilarityGroupSample]):
"""DataLoader designed to work with data represented as
:class:`~quaterion.dataset.similarity_samples.SimilarityGroupSample`.
"""
def __init__(self, dataset: Dataset[SimilarityGroupSample], **kwargs):
super().__init__(dataset, **kwargs)
[docs] @classmethod
def collate_labels(
cls, batch: List[SimilarityGroupSample]
) -> Dict[str, torch.Tensor]:
"""Collate function for labels
Convert labels into tensors, suitable for loss passing directly into loss functions and
metric estimators.
Args:
batch: List of :class:`~quaterion.dataset.similarity_samples.SimilarityGroupSample`
Returns:
Collated labels:
- groups -- id of the group for each feature object
Examples:
>>> GroupSimilarityDataLoader.collate_labels(
... [
... SimilarityGroupSample(obj="orange", group=0),
... SimilarityGroupSample(obj="lemon", group=0),
... SimilarityGroupSample(obj="apple", group=1)
... ]
... )
{'groups': tensor([0, 0, 1])}
"""
labels = {"groups": torch.LongTensor([record.group for record in batch])}
return labels
[docs] @classmethod
def flatten_objects(
cls, batch: List[SimilarityGroupSample], hash_ids: List[int]
) -> Tuple[List[Any], List[int]]:
return [sample.obj for sample in batch], hash_ids