mindspore.train.FlopsUtilizationCollector

class mindspore.train.FlopsUtilizationCollector(data_size, computility=1, full_flops=True)[源代码]

FlopsUtilizationCollector接口统计模型利用率信息MFU,硬件利用率信息HFU。 当前接口只统计MatMul、BatchMatMul、FlashAttentionScore、Conv2D算子的正反向flops信息。 只支持静态图静态shape模式。

参数:
  • data_size (int) - 表示每隔多少个step打印一次信息。

  • computility (int) - 表示每张计算卡的峰值算力。默认值: 1

  • full_flops (bool) - 表示是否统计完整的模型信息,如果设置为True,会统计完整的模型信息,如果设置为False,将会统计对应每张卡的分片模型信息。默认值: True

异常:
  • TypeError - data_size 不是正整数。

  • TypeError - full_flops 不是布尔类型。

  • AssertionError - 不是静态图或者不是静态shape。

样例:

>>> import numpy as np
>>> import mindspore.dataset as ds
>>> from mindspore import nn
>>> from mindspore.train import Model, FlopsUtilizationCollector
>>> from mindspore import context
>>> context.set_context(mode=context.GRAPH_MODE)
>>> data = {"x": np.float32(np.random.rand(64, 10)), "y": np.random.randint(0, 5, (64,))}
>>> train_dataset = ds.NumpySlicesDataset(data=data).batch(32)
>>> net = nn.Dense(10, 5)
>>> crit = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
>>> opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
>>> flops_callback = FlopsUtilizationCollector(train_dataset.get_dataset_size(), computility=10e6)
>>> model = Model(network=net, optimizer=opt, loss_fn=crit, metrics={"recall"})
>>> model.train(2, train_dataset, callbacks=[flops_callback])
Full model flops is 6400, Full hardware flops is 6400, Shard model flops is 6400, Shard hardware flops is 6400
Train per step time: 135.572 ms, mfu:0.47% hfu:0.47%
Train per step time: 1.317 ms, mfu:48.59% hfu:48.59%
epoch_begin(run_context)

在epoch开始时记录时间。

参数:
epoch_end(run_context)

在epoch结束时打印模型利用率信息MFU,硬件利用率信息HFU。

参数: