Source code for quaterion.loss.online_contrastive_loss
from typing import Optional
import torch
import torch.nn.functional as F
from torch import LongTensor, Tensor
from quaterion.distances import Distance
from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import get_anchor_positive_mask, max_value_of_dtype
[docs]class OnlineContrastiveLoss(GroupLoss):
"""Implements Contrastive Loss as defined in
http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
Unlike :class:`~quaterion.loss.contrastive_loss.ContrastiveLoss`, this one supports online pair
mining, i.e., it makes positive and negative pairs on-the-fly, so you don't need to form such
pairs yourself. Instead, it first calculates all possible pairs in a batch, and then filters
valid positive pairs and valid negative pairs separately. Batch-all and batch-hard strategies
for online pair mining are supported.
Args:
margin: Margin value to push negative examples
apart. Optional, defaults to `0.5`.
distance_metric_name: Name of the distance function, e.g.,
:class:`~quaterion.distances.Distance`. Optional, defaults to
:attr:`~quaterion.distances.Distance.COSINE`.
mining (str, optional): Pair mining strategy. One of `"all"`, `"hard"`.
Defaults to `"hard"`.
"""
def __init__(
self,
margin: Optional[float] = 0.5,
distance_metric_name: Distance = Distance.COSINE,
mining: Optional[str] = "hard",
):
mining_types = ["all", "hard"]
if mining not in mining_types:
raise ValueError(
f"Unrecognized mining strategy: {mining}. Must be one of {', '.join(mining_types)}"
)
super(OnlineContrastiveLoss, self).__init__(
distance_metric_name=distance_metric_name
)
self._margin = margin
self._mining = mining
[docs] def get_config_dict(self):
config = super().get_config_dict()
config.update(
{
"margin": self._margin,
"mining": self._mining,
}
)
return config
[docs] def forward(
self,
embeddings: Tensor,
groups: LongTensor,
) -> Tensor:
"""Calculates Contrastive Loss by making pairs on-the-fly.
Args:
embeddings: Shape: (batch_size, vector_length) - Batch of embeddings
groups: Shape (batch_size,) Batch of labels associated with `embeddings`
Returns:
torch.Tensor: Scalar loss value.
"""
# Shape: (batch_size, batch_size)
dists = self.distance_metric.distance_matrix(embeddings)
# get a mask for valid anchor-positive pairs and apply it to the distance matrix
# to set invalid ones to 0
anchor_positive_mask = get_anchor_positive_mask(groups, groups)
anchor_positive_dists = (
anchor_positive_mask.float() * dists
) # invalid pairs set to 0
# get a mask for valid anchor-negative pairs, and apply it to distance matrix
# # to set invalid ones to a maximum value of dtype
anchor_negative_mask = ~anchor_positive_mask
anchor_negative_dists = dists
anchor_negative_dists[~anchor_negative_mask] = max_value_of_dtype(
anchor_negative_dists.dtype
)
if self._mining == "all":
num_positive_pairs = anchor_positive_mask.sum()
positive_loss = anchor_positive_dists.sum() / torch.max(
num_positive_pairs, torch.tensor(1e-16)
)
num_negative_pairs = anchor_negative_mask.float().sum()
negative_loss = F.relu(
self._margin - anchor_negative_dists
).sum() / torch.max(num_negative_pairs, torch.tensor(1e-16))
else: # batch-hard pair mining
# get the hardest positive for each anchor
# shape: (batch_size,)
hardest_positive_dists = anchor_positive_dists.max(dim=1)[0]
num_positive_pairs = torch.count_nonzero(hardest_positive_dists)
positive_loss = hardest_positive_dists.sum() / torch.max(
num_positive_pairs, torch.tensor(1e-16)
)
# get the hardest negative for each anchor
# shape (batch_size,)
hardest_negative_dists = anchor_negative_dists.min(dim=1)[0]
num_negative_pairs = torch.sum(
(
hardest_negative_dists
< max_value_of_dtype(
hardest_negative_dists.dtype
) # It's True where we didn't set to this maximum value to mark them invalid
).float()
)
negative_loss = F.relu(
self._margin - hardest_negative_dists
).sum() / torch.max(num_negative_pairs, torch.tensor(1e-16))
total_loss = 0.5 * (positive_loss + negative_loss)
return total_loss