LSTM+CRF Sequence Labeling
This case does not support running on the Windows operating system.
Overview
Sequence labeling refers to the process of labeling each token for a given input sequence. Sequence labeling is usually used to extract information from text, including word segmentation, part-of-speech tagging, and named entity recognition (NER). The following uses NER as an example:
Input Sequence |
the |
wall |
street |
journal |
reported |
today |
that |
apple |
corporation |
made |
money |
---|---|---|---|---|---|---|---|---|---|---|---|
Output Labeling |
B |
I |
I |
I |
O |
O |
O |
B |
I |
O |
O |
As shown in the preceding table, the wall street journal
and apple corporation
are place names and need to be identified. We predict the label of each input word and identify the entity based on the label.
A common labeling method for NER is used, that is, BIOE labeling. The beginning of an entity is labeled as B, other parts are labeled as I, and non-entity is labeled as O.
Conditional Random Field (CRF)
It can be learned from the preceding example that labeling a sequence is actually performing label prediction on each token in the sequence, and may be directly considered as a simple multi-classification problem. However, sequence labeling not only needs to classify and predict a single token, but also directly associates adjacent tokens. The the wall street journal
is used as an example.
Input Sequence |
the |
wall |
street |
journal |
|
---|---|---|---|---|---|
Output Labeling |
B |
I |
I |
I |
√ |
Output Labeling |
O |
I |
I |
I |
× |
As shown in the preceding table, the four tokens contained in the correct entity depend on each other. A word before I must be B or I. However, in the error output, the token the
is marked as O, which violates the dependency. If NER is regarded as a multi-classification problem, the prediction probability of each word is independent and similar problems may occur. Therefore, an algorithm that can learn the association relationship is introduced to ensure the correctness of the prediction result. CRF is a probabilistic graphical model suitable for this scenario. The definition and parametric form of conditional random field are briefly analyzed in the following.
Considering the linear sequence feature of the sequence labeling problem, the CRF described in this section refers to the linear chain CRF.
Assume that
If
The emission probability function
indicates the probability of .The transition probability function
indicates the probability of .
The formula for calculating
Assume that the label set is
For details about the complete CRF-based deduction, see Log-Linear Models, MEMMs, and CRFs.
Next, we use MindSpore to implement the CRF parameterization based on the preceding formula. First, a forward training part of a CRF layer is implemented, the CRF and a loss function are combined, and a negative log likelihood (NLL) function commonly used for a classification problem is selected.
According to the formula
According to the formula
Score Calculation
First, the score corresponding to the correct label sequence is calculated according to the formula
def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans):
# emissions: (seq_length, batch_size, num_tags)
# tags: (seq_length, batch_size)
# mask: (seq_length, batch_size)
seq_length, batch_size = tags.shape
mask = mask.astype(emissions.dtype)
# Set score to the initial transition probability.
# shape: (batch_size,)
score = start_trans[tags[0]]
# score += Probability of the first emission
# shape: (batch_size,)
score += emissions[0, mnp.arange(batch_size), tags[0]]
for i in range(1, seq_length):
# Probability that the label is transited from i-1 to i (valid when mask == 1).
# shape: (batch_size,)
score += trans[tags[i - 1], tags[i]] * mask[i]
# Emission probability of tags[i] prediction(valid when mask == 1).
# shape: (batch_size,)
score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]
# End the transition.
# shape: (batch_size,)
last_tags = tags[seq_ends, mnp.arange(batch_size)]
# score += End transition probability
# shape: (batch_size,)
score += end_trans[last_tags]
return score
Normalizer Calculation
According to the formula
Assume that you need to calculate the scores
According to formula (7), the Normalizer is implemented as follows:
def compute_normalizer(emissions, mask, trans, start_trans, end_trans):
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)
seq_length = emissions.shape[0]
# Set score to the initial transition probability and add the first emission probability.
# shape: (batch_size, num_tags)
score = start_trans + emissions[0]
for i in range(1, seq_length):
# The score dimension is extended to calculate the total score.
# shape: (batch_size, num_tags, 1)
broadcast_score = score.expand_dims(2)
# The emission dimension is extended to calculate the total score.
# shape: (batch_size, 1, num_tags)
broadcast_emissions = emissions[i].expand_dims(1)
# Calculate score_i according to formula (7).
# In this case, broadcast_score indicates all possible paths from token 0 to the current token.
# log_sum_exp corresponding to score
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + trans + broadcast_emissions
# Perform the log_sum_exp operation on score_i to calculate the score of the next token.
# shape: (batch_size, num_tags)
next_score = ops.logsumexp(next_score, axis=1)
# The score changes only when mask == 1.
# shape: (batch_size, num_tags)
score = mnp.where(mask[i].expand_dims(1), next_score, score)
# Add the end transition probability.
# shape: (batch_size, num_tags)
score += end_trans
# Calculate log_sum_exp based on the scores of all possible paths.
# shape: (batch_size,)
return ops.logsumexp(score, axis=1)
Viterbi Algorithm
After the forward training part is completed, the decoding part needs to be implemented. Here we select the Viterbi algorithm that is suitable for finding the optimal path of the sequence. Similar to calculating Normalizer, dynamic programming is used to solve all possible prediction sequence scores. The difference is that the label with the maximum score corresponding to token
After obtaining the maximum probability score
The 0th token to the
Due to the syntax restrictions of static graphs, the Viterbi algorithm is used to solve the optimal prediction sequence as a post-processing function and is not included in the implementation of the CRF layer.
def viterbi_decode(emissions, mask, trans, start_trans, end_trans):
# emissions: (seq_length, batch_size, num_tags)
# mask: (seq_length, batch_size)
seq_length = mask.shape[0]
score = start_trans + emissions[0]
history = ()
for i in range(1, seq_length):
broadcast_score = score.expand_dims(2)
broadcast_emission = emissions[i].expand_dims(1)
next_score = broadcast_score + trans + broadcast_emission
# Obtain the label with the maximum score corresponding to the current token and save the label.
indices = next_score.argmax(axis=1)
history += (indices,)
next_score = next_score.max(axis=1)
score = mnp.where(mask[i].expand_dims(1), next_score, score)
score += end_trans
return score, history
def post_decode(score, history, seq_length):
# Use Score and History to calculate the optimal prediction sequence.
batch_size = seq_length.shape[0]
seq_ends = seq_length - 1
# shape: (batch_size,)
best_tags_list = []
# Decode each sample in a batch in sequence.
for idx in range(batch_size):
# Search for the label that maximizes the prediction probability corresponding to the last token.
# Add it to the list of best prediction sequence stores.
best_last_tag = score[idx].argmax(axis=0)
best_tags = [int(best_last_tag.asnumpy())]
# Repeatedly search for the label with the maximum prediction probability corresponding to each token and add the label to the list.
for hist in reversed(history[:seq_ends[idx]]):
best_last_tag = hist[idx][best_tags[-1]]
best_tags.append(int(best_last_tag.asnumpy()))
# Reset the solved label sequence in reverse order to the positive sequence.
best_tags.reverse()
best_tags_list.append(best_tags)
return best_tags_list
CRF Layer
After the code of the forward training part and the code of the decoding part are completed, a complete CRF layer is assembled. Considering that the input sequence may be padded, the actual length of the input sequence needs to be considered during CRF input. Therefore, in addition to the emissions matrix and label, the seq_length
parameter is added to transfer the length of the sequence before padding and implement the sequence_mask
method for generating the mask matrix.
Based on the preceding code, nn.Cell
is used for encapsulation. The complete CRF layer is implemented as follows:
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.numpy as mnp
from mindspore.common.initializer import initializer, Uniform
def sequence_mask(seq_length, max_length, batch_first=False):
"""Generate the mask matrix based on the actual length and maximum length of the sequence."""
range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)
result = range_vector < seq_length.view(seq_length.shape + (1,))
if batch_first:
return result.astype(ms.int64)
return result.astype(ms.int64).swapaxes(0, 1)
class CRF(nn.Cell):
def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
if num_tags <= 0:
raise ValueError(f'invalid number of tags: {num_tags}')
super().__init__()
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
raise ValueError(f'invalid reduction: {reduction}')
self.num_tags = num_tags
self.batch_first = batch_first
self.reduction = reduction
self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')
self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')
self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')
def construct(self, emissions, tags=None, seq_length=None):
if tags is None:
return self._decode(emissions, seq_length)
return self._forward(emissions, tags, seq_length)
def _forward(self, emissions, tags=None, seq_length=None):
if self.batch_first:
batch_size, max_length = tags.shape
emissions = emissions.swapaxes(0, 1)
tags = tags.swapaxes(0, 1)
else:
max_length, batch_size = tags.shape
if seq_length is None:
seq_length = mnp.full((batch_size,), max_length, ms.int64)
mask = sequence_mask(seq_length, max_length)
# shape: (batch_size,)
numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)
# shape: (batch_size,)
denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
# shape: (batch_size,)
llh = denominator - numerator
if self.reduction == 'none':
return llh
if self.reduction == 'sum':
return llh.sum()
if self.reduction == 'mean':
return llh.mean()
return llh.sum() / mask.astype(emissions.dtype).sum()
def _decode(self, emissions, seq_length=None):
if self.batch_first:
batch_size, max_length = emissions.shape[:2]
emissions = emissions.swapaxes(0, 1)
else:
batch_size, max_length = emissions.shape[:2]
if seq_length is None:
seq_length = mnp.full((batch_size,), max_length, ms.int64)
mask = sequence_mask(seq_length, max_length)
return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)
BiLSTM+CRF Model
After CRF is implemented, a bidirectional LSTM+CRF model is designed to train NER tasks. The model structure is as follows:
nn.Embedding -> nn.LSTM -> nn.Dense -> CRF
The LSTM extracts a sequence feature, obtains an emission probability matrix by means of Dense layer transformation, and finally sends the emission probability matrix to the CRF layer. The sample code is as follows:
class BiLSTM_CRF(nn.Cell):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)
self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')
self.crf = CRF(num_tags, batch_first=True)
def construct(self, inputs, seq_length, tags=None):
embeds = self.embedding(inputs)
outputs, _ = self.lstm(embeds, seq_length=seq_length)
feats = self.hidden2tag(outputs)
crf_outs = self.crf(feats, tags, seq_length)
return crf_outs
After the model design is complete, two examples and corresponding labels are generated, and a vocabulary and a label table are built.
embedding_dim = 16
hidden_dim = 32
training_data = [(
"the wall street journal reported today that apple corporation made money".split(),
"B I I I O O O B I O O".split()
), (
"georgia tech is a university in georgia".split(),
"B I O O O O B".split()
)]
word_to_idx = {}
word_to_idx['<pad>'] = 0
for sentence, tags in training_data:
for word in sentence:
if word not in word_to_idx:
word_to_idx[word] = len(word_to_idx)
tag_to_idx = {"B": 0, "I": 1, "O": 2}
len(word_to_idx)
21
Instantiate the model, select an optimizer, and send the model and optimizer to the Wrapper.
The NLLLoss has been calculated at the CRF layer. Therefore, you do not need to set Loss.
model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))
optimizer = nn.SGD(model.trainable_params(), learning_rate=0.01, weight_decay=1e-4)
grad_fn = ms.value_and_grad(model, None, optimizer.parameters)
def train_step(data, seq_length, label):
loss, grads = grad_fn(data, seq_length, label)
optimizer(grads)
return loss
Pack the generated data into a batch, pad the sequence with insufficient length based on the maximum sequence length, and return tensors consisting of the input sequence, output label, and sequence length.
def prepare_sequence(seqs, word_to_idx, tag_to_idx):
seq_outputs, label_outputs, seq_length = [], [], []
max_len = max([len(i[0]) for i in seqs])
for seq, tag in seqs:
seq_length.append(len(seq))
idxs = [word_to_idx[w] for w in seq]
labels = [tag_to_idx[t] for t in tag]
idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])
labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])
seq_outputs.append(idxs)
label_outputs.append(labels)
return ms.Tensor(seq_outputs, ms.int64), \
ms.Tensor(label_outputs, ms.int64), \
ms.Tensor(seq_length, ms.int64)
data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)
data.shape, label.shape, seq_length.shape
((2, 11), (2, 11), (2,))
After the model is precompiled, 500 steps are trained.
Training process visualization depends on the
tqdm
library, which can be installed by running thepip install tqdm
command.
from tqdm import tqdm
steps = 500
with tqdm(total=steps) as t:
for i in range(steps):
loss = train_step(data, seq_length, label)
t.set_postfix(loss=loss)
t.update(1)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:23<00:00, 21.13it/s, loss=0.3487625]
Finally, let's observe the model effect after 500 steps of training. First, use the model to predict possible path scores and candidate sequences.
score, history = model(data, seq_length)
score
Tensor(shape=[2, 3], dtype=Float32, value=
[[ 3.15928860e+01, 3.63119812e+01, 3.17248516e+01],
[ 2.81416149e+01, 2.61749763e+01, 3.24760780e+01]])
Perform post-processing on the predicted score.
predict = post_decode(score, history, seq_length)
predict
[[0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2], [0, 1, 2, 2, 2, 2, 0, 2, 2]]
Finally, convert the predicted index sequence into a label sequence, print the output result, and view the effect.
idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}
def sequence_to_tag(sequences, idx_to_tag):
outputs = []
for seq in sequences:
outputs.append([idx_to_tag[i] for i in seq])
return outputs
sequence_to_tag(predict, idx_to_tag)
[['B', 'I', 'I', 'I', 'O', 'O', 'O', 'B', 'I', 'O', 'O'],
['B', 'I', 'O', 'O', 'O', 'O', 'B', 'O', 'O']]