# Copyright 2021 Huawei Technologies Co., Ltd
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LeNet5 backbone."""

import mindspore.nn as nn
from mindspore.common.initializer import Normal

from mindvision.engine.class_factory import ClassFactory, ModuleType

__all__ = ["LeNet5"]

[docs]@ClassFactory.register(ModuleType.BACKBONE) class LeNet5(nn.Cell): """ LeNet backbone. Args: num_class (int): The number of classes. Default: 10. num_channel (int): The number of channels. Default: 1. include_top (bool): Whether to use the TOP architecture. Default: True. Inputs: - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})` Outputs: Tensor of shape :math:`(N, 10)` Supported Platforms: ``GPU`` Examples: >>> import numpy as np >>> >>> import mindspore as ms >>> from mindvision.classification.models.backbones import LeNet5 >>> >>> net = LeNet5() >>> x = ms.Tensor(np.ones([1, 1, 32, 32]), ms.float32) >>> output = net(x) >>> print(output.shape) (1, 10) About LeNet5: LeNet5 trained with the back-propagation algorithm constitute the best example of a successful gradient based learning technique. Given an appropriate network architecture, gradient-based learning algorithms can be used to synthesize a complex decision surface that can classify high-dimensional patterns, such as handwritten characters, with minimal preprocessing. Citation: .. code-block:: @article{1998Gradient, title={Gradient-based learning applied to document recognition}, author={ Lecun, Y. and Bottou, L. }, journal={Proceedings of the IEEE}, volume={86}, number={11}, pages={2278-2324}, year={1998} } """ def __init__(self, num_classes=10, num_channel=1, include_top=True): super(LeNet5, self).__init__() self.include_top = include_top self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) if self.include_top: self.flatten = nn.Flatten() self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) self.fc3 = nn.Dense(84, num_classes, weight_init=Normal(0.02))
[docs] def construct(self, x): """ LeNet5 construct. """ x = self.conv1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) if self.include_top: x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x