Source code for quaterion.eval.group.retrieval_r_precision
import torch
from quaterion.distances import Distance
from quaterion.eval.group import GroupMetric
[docs]class RetrievalRPrecision(GroupMetric):
"""Compute the retrieval R-precision score for group based data
Retrieval R-Precision is the ratio of `r/R`, where `R` is the number of the relevant documents
for a given query in the collection, and `r` is the number of the truly relevant documents
found in the `R` highest scored results for that query.
Args:
distance_metric_name: name of a distance metric to calculate distance or similarity
matrices. Available names could be found in :class:`~quaterion.distances.Distance`.
Example:
Suppose that a collection contains 20 relevant documents for our query, and the model can
retrieve 15 of them in the 20 highest scored results, then Retrieval R-Precision is
calculated as r/R = 15/20 = 0.75.
"""
def __init__(
self,
distance_metric_name: Distance = Distance.COSINE,
):
super().__init__(
distance_metric_name=distance_metric_name,
)
[docs] def raw_compute(
self, distance_matrix: torch.Tensor, labels: torch.Tensor
) -> torch.Tensor:
return retrieval_r_precision(distance_matrix, labels)
[docs]def retrieval_r_precision(distance_matrix: torch.Tensor, labels: torch.Tensor):
"""Calculates retrieval r precision given distance matrix and labels
Args:
distance_matrix: distance matrix having max possible distance value on a diagonal
labels: labels matrix having False or 0. on a diagonal
Returns:
torch.Tensor: mean retrieval r precision
"""
# number of members for group which is on i-th position in groups
relevant_numbers = labels.sum(dim=-1).view(labels.shape[0], 1)
nearest_to_furthest_ind = torch.argsort(distance_matrix, dim=-1, descending=False)
sorted_by_distance = torch.gather(labels, dim=-1, index=nearest_to_furthest_ind)
top_k_mask = (
torch.arange(
0,
labels.shape[1],
step=1,
device=distance_matrix.device,
).repeat(labels.shape[0], 1)
< relevant_numbers
)
metric = sorted_by_distance[top_k_mask].float()
return metric.mean()