Shortcuts

Source code for quaterion.loss.group_loss

from typing import Optional

from torch import LongTensor, Tensor

from quaterion.distances import Distance
from quaterion.loss.similarity_loss import SimilarityLoss


[docs]class GroupLoss(SimilarityLoss): """Base class for group losses. Args: distance_metric_name: Name of the distance function, e.g., :class:`~quaterion.distances.Distance`. """ def __init__(self, distance_metric_name: Distance = Distance.COSINE): super(GroupLoss, self).__init__(distance_metric_name=distance_metric_name)
[docs] def forward(self, embeddings: Tensor, groups: LongTensor) -> Tensor: """ Args: embeddings: shape: (batch_size, vector_length) groups: shape: (batch_size,) - Groups, associated with `embeddings` Returns: Tensor: zero-size tensor, loss value """ raise NotImplementedError()
[docs] def xbm_loss( self, embeddings: Tensor, groups: LongTensor, memory_embeddings: Tensor, memory_groups: LongTensor, ) -> Tensor: """Implement XBM loss computation for this loss. Args: embeddings: shape: (batch_size, vector_length) - Output embeddings from the encoder. groups: shape: (batch_size,) - Group ids associated with embeddings. memory_embeddings: shape: (memory_buffer_size, vector_length) - Embeddings stored in a ring buffer memory_groups: shape: (memory_buffer_size,) - Groups ids associated with `memory_embeddings` Returns: Tensor: zero-size tensor, XBM loss value. """ raise NotImplementedError( f"XBM is not implemented for {self.__class__.__name__}" )