View Source On Gitee

Introduction || Quick Start || Tensor || Dataset || Transforms || Model || Autograd || Train || Save and Load || Accelerating with Static Graphs

Building a Network

The neural network model consists of neural network layers and Tensor operations. mindspore.nn provides common neural network layer implementations, and the Cell class in MindSpore is the base class for building all networks and is the basic unit of the network. Cell, a neural network model, is composed of different sub-Cells. Using such a nested structure, the neural network structure can be constructed and managed simply by using object-oriented programming thinking.

In the following we will construct a neural network model for the Mnist dataset classification.

import mindspore
from mindspore import nn, ops

Defining a Model Class

When define a neural network, we can inherit the nn.Cell class, instantiate and manage the state of the sub-Cell in the __init__ method, and implement the Tensor operation in the construct method.

construct means neural network (computational graph) construction. For more details, see Accelerating with Static Graphs.

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),
            nn.ReLU(),
            nn.Dense(512, 10, weight_init="normal", bias_init="zeros")
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

After completing construction, instantiate the Network object and look at its structure.

model = Network()
print(model)
Network<
  (flatten): Flatten<>
  (dense_relu_sequential): SequentialCell<
    (0): Dense<input_channels=784, output_channels=512, has_bias=True>
    (1): ReLU<>
    (2): Dense<input_channels=512, output_channels=512, has_bias=True>
    (3): ReLU<>
    (4): Dense<input_channels=512, output_channels=10, has_bias=True>
    >
  >

We construct an input data and call the model directly to obtain a 10-dimensional Tensor output that contains the original predicted values for each category.

The model.construct() method cannot be called directly.

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)
# print logits
logits
Tensor(shape=[1, 10], dtype=Float32, value=
[[-5.08734025e-04,  3.39190010e-04,  4.62840870e-03 ... -1.20305456e-03, -5.05689112e-03,  3.99264274e-03]])

On this basis, we obtain the prediction probabilities by an nn.Softmax layer instance.

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
Predicted class: [4]

Model Layers

In this section, we decompose each layer of the neural network model constructed in the previous section. First we construct a random data (3 images of 28x28) with shape (3, 28, 28) and pass through each neural network layer in turn to observe its effect.

input_image = ops.ones((3, 28, 28), mindspore.float32)
print(input_image.shape)
(3, 28, 28)

nn.Flatten

Initialize the nn.Flatten layer and convert a 28x28 2D tensor into a contiguous array of size 784.

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)
(3, 784)

nn.Dense

nn.Dense is the fully connected layer, which linearly transforms the input by using weights and deviations.

layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)
(3, 20)

nn.ReLU

nn.ReLU layer adds a nonlinear activation function to the network, to help the neural network learn various complex features.

print(f"Before ReLU: {hidden1}\n\n")
hidden1 = nn.ReLU()(hidden1)
print(f"After ReLU: {hidden1}")
Before ReLU: [[-0.04736331  0.2939465  -0.02713677 -0.30988005 -0.11504349 -0.11661264
   0.18007928  0.43213072  0.12091967 -0.17465964  0.53133243  0.12605792
   0.01825903  0.01287796  0.17238477 -0.1621131  -0.0080034  -0.24523425
  -0.10083733  0.05171938]
 [-0.04736331  0.2939465  -0.02713677 -0.30988005 -0.11504349 -0.11661264
   0.18007928  0.43213072  0.12091967 -0.17465964  0.53133243  0.12605792
   0.01825903  0.01287796  0.17238477 -0.1621131  -0.0080034  -0.24523425
  -0.10083733  0.05171938]
 [-0.04736331  0.2939465  -0.02713677 -0.30988005 -0.11504349 -0.11661264
   0.18007928  0.43213072  0.12091967 -0.17465964  0.53133243  0.12605792
   0.01825903  0.01287796  0.17238477 -0.1621131  -0.0080034  -0.24523425
  -0.10083733  0.05171938]]


After ReLU: [[0.         0.2939465  0.         0.         0.         0.
  0.18007928 0.43213072 0.12091967 0.         0.53133243 0.12605792
  0.01825903 0.01287796 0.17238477 0.         0.         0.
  0.         0.05171938]
 [0.         0.2939465  0.         0.         0.         0.
  0.18007928 0.43213072 0.12091967 0.         0.53133243 0.12605792
  0.01825903 0.01287796 0.17238477 0.         0.         0.
  0.         0.05171938]
 [0.         0.2939465  0.         0.         0.         0.
  0.18007928 0.43213072 0.12091967 0.         0.53133243 0.12605792
  0.01825903 0.01287796 0.17238477 0.         0.         0.
  0.         0.05171938]]

nn.SequentialCell

nn.SequentialCell is an ordered Cell container. The input Tensor will pass through all the Cells in the defined order, and we can use SequentialCell to construct a neural network model quickly.

seq_modules = nn.SequentialCell(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Dense(20, 10)
)

logits = seq_modules(input_image)
print(logits.shape)
(3, 10)

nn.Softmax

Finally, the value of logits returned by the last fully-connected layer of the neural network is scaled to [0, 1] by using nn.Softmax, indicating the predicted probability of each category. The dimensional values specified by axis sum to 1.

softmax = nn.Softmax(axis=1)
pred_probab = softmax(logits)

Model Parameters

The internal neural network layer of the network has weight parameters and bias parameters (e.g. nn.Dense), which are continuously optimized during the training process, and the parameter names and corresponding parameter details can be obtained through model.parameters_and_names().

print(f"Model structure: {model}\n\n")

for name, param in model.parameters_and_names():
    print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")
Model structure: Network<
  (flatten): Flatten<>
  (dense_relu_sequential): SequentialCell<
    (0): Dense<input_channels=784, output_channels=512, has_bias=True>
    (1): ReLU<>
    (2): Dense<input_channels=512, output_channels=512, has_bias=True>
    (3): ReLU<>
    (4): Dense<input_channels=512, output_channels=10, has_bias=True>
    >
  >


Layer: dense_relu_sequential.0.weight
Size: (512, 784)
Values : [[-0.01491369  0.00353318 -0.00694948 ...  0.01226766 -0.00014423
   0.00544263]
 [ 0.00212971  0.0019974  -0.00624789 ... -0.01214037  0.00118004
  -0.01594325]]

Layer: dense_relu_sequential.0.bias
Size: (512,)
Values : [0. 0.]

Layer: dense_relu_sequential.2.weight
Size: (512, 512)
Values : [[ 0.00565423  0.00354313  0.00637383 ... -0.00352688  0.00262949
   0.01157355]
 [-0.01284141  0.00657666 -0.01217057 ...  0.00318963  0.00319115
  -0.00186801]]
Layer: dense_relu_sequential.2.bias
Size: (512,)
Values : [0. 0.]
Layer: dense_relu_sequential.4.weight
Size: (10, 512)
Values : [[ 0.0087168  -0.00381866 -0.00865665 ... -0.00273731 -0.00391623
   0.00612853]
 [-0.00593031  0.0008721  -0.0060081  ... -0.00271535 -0.00850481
  -0.00820513]]
Layer: dense_relu_sequential.4.bias
Size: (10,)
Values : [0. 0.]

For more built-in neural network layers, see mindspore.nn API.