mindformers.core.PromptAccMetric
- class mindformers.core.PromptAccMetric[source]
Computes the prompt acc of each entity. The prompt acc is the accuracy of text classification base on building prompt. The accurate index is the index of the prompt which has the minimum perplexity.
Build the prompt for this metric is described as follows:
这是关于**体育**的文章:$passage 这是关于**文化**的文章:$passage
Computes perplexity of each generated context based on prompt. Perplexity is a measurement about how well a probability distribution or a model predicts a sample. A low perplexity indicates the model can predict the sample well. The function is shown as follows:
\[PP(W)=P(w_{1}w_{2}...w_{N})^{-\frac{1}{N}}=\sqrt[N]{\frac{1}{P(w_{1}w_{2}...w_{N})}}\]Where \(w\) represents words in corpus.
Compute classification result by choosing the index of the prompt which has the minimum perplexity.
Count the number of correctly classified and the total number of samples and compute the acc as follows:
\[\text{accuracy} =\frac{\text{correct_sample_nums}}{\text{total_sample_nums}}\]
Examples
>>> import numpy as np >>> from mindspore import Tensor >>> from mindformers.core.metric.metric import PromptAccMetric >>> logtis = Tensor(np.array([[[[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]]])) >>> input_ids = Tensor(np.array([[15, 16, 17]])) >>> labels = Tensor(np.array([[1, 0, 1]])) >>> mask = Tensor(np.array([[1, 1, 1]])) >>> metric = PromptAccMetric() >>> metric.clear() >>> metric.update(logtis, input_ids, mask, labels) >>> result = metric.eval() >>> print(result) Current data num is 1, total acc num is 1.0, ACC is 1.000 Acc: 1.000, total_acc_num: 1.0, total_num: 1 {'Acc': 1.0}
- eval()[source]
Computing the evaluation result.
- Returns
A dict of evaluation results with Acc scores.
- update(*inputs)[source]
Updating the internal evaluation result.
- Parameters
*inputs (List) – Logits, input_ids, input_mask, and labels. where logits is a tensor of shape \([N,C,S,W]\) with data type Float16 or Float32, and input_ids, input_mask, and labels are tensors of shape \([N*C,S]\) with data type Int32 or Int64. Where \(N\) is batch size, \(C\) the total number of entity types, \(S\) is the sequence length, and \(W\) is the vocabulary size.