Parameter Initialization
Initializing with Built-In Parameters
MindSpore provides a variety of network parameter initialization methods, and encapsulates the function of parameter initialization in some operators. This section takes Conv2d
as an example to introduce how to use the subclass, Initializer
, and string to initialize parameters.
Initializer Initialization
Initializer
is the built-in parameter initialization base class of MindSpore. All built-in parameter initialization methods inherit this class. The neural network layer package in mindspore.nn
provides input parameters weight_init
, bias_init
, etc., which can be directly initialized with the instantiated Initializer. Examples are as follows:
import numpy as np
import mindspore.nn as nn
import mindspore as ms
from mindspore.common.initializer import Normal, initializer
input_data = ms.Tensor(np.ones([1, 3, 16, 50], dtype=np.float32))
# Convolution layer, the input channel is 3, the output channel is 64, the size of convolution kernel is 3 * 3, and the weight parameter uses the random number generated by normal distribution, Nomal().
net = nn.Conv2d(3, 64, 3, weight_init=Normal(0.2))
# The network output
output = net(input_data)
String Initialization
In addition to using the instantiated Initializer, MindSpore also provides a simple method for parameter initialization, that is, using the string of initializing method name. This method uses the default parameters of the Initializer to initialize. Examples are as follows:
import numpy as np
import mindspore.nn as nn
import mindspore as ms
net = nn.Conv2d(3, 64, 3, weight_init='normal')
output = net(input_data)
Customized Parameter Initialization
In general, the default parameter initialization provided by MindSpore can meet the initialization requirements of the common neural network layer. When encountering a parameter initialization method that needs to be customized, you can inherit the Initializer
custom parameter initialization method. Take XavierNormal
as an example:
import math
import numpy as np
from mindspore.common.initializer import Initializer
def _calculate_fan_in_and_fan_out(arr):
# calculate fan_in and fan_out. fan_in is the number of input units in `arr` , and fan_out is the number of output units in `arr`.
shape = arr.shape
dimensions = len(shape)
if dimensions < 2:
raise ValueError("'fan_in' and 'fan_out' can not be computed for arr with fewer than"
" 2 dimensions, but got dimensions {}.".format(dimensions))
if dimensions == 2: # Linear
fan_in = shape[1]
fan_out = shape[0]
else:
num_input_fmaps = shape[1]
num_output_fmaps = shape[0]
receptive_field_size = 1
for i in range(2, dimensions):
receptive_field_size *= shape[i]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
class XavierNormal(Initializer):
def __init__(self, gain=1):
super().__init__()
# Configure the parameters required for initialization
self.gain = gain
def _initialize(self, arr): # arr is a Tensor to be initialized
fan_in, fan_out = _calculate_fan_in_and_fan_out(arr) # Compute fan_in, fan_out
std = self.gain * math.sqrt(2.0 / float(fan_in + fan_out)) # Calculate std value
data = np.random.normal(0, std, arr.shape) # Construct the initialized array with numpy
arr[:] = data[:] # Assign the initialized ndarray to arr
After that, we can call it like the built-in initialization method:
net = nn.Conv2d(3, 64, 3, weight_init=XavierNormal())
# The network output
output = net(input_data)
Cell Traversal Initialization
In addition to using parameters weight_init
, bias_init
, etc., provided by mindspore.nn
, we are also used to constructing a complete neural network first, and then uniformly managing the weight
, bias
and other parameters. At this time, you need to construct a network and instantiate it, then traverse the cell and assign values to parameters. Here is a simple example:
for name, param in net.parameters_and_names():
if 'weight' in name:
param.set_data(initializer(Normal(), param.shape, param.dtype))
if 'bias' in name:
param.set_data(initializer('zeros', param.shape, param.dtype))