Shortcuts

Source code for quaterion.loss.softmax_loss

from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.loss.group_loss import GroupLoss


[docs]class SoftmaxLoss(GroupLoss): """Regular cross-entropy loss. An implementation of softmax with dot product. It is designed to work with the base :class:`~quaterion.loss.group_loss.GroupLoss`. Args: embedding_size: Output dimension of the encoder. num_groups: Number of groups in the dataset. temperature: Temperature value to divide logits, defaults to 0.05 """ def __init__(self, embedding_size: int, num_groups: int, temperature: float = 0.05): super(GroupLoss, self).__init__() self.temperature = temperature self.kernel = nn.Parameter(torch.FloatTensor(embedding_size, num_groups)) nn.init.normal_(self.kernel, std=0.01)
[docs] def forward( self, embeddings: Tensor, groups: LongTensor, ) -> Tensor: """Compute loss value. Args: embeddings: shape: (batch_size, vector_length) - Output embeddings from the encoder groups: shape: (batch_size,) - Group ids, associated with embeddings Returns: Tensor: zero-size tensor, loss value """ # shape: (batch_size, num_groups) logits = torch.mm(embeddings, self.kernel) / self.temperature return F.cross_entropy(logits, groups)