Source code for mindelec.loss.losses

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
add metric function
"""
from __future__ import absolute_import

import mindspore.nn as nn


_loss_metric = {
    'l1_loss': nn.L1Loss,
    'l1': nn.L1Loss,
    'l2_loss': nn.MSELoss,
    'l2': nn.MSELoss,
    'mse_loss': nn.MSELoss,
    'mse': nn.MSELoss,
    'rmse_loss': nn.RMSELoss,
    'rmse': nn.RMSELoss,
    'mae_loss': nn.MAELoss,
    'mae': nn.MAELoss,
    'smooth_l1_loss': nn.SmoothL1Loss,
    'smooth_l1': nn.SmoothL1Loss,
}


[docs]def get_loss_metric(name): """ Gets the loss function. Args: name (str): The name of the loss function. Returns: Function, the loss function. Supported Platforms: ``Ascend`` Examples: >>> import numpy as np >>> from mindelec.loss import get_loss_metric >>> import mindspore >>> from mindspore import Tensor >>> l1_loss = get_loss_metric('l1_loss') >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32) >>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32) >>> output = l1_loss(logits, labels) >>> print(output) 0.6666667 """ if not isinstance(name, str): raise TypeError("the type of name should be str but got {}".format(type(name))) if name not in _loss_metric: raise ValueError("Unknown loss function type: {}".format(name)) return _loss_metric[name]()