quaterion.loss.pairwise_loss module¶
- class PairwiseLoss(distance_metric_name: Distance = Distance.COSINE)[source]¶
Bases:
SimilarityLoss
Base class for pairwise losses.
- Parameters:
distance_metric_name – Name of the distance function, e.g.,
Distance
.
- forward(embeddings: Tensor, pairs: Tensor, labels: Tensor, subgroups: Tensor) Tensor [source]¶
Compute loss value.
- Parameters:
embeddings – shape: (batch_size, vector_length)
pairs – shape: (2 * pairs_count,) - contains a list of known similarity pairs in batch
labels – shape: (pairs_count,) - similarity of the pair
subgroups – shape: (2 * pairs_count,) - subgroup ids of objects
- Returns:
Tensor – zero-size tensor, loss value
- training: bool¶