损失函数
损失函数,又叫目标函数,用于衡量预测值与真实值差异的程度。
在深度学习中,模型训练就是通过不停地迭代来缩小损失函数值的过程,因此,在模型训练过程中损失函数的选择非常重要,定义一个好的损失函数,可以有效提高模型的性能。
mindspore.nn
模块中提供了许多通用损失函数,但这些通用损失函数并不适用于所有场景,很多情况需要用户自定义所需的损失函数。因此,本教程介绍如何自定义损失函数。
内置损失函数
首先介绍mindspore.nn
模块中内置的损失函数。
如下示例以nn.L1Loss
为例,计算预测值和目标值之间的平均绝对误差:
其中N为数据集中的batch_size
值。
nn.L1Loss
中的参数reduction
取值可为mean
,sum
,或none
。如果 reduction
为mean
或sum
,则输出一个标量Tensor;如果reduction
为none
,则输出Tensor的shape为广播后的shape。
[1]:
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
# 输出loss均值
loss = nn.L1Loss()
# 输出loss和
loss_sum = nn.L1Loss(reduction='sum')
# 输出loss原值
loss_none = nn.L1Loss(reduction='none')
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32))
print("loss:", loss(input_data, target_data))
print("loss_sum:", loss_sum(input_data, target_data))
print("loss_none:\n", loss_none(input_data, target_data))
loss: 1.5
loss_sum: 9.0
loss_none:
[[1. 0. 2.]
[1. 2. 3.]]
自定义损失函数
自定义损失函数的方法有两种,一种是通过继承网络的基类nn.Cell
来定义损失函数,另一种是通过继承损失函数的基类nn.LossBase
来定义损失函数。nn.LossBase
在nn.Cell
的基础上,提供了get_loss
方法,利用reduction
参数对损失值求和或求均值,输出一个标量。
下面将分别使用继承Cell
和继承LossBase
的方法,来定义平均绝对误差损失函数(Mean Absolute Error,MAE),MAE算法的公式如下所示:
上式中\(f(x)\)为预测值,\(y\)为样本真实值,\(loss\)为预测值与真实值之间距离的平均值。
继承Cell的损失函数
nn.Cell
是MindSpore的基类,可以用于构建网络,也可以用于定义损失函数。使用nn.Cell
定义损失函数的方法与定义一个普通的网络相同,差别在于,其执行逻辑用于计算前向网络输出与真实值之间的误差。
下面通过继承nn.Cell
方法来定义损失函数MAELoss
的方法如下:
[2]:
import mindspore.ops as ops
class MAELoss(nn.Cell):
"""自定义损失函数MAELoss"""
def __init__(self):
"""初始化"""
super(MAELoss, self).__init__()
self.abs = ops.Abs()
self.reduce_mean = ops.ReduceMean()
def construct(self, base, target):
"""调用算子"""
x = self.abs(base - target)
return self.reduce_mean(x)
loss = MAELoss()
input_data = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32)) # 生成预测值
target_data = Tensor(np.array([0.1, 0.2, 0.2]).astype(np.float32)) # 生成真实值
output = loss(input_data, target_data)
print(output)
0.033333335
继承LossBase的损失函数
通过继承nn.LossBase来定义损失函数MAELoss
,与nn.Cell
类似,都要重写__init__
方法和construct
方法。
nn.LossBase
可使用方法get_loss
将reduction
应用于损失计算。
[4]:
class MAELoss(nn.LossBase):
"""自定义损失函数MAELoss"""
def __init__(self, reduction="mean"):
"""初始化并求loss均值"""
super(MAELoss, self).__init__(reduction)
self.abs = ops.Abs() # 求绝对值算子
def construct(self, base, target):
x = self.abs(base - target)
return self.get_loss(x) # 返回loss均值
loss = MAELoss()
input_data = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32)) # 生成预测值
target_data = Tensor(np.array([0.1, 0.2, 0.2]).astype(np.float32)) # 生成真实值
output = loss(input_data, target_data)
print(output)
0.033333335
损失函数与模型训练
自定义的损失函数MAELoss
完成后,可使用MindSpore的接口Model中train
接口进行模型训练,定义Model
时需要指定前向网络、损失函数和优化器,Model
内部会将它们关联起来,组成一张可用于训练的网络模型。
在Model
中,前向网络和损失函数是通过nn.WithLossCell关联起来的,nn.WithLossCell
支持两个输入,分别为data
和label
。
[5]:
from mindspore import Model
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindvision.engine.callback import LossMonitor
def get_data(num, w=2.0, b=3.0):
"""生成数据及对应标签"""
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
def create_dataset(num_data, batch_size=16):
"""加载数据集"""
dataset = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data', 'label'])
dataset = dataset.batch(batch_size)
return dataset
class LinearNet(nn.Cell):
"""定义线性回归网络"""
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
return self.fc(x)
ds_train = create_dataset(num_data=160)
net = LinearNet()
loss = MAELoss()
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
# 使用model接口将网络、损失函数和优化器关联起来
model = Model(net, loss, opt)
model.train(epoch=1, train_dataset=ds_train, callbacks=[LossMonitor(0.005)])
Epoch:[ 0/ 1], step:[ 1/ 10], loss:[9.169/9.169], time:365.966 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 2/ 10], loss:[5.861/7.515], time:0.806 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 3/ 10], loss:[8.759/7.930], time:0.768 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 4/ 10], loss:[9.503/8.323], time:1.080 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 5/ 10], loss:[8.541/8.367], time:0.762 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 6/ 10], loss:[9.158/8.499], time:0.707 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 7/ 10], loss:[9.168/8.594], time:0.900 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 8/ 10], loss:[6.828/8.373], time:1.184 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 9/ 10], loss:[7.149/8.237], time:0.962 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 10/ 10], loss:[6.342/8.048], time:1.273 ms, lr:0.00500
Epoch time: 390.358 ms, per step time: 39.036 ms, avg loss: 8.048
多标签损失函数与模型训练
上述定义了一个简单的平均绝对误差损失函数MAELoss
,但许多深度学习应用的数据集较复杂,例如目标检测网络Faster R-CNN的数据中就包含多个标签,而不是简单的一个数据对应一个标签,这时候损失函数的定义和使用略有不同。
本节介绍在多标签数据集场景下,如何定义多标签损失函数(Multi label loss function),并使用Model进行模型训练。
多标签数据集
如下示例通过get_multilabel_data
函数拟合两组线性数据\(y1\)和\(y2\),拟合的目标函数为:
由于最终的数据集应该随机分布于函数周边,这里按以下公式的方式生成,其中noise
为遵循标准正态分布规律的随机数值。get_multilabel_data
函数返回数据\(x\)、\(y1\)和\(y2\):
通过create_multilabel_dataset
生成多标签数据集,并将GeneratorDataset
中的column_names
参数设置为[‘data’, ‘label1’, ‘label2’],最终返回的数据集就有一个数据data
对应两个标签label1
和label2
。
[6]:
import numpy as np
from mindspore import dataset as ds
def get_multilabel_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise1 = np.random.normal(0, 1)
noise2 = np.random.normal(-1, 1)
y1 = x * w + b + noise1
y2 = x * w + b + noise2
yield np.array([x]).astype(np.float32), np.array([y1]).astype(np.float32), np.array([y2]).astype(np.float32)
def create_multilabel_dataset(num_data, batch_size=16):
dataset = ds.GeneratorDataset(list(get_multilabel_data(num_data)), column_names=['data', 'label1', 'label2'])
dataset = dataset.batch(batch_size) # 每个batch有16个数据
return dataset
多标签损失函数
针对上一步创建的多标签数据集,定义多标签损失函数MAELossForMultiLabel
。
上式中,\(f(x)\) 为预测值,\(y1\) 和 \(y2\) 为样本真实值,\(loss1\) 为预测值与样本真实值 \(y1\) 之间距离的平均值,\(loss2\) 为预测值与样本真实值 \(y2\) 之间距离的平均值 ,\(loss\) 为损失值 \(loss1\) 与损失值 \(loss2\) 平均值。
在MAELossForMultiLabel
中的construct
方法的输入有三个,预测值base
,真实值target1
和target2
,在construct
中分别计算预测值与真实值target1
,预测值与真实值target2
之间的误差,将这两个误差的均值作为最终的损失函数值.
示例代码如下:
[7]:
class MAELossForMultiLabel(nn.LossBase):
def __init__(self, reduction="mean"):
super(MAELossForMultiLabel, self).__init__(reduction)
self.abs = ops.Abs()
def construct(self, base, target1, target2):
x1 = self.abs(base - target1)
x2 = self.abs(base - target2)
return (self.get_loss(x1) + self.get_loss(x2))/2
多标签模型训练
使用Model
关联指定的前向网络、损失函数和优化器时,由于Model
默认使用的nn.WithLossCell
只有两个输入:data
和label
,不适用于多标签的场景。
在多标签场景下,如果想使用Model
进行模型训练就需要将前向网络与多标签损失函数连接起来,需要自定义损失网络,将前向网络和自定义多标签损失函数关联起来。
定义损失网络
定义损失网络CustomWithLossCell
,其中__init__
方法的输入分别为前向网络backbone
和损失函数loss_fn
,construct
方法的输入分别为数据data
、label1
和label2
,将数据部分data
传给前向网络backbone
,将预测值和两个标签传给损失函数loss_fn
。
[8]:
class CustomWithLossCell(nn.Cell):
def __init__(self, backbone, loss_fn):
super(CustomWithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, data, label1, label2):
output = self._backbone(data)
return self._loss_fn(output, label1, label2)
定义网络模型并训练
使用Model连接前向网络、多标签损失函数和优化器时,Model
的网络network
指定为自定义的损失网络loss_net
,损失函数loss_fn
不指定,优化器仍使用Momentum
。
由于未指定loss_fn
,Model
则认为network
内部已经实现了损失函数的逻辑,不会用nn.WithLossCell
对前向函数和损失函数进行封装。
[9]:
ds_train = create_multilabel_dataset(num_data=160)
net = LinearNet()
# 定义多标签损失函数
loss = MAELossForMultiLabel()
# 定义损失网络,连接前向网络和多标签损失函数
loss_net = CustomWithLossCell(net, loss)
# 定义优化器
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
# 定义Model,多标签场景下Model无需指定损失函数
model = Model(network=loss_net, optimizer=opt)
model.train(epoch=1, train_dataset=ds_train, callbacks=[LossMonitor(0.005)])
Epoch:[ 0/ 1], step:[ 1/ 10], loss:[10.329/10.329], time:290.788 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 2/ 10], loss:[10.134/10.231], time:0.813 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 3/ 10], loss:[9.862/10.108], time:2.410 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 4/ 10], loss:[11.182/10.377], time:1.154 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 5/ 10], loss:[8.571/10.015], time:1.137 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 6/ 10], loss:[7.763/9.640], time:0.928 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 7/ 10], loss:[7.542/9.340], time:1.001 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 8/ 10], loss:[8.644/9.253], time:1.156 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 9/ 10], loss:[5.815/8.871], time:1.908 ms, lr:0.00500
Epoch:[ 0/ 1], step:[ 10/ 10], loss:[5.086/8.493], time:1.575 ms, lr:0.00500
Epoch time: 323.467 ms, per step time: 32.347 ms, avg loss: 8.493
本章节简单讲解了多标签数据集场景下,如何定义损失函数并使用Model进行模型训练。在很多其他场景中,也可以采用此类方法进行模型训练。