mindspore_gs.ptq.network_helpers.mf_net_helpers.MFLlama2Helper

View Source On Gitee
class mindspore_gs.ptq.network_helpers.mf_net_helpers.MFLlama2Helper(config: Union[str, MindFormerConfig] = None)[source]

Derived from 'NetworkHelper', a utility class for the MindFormers framework Llama2 network.

Parameters

config (MindFormerConfig) – MindFormerConfig for network.

Raises

TypeError – If input config is not an instance of MindFormerConfig.

analysis_decoder_groups(network)[source]

Analyze decoder groups information of network.

Parameters

network (Cell) – network to analyze decoder groups information.

assemble_inputs(input_ids: np.ndarray, **kwargs)[source]

Assemble network inputs for predict from input tokens in numpy ndarray format.

Parameters
  • input_ids (numpy.ndarray) – Input tokens.

  • kwargs (Dict) – Extensible parameter for subclasses.

Returns

A list of mindspore.Tensor as inputs of network predict.

get_pre_layer(linear_name: str)[source]

Get pre layer information from current linear_name.

Parameters

linear_name (str) – linear layer name.

Returns

A dict of pre layer information which include pre layer name, layer and type.