**交叉熵损失函数** 通常用于解决分类问题。
**pytorch**:torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
**mindspore**:mindspore.nn.SoftmaxCrossEntropyWithLogits(sparse=False, reduction="none")
虽然 mindspore 官方算子映射表(https://www.mindspore.cn/doc/programming_guide/zh-CN/r1.2/index.html#operator_api) 说这两个算子功能是一致的,但是其实在细节上还是有部分不同。
### 一、应用背景
在语义分割任务中,我们需要为一张图片中的每个像素进行分类。在训练数据集中人为标注的标签(ground truth)中,我们通常事先标注好**X**类标签,包括**X-1**类物体(非背景)和**1**类背景。但在实际研究中(由于分割难度等限制),我们设计的网络可能并不着力于正确分类这**X**类物体,而是只尝试分割其中的**Y**类物体(**Y**<=**X**),因此对于其中的**X-Y**类物体,我们在训练过程中将直接忽视它们,它们不参与网络分割结果的损失计算,网络将这块区域分割成什么物体我们也并不关心。
要达到这一目的,通常我们只需要在读取**label**时将它们的类别全部重写为一个新类别(一般设置为-1,也可以是其他任意值,只要不与待分类物体类别冲突,自己设定即可),然后在和预测结果算损失值时,将这个类别忽略。在**pytorch**中即将 **torch.nn.CrossEntropyLoss** 中的 **ignore_index** 参数指定为自己设定的那个新类别值。不过遗憾的是,**mindspore** 中的交叉熵损失函数**SoftmaxCrossEntropyWithLogits**不含这一参数,那我们应该怎么做来完成这一目标呢?
### 二、mindspore 中的 ignore_index
我这里选取 FastSCNN 的损失函数作为示例。
全部源代码见 https://gitee.com/mindspore/mindspore/tree/r1.2/model_zoo/official/cv/fastscnn
```python
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
from mindspore.nn import SoftmaxCrossEntropyWithLogits
__all__ = ['MixSoftmaxCrossEntropyLoss']
class MixSoftmaxCrossEntropyLoss(nn.Cell):
'''MixSoftmaxCrossEntropyLoss'''
def __init__(self, args, ignore_label=-1, aux=True, aux_weight=0.4, \
sparse=True, reduction='none', one_d_length=2*768*768, **kwargs):
super(MixSoftmaxCrossEntropyLoss, self).__init__()
self.ignore_label = ignore_label
self.weight = aux_weight if aux else 1.0
self.select = ops.Select()
self.reduceSum = ops.ReduceSum(keep_dims=False)
self.div_no_nan = ops.DivNoNan()
self.mul = ops.Mul()
self.reshape = ops.Reshape()
self.cast = ops.Cast()
self.transpose = ops.Transpose()
self.zero_tensor = Tensor([0]*one_d_length, mindspore.float32)
self.SoftmaxCrossEntropyWithLogits = SoftmaxCrossEntropyWithLogits(sparse=sparse, reduction="none")
args.logger.info('using MixSoftmaxCrossEntropyLoss....')
args.logger.info('self.ignore_label:' + str(self.ignore_label))
args.logger.info('self.aux:' + str(aux))
args.logger.info('self.weight:' + str(self.weight))
args.logger.info('one_d_length:' + str(one_d_length))
def construct(self, *inputs, **kwargs):
'''construct'''
preds, target = inputs[:-1], inputs[-1]
target = self.reshape(target, (-1,))
valid_flag = target != self.ignore_label
num_valid = self.reduceSum(self.cast(valid_flag, mindspore.float32))
z = self.transpose(preds[0], (0, 2, 3, 1))#move the C-dim to the last, then reshape it.
#This operation is vital, or the data would be soiled.
loss = self.SoftmaxCrossEntropyWithLogits(self.reshape(z, (-1, 19)), target)
loss = self.select(valid_flag, loss, self.zero_tensor)
loss = self.reduceSum(loss)
loss = self.div_no_nan(loss, num_valid)
for i in range(1, len(preds)):
z = self.transpose(preds[i], (0, 2, 3, 1))
aux_loss = self.SoftmaxCrossEntropyWithLogits(self.reshape(z, (-1, 19)), target)
aux_loss = self.select(valid_flag, aux_loss, self.zero_tensor)
aux_loss = self.reduceSum(aux_loss)
aux_loss = self.div_no_nan(aux_loss, num_valid)
loss += self.mul(self.weight, aux_loss)
return loss
```
请自行忽略掉一些无关代码。
我们从construct开始看,第31行是由于FastSCNN的网络输出有3部分,MixSoftmaxCrossEntropyLoss需要分别计算其与target之间的损失值,这里取出的preds正常情况下(aux=True时)是一个长度为3的元组。
第32行将target展成1维,当输入图片为768x768大小、batch_size=2时,它的1维长度就是2x768x768=1179648 。
valid_flag 用于获取target中需要被正确分类(非忽视类别,需要参与损失计算)的那些类别的位置,它是一个boolean数组;num_valid 是这些类别数总和。
第36行操作的原因可以看我这篇文章:http://luxuff.cn/archives/reshape和transpose的区别 中的第四节。
第38行中self.reshape(z, (-1, 19))会得到一个1179648x19的数组,它将和target(1179648x1)进行交叉熵损失计算并返回一个1179648x1的损失数组,然后利用valid_flag 数组从该损失数组中取出我们想要的损失值,其他的置0,之后的就是求和与计算像素损失均值了,不计算均值的话损失值会比较大,有可能溢出。
### 三、后续
实验时发现交叉熵不指定ignore_index的情况下FastSCNN也可以慢慢收敛,但是我没有将其完整的训练下来,所以不清楚最后他分割的精度和指定了ignore_index的分割精度谁更优,留着后续再测吧。当前这个版本算是契合了原代码的损失函数了。
语义分割任务中还有另一个损失函数**SoftmaxCrossEntropyOHEMLoss**(原代码:https://github.com/Tramac/Fast-SCNN-pytorch/blob/master/utils/loss.py ),计算损失时对不同类别进行了加权,后续有空了用mindspore把这个也重写一下。

如何用mindspore的nn.SoftmaxCrossEntropyWithLogits算子完成pytorch中nn.CrossEntropyLoss算子的ignore_index功能