mindspore_gs.ptq.network_helpers.mf_net_helpers.MFLlama2Helper

View Source On Gitee
class mindspore_gs.ptq.network_helpers.mf_net_helpers.MFLlama2Helper(config: Union[str, MindFormerConfig] = None)[source]

Derived from 'NetworkHelper', a utility class for the MindFormers framework Llama2 network.

Parameters

config (MindFormerConfig) – MindFormerConfig for network.

Raises

TypeError – If input config is not an instance of MindFormerConfig.

assemble_inputs(input_ids: np.ndarray, **kwargs)[source]

Assemble network inputs for predict from input tokens in numpy ndarray format.

Parameters
  • input_ids (numpy.ndarray) – Input tokens.

  • kwargs (Dict) – Extensible parameter for subclasses.

Returns

A list of mindspore.Tensor as inputs of network predict.