Vectorization Acceleration Interface Vmap

Download NotebookView Source On Gitee

Overview

The vigorous development of AI converged computing poses new requirements and challenges to framework capabilities. Problem scenarios and model design become increasingly complex. As a result, the service data dimensions and the nesting depth of operations increase accordingly. Even if the vectorization optimization method can effectively resolve performance bottlenecks, it is not easy for common users to implement. Users can easily implement low-dimensional data operations. However, as the data dimensions increase, the service becomes more complex, which requires users to clearly understand the mapping between data dimensions of operations, bringing great challenges to model design and coding. The automatic vectorization (Vmap) feature helps users solve this problem, which allows users to separate specific batch processing from functions. When writing a function, users only need to consider the low-dimensional operations. The vmap API is called to automatically implement high-dimensional operation. In addition, nested calling is supported, which effectively reduces the problem complexity.

This tutorial describes how to use the vmap API to convert highly repeated operations in models or functions into parallel vector operation, achieving simplified code and efficient execution performance.

Vectorization Thinking

Vectorization thinking is common in technologies that improve computing performance. Vectorization thinking can be formalized as follows:

\[\begin{split}\begin{matrix} a_{1} + b_{1} = c_{1} \\ a_{2} + b_{2} = c_{2} \\ a_{3} + b_{3} = c_{3} \\ a_{4} + b_{4} = c_{4} \end{matrix} \Rightarrow \vec{a} + \vec{b} = \vec{c}\end{split}\]

The core idea is to convert the operations of multiple for loops into one vector operation. Vectorization thinking still works when it comes to a function or a set of operations of a model.

Manual Vectorization

First, we construct a simple convolution function, which is applicable to one-dimensional vector scenarios.

[1]:
import mindspore
from mindspore import Tensor, ops
import mindspore.numpy as mnp

x = mnp.arange(5).astype('float32')
w = mnp.array([1., 2., 3.])

def convolve(x, w):
    output = []
    for i in range(1, len(x) - 1):
        output.append(mnp.dot(x[i - 1 : i + 2], w))
    return mnp.stack(output)

convolve(x, w)
[1]:
Tensor(shape=[3], dtype=Float32, value= [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01])

When we expect this function to be used to compute a batch of one-dimensional convolution operations, we usually think of calling the for loop for batch processing.

[2]:
x_batch = mnp.stack([x, x, x])
w_batch = mnp.stack([w, w, w])

def manually_batch_conv(x_batch, w_batch):
    output = []
    for i in range(x_batch.shape[0]):
        output.append(convolve(x_batch[i], w_batch[i]))
    return mnp.stack(output)

manually_batch_conv(x_batch, w_batch)
[2]:
Tensor(shape=[3, 3], dtype=Float32, value=
[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])

Obviously, we obtain a correct computation result, but in low efficiency. Of course, we can also manually rewrite functions to achieve more efficient vectorized computing, but this involves processing information such as data indexes and axes.

[3]:
def manually_vectorization_conv(x_batch, w_batch):
    output = []
    for i in range(1, x_batch.shape[-1] - 1):
        output.append(mnp.sum(x_batch[:, i - 1 : i + 2] * w_batch, axis=1))
    return mnp.stack(output, axis=1)

manually_vectorization_conv(x_batch, w_batch)
[3]:
Tensor(shape=[3, 3], dtype=Float32, value=
[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])

In low-dimensional scenarios, you can easily understand the mapping between data indexes. However, as the number of dimensions increases, the computing becomes more complex, and you may feel a headache for this confusion. Fortunately, Vmap provides us with another way to do it.

Vmap

Vmap helps us hide batch dimensions. You only need to call an API to convert a function to a vectorized form.

[4]:
from mindspore import vmap

auto_vectorization_conv = vmap(convolve)
auto_vectorization_conv(x_batch, w_batch)
[4]:
Tensor(shape=[3, 3], dtype=Float32, value=
[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])

In addition to providing you with simple programming experience, Vmap offloads loop to each primitive operation of a function and combines distributed parallel optimization to achieve higher execution performance. By default, the input and output of vmap are batched along the first axis. If your input and output are not always expected to be batch processed along the 0 axis, you can specify them using the in_axes and out_axes parameters. You can specify a batch axis index separately for each input or output, or specify the same batch axis index for all inputs or outputs.

[5]:
w_batch_t = ops.transpose(w_batch, (1, 0))

auto_vectorization_conv = vmap(convolve, in_axes=(0, 1), out_axes=1)
output = auto_vectorization_conv(x_batch, w_batch_t)

ops.transpose(output, (1, 0))
[5]:
Tensor(shape=[3, 3], dtype=Float32, value=
[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])

In the scenario with multiple inputs, you can specify that only some arguments are processed in batches. For example, in the preceding scenario, the convolution of a group of one-dimensional vectors and a weight is computed. You can configure None in the input of the in_axes parameter. None indicates that batch processing is not performed along any axis.

[6]:
auto_vectorization_conv = vmap(convolve, in_axes=(0, None), out_axes=0)
auto_vectorization_conv(x_batch, w)
[6]:
Tensor(shape=[3, 3], dtype=Float32, value=
[[ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01],
 [ 8.00000000e+00,  1.40000000e+01,  2.00000000e+01]])

To ensure the correctness of the Vmap operation, Vmap verifies the input dimension, axis index, and batch size. For details about the parameter restrictions, see mindspore.vmap.

Nesting of High-Order Functions

Vmap is essentially a high-order function that takes the function as input and returns a vectorized function that can be applied to batch data processing. It can be nested and combined with high-order functions provided by other frameworks.

  • The two vmap APIs are nested with each other and apply to the batch processing of more than two layers.

[7]:
hyper_x = Tensor([[1., 2., 3., 4., 5.], [2., 3., 4., 5., 6.], [3., 4., 5., 6., 7.]], mindspore.float32)
hyper_w = Tensor([[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]], mindspore.float32)

hyper_vmap_ger = vmap(vmap(convolve, in_axes=[None, 0]), in_axes=[0, None])
hyper_vmap_ger(hyper_x, hyper_w)
[7]:
Tensor(shape=[3, 3, 3], dtype=Float32, value=
[[[ 6.00000000e+00,  9.00000000e+00,  1.20000000e+01],
  [ 1.20000000e+01,  1.80000000e+01,  2.40000000e+01],
  [ 1.80000000e+01,  2.70000000e+01,  3.60000000e+01]],
 [[ 9.00000000e+00,  1.20000000e+01,  1.50000000e+01],
  [ 1.80000000e+01,  2.40000000e+01,  3.00000000e+01],
  [ 2.70000000e+01,  3.60000000e+01,  4.50000000e+01]],
 [[ 1.20000000e+01,  1.50000000e+01,  1.80000000e+01],
  [ 2.40000000e+01,  3.00000000e+01,  3.60000000e+01],
  [ 3.60000000e+01,  4.50000000e+01,  5.40000000e+01]]])
  • The vmap is nested in grad and is used to compute the gradient of the vectorized function.

[8]:
from mindspore import grad

def forward_fn(x, y):
    out = x + 2 * y
    out = ops.sin(out)
    reduce_sum = ops.ReduceSum()
    return reduce_sum(out)

x_hat = Tensor([[1., 2., 3.], [2., 3., 4.]], mindspore.float32)
y_hat = Tensor([[2., 3., 4.], [3., 4., 5.]], mindspore.float32)

grad_vmap_ger = grad(vmap(forward_fn), grad_position=(0, 1))
grad_vmap_ger(x_hat, y_hat)
[8]:
(Tensor(shape=[2, 3], dtype=Float32, value=
 [[ 2.83662200e-01, -1.45500034e-01,  4.42569796e-03],
  [-1.45500034e-01,  4.42569796e-03,  1.36737213e-01]]),
 Tensor(shape=[2, 3], dtype=Float32, value=
 [[ 5.67324400e-01, -2.91000068e-01,  8.85139592e-03],
  [-2.91000068e-01,  8.85139592e-03,  2.73474425e-01]]))
  • The grad is nested in vmap and is used in scenarios such as batch gradient computation and high-order gradient computation, for example, Jacobian matrix computation.

[9]:
vmap_grad_ger = vmap(grad(forward_fn, grad_position=(0, 1)))
vmap_grad_ger(x_hat, y_hat)
[9]:
(Tensor(shape=[2, 3], dtype=Float32, value=
 [[ 2.83662200e-01, -1.45500034e-01,  4.42569796e-03],
  [-1.45500034e-01,  4.42569796e-03,  1.36737213e-01]]),
 Tensor(shape=[2, 3], dtype=Float32, value=
 [[ 5.67324400e-01, -2.91000068e-01,  8.85139592e-03],
  [-2.91000068e-01,  8.85139592e-03,  2.73474425e-01]]))

This tutorial briefly describes how to nest two-layer high-order functions. You can implement more-layer nesting based on scenario requirements.

Automatic Vectorization of Cell

In the previous test cases, the function object is used as the input. The following describes how to use the Cell object as the input of vmap. This is an example of a simply defined fully-connected layer.

[10]:
import mindspore.nn as nn
from mindspore import Parameter
from mindspore.common.initializer import initializer

class Dense(nn.Cell):
    def __init__(self, in_channels, out_channels, weight_init='normal', bias_init='zeros'):
        super(Dense, self).__init__()
        self.scalar = 1
        self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
        self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
        self.matmul = ops.MatMul(transpose_b=True)

    def construct(self, x):
        x = self.matmul(x, self.weight)
        output = ops.bias_add(x, self.bias)
        return output

input_a = Tensor([[1, 2, 3], [4, 5, 6]], mindspore.float32)
input_b = Tensor([[2, 3, 4], [5, 6, 7]], mindspore.float32)
input_c = Tensor([[3, 4, 5], [6, 7, 8]], mindspore.float32)

dense_net = Dense(3, 4)
print(dense_net(input_a))
print(dense_net(input_b))
print(dense_net(input_c))

inputs = mnp.stack([input_a, input_b, input_c])

vmap_dense_net = vmap(dense_net)
print(vmap_dense_net(inputs))
[[ 0.0219292  -0.01062493 -0.03378957 -0.02589925]
 [ 0.03091274 -0.04968021 -0.08098207 -0.07896652]]
[[ 0.02492371 -0.02364336 -0.0495204  -0.04358834]
 [ 0.03390725 -0.06269865 -0.09671289 -0.09665561]]
[[ 0.02791822 -0.03666179 -0.06525123 -0.06127743]
 [ 0.03690176 -0.07571708 -0.11244373 -0.1143447 ]]
[[[ 0.0219292  -0.01062493 -0.03378957 -0.02589925]
  [ 0.03091274 -0.04968021 -0.08098207 -0.07896652]]

 [[ 0.02492371 -0.02364336 -0.0495204  -0.04358834]
  [ 0.03390725 -0.06269865 -0.09671289 -0.09665561]]

 [[ 0.02791822 -0.03666179 -0.06525123 -0.06127743]
  [ 0.03690176 -0.07571708 -0.11244373 -0.1143447 ]]]

The usage of Cell is basically the same as that of function-based automatic vectorization. You only need to replace the first input parameter of vmap with the Cell instance. Vmap vectorizes construct for batch data processing. In addition, two Parameter arguments are defined for the initialization function in this test case. The Vmap processing of such free variables in the execution functions is equivalent to using the free variables as arguments meanwhile setting in_axes to None.

In this way, batch input can be used for training or inference on the same model. Compared with the existing network model input that supports batch input, the batch processing dimension implemented by using Vmap is more flexible and is not limited to input formats such as NCHW.

Model Ensembling Scenario

In the model ensembling scenario, prediction results from multiple models are combined. Traditionally, each model is run on certain inputs, and then the prediction results are combined. If you are running models with the same architecture, you can vectorize them with Vmap for acceleration.

In this scenario, vectorization of weight data is involved. If the running model is implemented through functional programming, that is, weight parameters are defined outside the model and transferred to the model through arguments, you can directly configure in_axes to perform batch processing. To provide the convenient model definition function, the weight parameters of most nn APIs are internally defined and initialized. This means that the weight parameters in the model cannot be processed in batches in the original vmap API. Therefore, extra workload is required for reconstructing the model to a function that is transferred through arguments. Fortunately, the vmap API of MindSpore has optimized this scenario for you. You only need to transfer multiple running model instances to the vmap in CellList format. Then the framework can automatically implement batch processing of weight parameters.

The following demonstrate how to use a simple set of CNN models to implement model ensembling inference and training.

[11]:
class LeNet5(nn.Cell):
    """
    LeNet-5 network structure
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, num_class)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

Assume that we are verifying the effect of the same model architecture under different weight parameters. Simulate four trained model instances and a minibatch of a virtual image dataset with a batch size of 16 and a size of 32 x 32.

[12]:
net1 = LeNet5()
net2 = LeNet5()
net3 = LeNet5()
net4 = LeNet5()

minibatch = Tensor(mnp.randn(3, 1, 32, 32), mindspore.float32)

Compared with using the for loop to run each model and then combining prediction results, Vmap can obtain prediction results of multiple models at one run.

Note that the Vmap implementation mechanism has requirements on the running memory of the device. Using Vmap may occupy more memory. Please use it based on the actual scenario.

[13]:
nets = nn.CellList([net1, net2, net3, net4])
vmap(nets, in_axes=None)(minibatch)
[13]:
Tensor(shape=[4, 3, 10], dtype=Float32, value=
[[[ 4.66281290e-06, -7.24548045e-06,  8.68147254e-07 ...  1.42438457e-05,  1.49375774e-05, -1.18535736e-05],
  [ 9.10962353e-06, -5.63606591e-06, -7.06250285e-06 ...  1.68580664e-05,  1.41603141e-05, -3.55220163e-06],
  [ 1.11184154e-05, -6.08020900e-06, -5.08124231e-06 ...  1.37913748e-05,  1.20597506e-05, -1.01803471e-05]],
 [[ 3.22165624e-06,  6.22022753e-06,  2.60713023e-07 ... -1.53302244e-05,  2.34313102e-05, -4.16413786e-06],
  [ 2.82950850e-06,  1.54561894e-06,  5.19753303e-06 ... -1.53819674e-05,  1.58681542e-05, -7.10185304e-07],
  [ 1.77780748e-07,  4.33479636e-06, -1.35376536e-06 ... -1.06113021e-05,  1.58355688e-05, -5.78900153e-06]],
 [[ 6.66864071e-06, -1.99870119e-05, -1.30958688e-05 ...  3.68208202e-06, -9.69678968e-06,  9.59075351e-06],
  [ 7.99765985e-06, -1.16931469e-05, -1.06589669e-05 ... -1.24687813e-06, -8.65744005e-06,  6.81729716e-06],
  [ 6.87587362e-06, -1.23972441e-05, -1.05251866e-05 ...  1.44004912e-06, -5.40550172e-06,  6.71799853e-06]],
 [[-3.44783439e-06,  2.32537104e-07, -8.64402864e-06 ...  3.52633970e-06, -6.27670488e-06,  3.27721250e-06],
  [-6.90392517e-06, -9.97693860e-07, -6.48076320e-06 ...  7.61923275e-07, -2.54563452e-06,  3.08638573e-06],
  [-3.78440518e-06,  3.93633945e-06, -7.90367903e-06 ...  5.13138957e-07, -4.50420839e-06,  2.13702333e-06]]])

Alternatively, we expect to obtain prediction results of different minibatch data separately executed by a plurality of models.

In the model ensembling scenario, the first argument of vmap must be of the CellList type. Ensure that the architectures of all models are the same. Otherwise, the computation may be incorrect. If in_axes is not None, ensure that the number of models is the same as the value of axis_size corresponding to the mapping axis index to implement one-to-one mapping.

[14]:
minibatchs = Tensor(mnp.randn(4, 3, 1, 32, 32), mindspore.float32)
vmap(nets, in_axes=0)(minibatchs)
[14]:
Tensor(shape=[4, 3, 10], dtype=Float32, value=
[[[ 6.52808285e-06, -4.15002341e-06, -3.80283609e-06 ...  1.54428089e-05,  1.44425348e-05, -9.00016857e-06],
  [ 7.39091365e-06, -5.19960076e-06,  3.83916813e-07 ...  1.67857870e-05,  1.80104271e-05, -1.56435199e-05],
  [ 1.11604741e-05, -7.59019804e-06,  2.54263796e-07 ...  1.21071571e-05,  1.66683039e-05, -1.09967377e-05]],
 [[ 1.48978233e-06,  1.02267529e-06,  1.33801677e-06 ... -1.32894393e-05,  1.36311328e-05, -3.29658405e-06],
  [ 1.09956818e-06, -5.06103561e-07,  3.04885953e-06 ... -1.76028752e-05,  1.66466998e-05, -1.17290392e-06],
  [ 2.96090502e-06,  1.87074147e-06,  5.76813818e-06 ... -1.09994007e-05,  1.35614964e-05, -2.19983576e-06]],
 [[ 6.74323928e-06, -1.03955799e-05, -6.92168396e-06 ...  4.88165415e-06, -5.40378596e-06,  3.09346888e-06],
  [ 7.28906161e-06, -1.34921102e-05, -1.00995640e-05 ...  9.44596650e-07, -6.40979761e-06,  1.26146606e-05],
  [ 9.43304440e-06, -1.61852931e-05, -1.16265892e-05 ...  5.31926253e-06, -1.28484417e-05,  8.03831313e-07]],
 [[-5.51165886e-06, -1.09487860e-06, -6.07781249e-06 ...  7.51453626e-06, -3.29403338e-06,  3.45475746e-06],
  [-6.27516283e-06,  1.40756754e-06, -9.18502155e-06 ...  4.16079911e-06, -5.30383022e-06,  5.12517454e-06],
  [-6.19608954e-06,  5.12868655e-06, -1.00337056e-05 ...  2.93281119e-07, -6.52256404e-06,  3.62988635e-06]]])

In addition to model ensembling inference, the Vmap feature can also be used to implement model ensembling training.

[15]:
from mindspore import ParameterTuple

class TrainOneStepNet(nn.Cell):
    def __init__(self, net, lr):
        super(TrainOneStepNet, self).__init__()
        self.loss_fn = nn.WithLossCell(net, nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean'))
        self.weight = ParameterTuple(net.trainable_params())
        self.adam_optim = nn.Adam(self.weight, learning_rate=lr, use_amsgrad=True)

    def construct(self, batch, targets):
        loss = self.loss_fn(batch, targets)
        grad_weights = grad(self.loss_fn, grad_position=None, weights=self.weight)(batch, targets)
        self.adam_optim(grad_weights)
        return loss

train_net1 = TrainOneStepNet(net1, lr=1e-2)
train_net2 = TrainOneStepNet(net2, lr=1e-3)
train_net3 = TrainOneStepNet(net3, lr=2e-3)
train_net4 = TrainOneStepNet(net4, lr=3e-3)

train_nets = nn.CellList([train_net1, train_net2, train_net3, train_net4])
model_ensembling_train_one_step = vmap(train_nets)

images = Tensor(mnp.randn(4, 3, 1, 32, 32), mindspore.float32)
labels = Tensor(mnp.randint(1, 10, (4, 3)), mindspore.int32)

for i in range(1, 11):
    loss = model_ensembling_train_one_step(images, labels)
    print("Step {} - loss: {}".format(i, loss))

vmap(nets, in_axes=None)(minibatch)
Step 1 - loss: [2.3025837 2.3025882 2.3025858 2.3025842]
Step 2 - loss: [2.260927  2.301028  2.2992857 2.2976868]
Step 3 - loss: [1.8539654 2.2993202 2.2951114 2.2899477]
Step 4 - loss: [0.77165794 2.2973287  2.288719   2.2726345 ]
Step 5 - loss: [0.9397469 2.2948549 2.2777178 2.2313874]
Step 6 - loss: [0.6747699 2.29158   2.2579656 2.1410708]
Step 7 - loss: [0.64673084 2.2870557  2.2232006  1.966973  ]
Step 8 - loss: [1.0506033 2.2806385 2.1645374 1.6848679]
Step 9 - loss: [0.612196  2.2714498 2.0706694 1.3499321]
Step 10 - loss: [0.8843982 2.258316  1.9299208 1.1472267]
[15]:
Tensor(shape=[4, 3, 10], dtype=Float32, value=
[[[-1.91058636e+01, -1.92182674e+01,  1.06328402e+01 ... -1.87287464e+01, -1.87855473e+01, -2.02504387e+01],
  [-1.94767399e+01, -1.95909595e+01,  1.08379564e+01 ... -1.90921249e+01, -1.91503220e+01, -2.06434765e+01],
  [-1.96521702e+01, -1.97674465e+01,  1.09355783e+01 ... -1.92643051e+01, -1.93227654e+01, -2.08293762e+01]],
 [[-4.07293849e-02, -4.27918807e-02,  5.22112176e-02 ... -4.67570126e-02, -3.88025381e-02,  4.88412194e-02],
  [-3.91553082e-02, -4.11494374e-02,  5.00433967e-02 ... -4.48847115e-02, -3.73134986e-02,  4.68519926e-02],
  [-3.80369201e-02, -3.99325565e-02,  4.84890938e-02 ... -4.35365662e-02, -3.62745039e-02,  4.54225838e-02]],
 [[-5.08784056e-01, -5.05123973e-01,  5.20882547e-01 ...  4.72596169e-01, -5.00697553e-01, -4.60489392e-01],
  [-4.80103493e-01, -4.76664037e-01,  4.91507798e-01 ...  4.46062207e-01, -4.72493649e-01, -4.34652239e-01],
  [-4.81168061e-01, -4.77702975e-01,  4.92583781e-01 ...  4.47029382e-01, -4.73524809e-01, -4.35579300e-01]],
 [[-3.66236401e+00, -3.25362825e+00,  3.51312804e+00 ...  3.77490187e+00, -3.36864424e+00, -3.34358120e+00],
  [-3.49160767e+00, -3.10209608e+00,  3.34935308e+00 ...  3.59894991e+00, -3.21167707e+00, -3.18782210e+00],
  [-3.57623625e+00, -3.17717075e+00,  3.43059325e+00 ...  3.68615556e+00, -3.28948307e+00, -3.26504302e+00]]])

In addition to ensembling inference, a trained ensembling model can still be inferred independently.

[16]:
net1(minibatch)
[16]:
Tensor(shape=[3, 10], dtype=Float32, value=
[[-1.91058636e+01, -1.92182674e+01,  1.06328402e+01 ... -1.87287483e+01, -1.87855473e+01, -2.02504387e+01],
 [-1.94767399e+01, -1.95909595e+01,  1.08379564e+01 ... -1.90921249e+01, -1.91503220e+01, -2.06434765e+01],
 [-1.96521702e+01, -1.97674465e+01,  1.09355783e+01 ... -1.92643051e+01, -1.93227654e+01, -2.08293762e+01]])

More Practice Cases

Summary

This tutorial focuses on the usage of Vmap. In essence, Vmap does not execute the loop outside the function. Instead, it offloads the loop to each primitive operation of the function and transfers the mapping axis information between primitive operations to ensure the correctness of the operations. The Vmap performance benefits mainly come from the VmapRule implementation corresponding to each primitive operation. Because the loop is offloaded to the operator level, it is easier to optimize the performance based on the parallel technology. If you have custom operators in your function, you can try to implement specific VmapRule for custom operators to achieve better performance. If ultimate performance is required, the graph kernel fusion feature can be used for optimization.

Currently, the Vmap feature supports the GPU and CPU platforms. More functions are being adapted to the Ascend platform.

If the Vmap contains control flows, each batch processing branch must have the same processing operation or all variables in the control flow has no split axis.

Based on the preceding cases, you may have a general understanding of Vmap. However, the application scenarios of Vmap are not limited to this tutorial. You can try more interesting scenarios and join our discussion and work.