LRP Head Pruning

View Source On Gitee

MindSpore Golden Stick LRP module is an implementation of the LRP method for head pruning first presented in the article Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned and explained below.

For a complete tutorial of how to use this module please refer to the tutorial page.

Introduction

Pruning Attention Heads

The attention mechanism is a crucial component of the transformer architecture. Different studies have shown that most attention heads are not confident in their decisions and can be pruned. Since then, several algorithms have been proposed on how to intelligently prune attention heads.

LRP method

The LRP method starts from a converged model that was not necessarily specially trained to support head pruning. Then, it fine-tunes the model to support head pruning, and at the end of the fine-tuning phase some heads are pruned.

During the fine-tuning, the original Transformer architecture is modified by multiplying the representation computed by each head with a scalar gate. That is, when moving from the standard Transformer to the modified one, the representation computed by each \({\rm head}_i\) is multiplied by a scalar gate \(g_i\) before being concatenated with results of other attention heads in a layer: \({\rm MultiHead}(Q, K, V) = {\rm Concat}_i({\rm head}_i)W^O \rightarrow {\rm Concat}_i(g_i*{\rm head}_i)W^O \)

Note, \(g_i\) are parameters specific to heads and are independent of the input (i.e. the sentence).

Ideally, the values of the gates are either 1 (“open”) or 0 (“closed”), and the sparsity of the gates is encouraged by applying \(L0\) regularization to the scalar gates \(g_i\). The \(L0\) norm equals the number of non-zero components and would push the model to switch off less important heads.

Unfortunately, a direct implementation of this idea leads to an objective function that is non-differentiable with respect to the added gates. So instead, the algorithm uses a stochastic relaxation in which each gate is modeled as a continuous random variable drawn independently of a head-specific Hard Concrete distribution. The distributions have non-zero probability mass at 0 and 1; look at the illustration.

concrete

The non-differentiable \(L0\) norm is replaced by the sum of the probabilities of heads being non-zero (\(L_C\)) as a stochastic relaxation, and the resulting training objective is

\(L = L_{xent} + \lambda L_C\)

By varying the coefficient \(\lambda\) in the optimized objective, we obtain models with different numbers of retained heads. Below is shown how the probabilities of encoder heads being completely closed \((P(g_i)=0)\) change in training for different values of \(\lambda\) (pruning starts from a converged model).  White color denotes \(P(g_i=0) = 1\), which means that a head is completely removed from the model.

enc_head_gif

(Gif is for the model trained on EN-RU WMT. For other datasets, values of \(\lambda\) can be different.)

Empirically, the model converges to solutions where gates are either almost completely closed or completely open. This means that at test time we can treat the model as a standard Transformer and use only a subset of heads.

Overall, there are 3 main parameters that the user needs to know in order to get a general understanding of the algorithm backstage:

  • l0_penalty

    • The regularization coefficient through which we control the number of pruned heads

    • The larger the coefficient is, the more heads are pruned

  • temperature

    • Controls the relaxation of the ideal objective function

    • As the temperature decreases, the degree of approximation improves (at the expense of the derivatives’ stability)

  • log_a

    • A per head parameter indicating the “openness” of the gate associated with the head

    • These parameters are learned by the algorithm in the fine-tuning phase (unlike the previous two parameters that are set by the user)

\(L\)0_penalty graph

(From Differentiable Subset Pruning of Transformer Heads article)

l0_graph

Link to the article

Differentiable Subset Pruning of Transformer Heads

HeadPruning API

Our headPruning algorithm is integrated within the standard training loop, and returns a pruned model.

Create Dataset:

Create an NLP Dataset for your model.

Create optimizer:

The model parameter ‘log_a’ (which is introduced by the algorithm) should be defined with a high learning rate.

To define a different learning rate especially for this parameter, use the following script

params = network.trainable_params()
pruning_params = list(filter(lambda x: 'log_a' in x.name, params))
other_params = list(filter(lambda x: 'log_a' not in x.name, params))
group_params = [{'params': other_params},
                {'params': pruning_params, 'lr': 0.1},
                {'order_params': params}]

Load Model:

Load trained model from mindspore.

bert_model = BertModel(bert_net_cfg, True)

Create HeadPruner:

Initiate HeadPruner

from mindspore_gs.prune_heads import HeadPruningFactory, PruningType, SupportedModels

model_params = {'config': bert_net_config}

l0_penalty = 0.0015
head_pruner_config = {'model_params': model_params,
                      'l0_penalty': l0_penalty }

factory_dic = {"prune_type": PruningType.LRP,
               "arch_type": SupportedModels.BERT,
               "config": head_pruner_config}

pruner = HeadPruningFactory.get_pruner(pruner_config)

The parameters required for HeadPruner initialization are:

  • prune_type (Enum) – Type of pruning method.

  • arch_type (Enum) – Type of model to prune

  • config (Dictionary) - Head pruner config, contain:

    • model_params (Dictionary) - Parameters of model.

    • l0_penalty (float) – penalty value for gate calculation.

Apply model:

Repackage the model with additional functionality that supports the pruning.

bert_model_gated = pruner.apply(bert_model)

The function input is:

  • model - model to prune, possible with/without head.

Train Gated Model

For the fine-tuning of the model, we add a penalty factor to the loss expression

 sequence_output, pooled_output, embedding_table, total_reg = \
                                model(input_ids, token_type_id, input_mask)
 loss = ...
 loss += total_reg

Train the model between 1-3 epoch.

Get mask:

 mask = bert_model_gated.get_gate_values()

Prune model:

After training / fine-tuning, prune the model.

 prune_model = pruner.convert(bert_model_gated)

The function has an option to get another parameter:

  • save_dir_path (string) - path for saving the model before pruning, and pickle file with heads to prune dictionary.

 save_dir_path = '...'
 prune_model = pruner.convert(bert_model_gated, save_dir_path)

Run sample

BERT on MNLI

We provide a code example for LRP pruning of Bert model in run_sample_bert.py file. It can be run using the following shell command

export DEVICE_TARGET = "GPU"
export DATA_DIR = data/MNLI/
export TRAINING_MODEL_CKPT = "..."
python run_sample_bert.py \
          --device_target $DEVICE_TARGET \
          --distribute "false" \
          --epoch_size 3 \
          --enable_save_ckpt "false" \
          --enable_lossscale "true" \
          --do_shuffle "true" \
          --enable_data_sink "true" \
          --data_sink_steps 20 \
          --load_checkpoint_path $TRAINING_MODEL_CKPT \
          --save_checkpoint_path "" \
          --save_checkpoint_steps 10000 \
          --save_checkpoint_num 1 \
          --data_dir $DATA_DIR \