mindspore.nn.CosineSimilarity
- class mindspore.nn.CosineSimilarity(similarity='cosine', reduction='none', zero_diagonal=True)[源代码]
计算余弦相似度。
- 参数:
similarity (str) - “dot”或”cosine”。”cosine”表示相似度计算逻辑, “dot”表示矩阵点乘矩阵计算逻辑。默认值:”cosine”。
reduction (str) - “none”、”sum”或”mean”。默认值:”none”。
zero_diagonal (bool) - 如果为True,则对角线将设置为零。默认值:True。
- 支持平台:
Ascend
GPU
CPU
样例:
>>> import numpy as np >>> from mindspore import nn >>> >>> test_data = np.array([[1, 3, 4, 7], [2, 4, 2, 5], [3, 1, 5, 8]]) >>> metric = nn.CosineSimilarity() >>> metric.clear() >>> metric.update(test_data) >>> square_matrix = metric.eval() >>> print(square_matrix) [[0. 0.94025615 0.95162452] [0.94025615 0. 0.86146098] [0.95162452 0.86146098 0.]]