# MindSpore 构建训练 pipeline 的几种方式
## 1、预定义与自定义的模块
简要来看,构建一个训练 pipeline 需要以下几个模块:
```tex
1、数据集处理和加载器
2、模型网络结构
3、损失函数
4、优化器
```
然后将这几个模块串起来,理论上就可以使训练工作跑起来。
由于实际工作的不同,上述 4 个模块在细节上会有所不同,比如:
### (1) 数据集处理和加载器
MindSpore 对常用的、比较简单的数据集加载和处理工作进行了预先的封装,比如本例即将用到的手写数字数据集 **MnistDataset**,其他类型数据集的预先封装也在逐步添加中,具体可以在 [ms.dataset](https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.dataset.html) 中查看;
但通常情况下,我们需要处理的数据集是没有这种预先封装的 api 的,所以我们需要自定义。实际上,自定义的方式拥有更高的灵活性,我倒是更建议使用这种方式。大体方式如下,示例代码来自 [BRDNet](https://gitee.com/mindspore/models/blob/master/official/cv/brdnet/src/dataset.py):
```python
import os
import glob
import numpy as np
import PIL.Image as Image
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
class BRDNetDataset:
""" BRDNetDataset.
Args:
data_path: path of images
sigma: noise level
channel: 3 for color, 1 for gray
"""
def __init__(self, data_path, sigma, channel):
images = []
file_dictory = glob.glob(os.path.join(data_path, '*.bmp')) #notice the data format
for file in file_dictory:
images.append(file)
self.images = images
self.sigma = sigma
self.channel = channel
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
get_batch_x: image with noise
get_batch_y: image without noise
"""
if self.channel == 3:
get_batch_y = np.array(Image.open(self.images[index]), dtype='uint8')
else:
get_batch_y = np.expand_dims(np.array(\
Image.open(self.images[index]).convert('L'), dtype='uint8'), axis=2)
get_batch_y = get_batch_y.astype('float32')/255.0
noise = np.random.normal(0, self.sigma/255.0, get_batch_y.shape).astype('float32') # noise
get_batch_x = get_batch_y + noise # input image = clean image + noise
return get_batch_x, get_batch_y
def __len__(self):
return len(self.images)
def create_BRDNetDataset(data_path, sigma, channel, batch_size, device_num, rank, shuffle):
dataset = BRDNetDataset(data_path, sigma, channel)
hwc_to_chw = CV.HWC2CHW()
data_set = ds.GeneratorDataset(dataset, column_names=["image", "label"], \
num_parallel_workers=8, shuffle=shuffle, num_shards=device_num, shard_id=rank)
data_set = data_set.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
data_set = data_set.map(input_columns=["label"], operations=hwc_to_chw, num_parallel_workers=8)
data_set = data_set.batch(batch_size, drop_remainder=True)
return data_set, data_set.get_dataset_size()
```
我们自定义一个 **iteratable** 的类,在 **\_\_getitem\_\_** 方法中返回 **图片和对应标签**,相应的数据增强操作也可以在此处进行。
再通过第49行的 **ds.GeneratorDataset** 操作得到 dataset 类,它将和前面提到的 **MnistDataset** 属于同一个层级,继承自同一个父类。后续对 **data_set** 和 **MnistDataset** 的操作就是类似的了。
### (2) 模型网络结构
具体的网络结构,会由于具体的复杂性而拥有不同的实现难度,但实现流程上不会有太大变化。
### (3) 损失函数
[ms-损失函数](https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.nn.html#loss-functions) 中已经包含了大部分常见的损失函数,可以直接调用;但更实际的情况是需要我们自定义损失函数的。同样,我们实现一个类,接收 **网络输出的结果** 和 **标签值** 计算损失值并返回即可。大体方式如下,示例代码来自 [FastSCNN](https://gitee.com/mindspore/models/blob/master/official/cv/fastscnn/src/loss.py) :
```python
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.tensor import Tensor
from mindspore.nn import SoftmaxCrossEntropyWithLogits
__all__ = ['MixSoftmaxCrossEntropyLoss']
class MixSoftmaxCrossEntropyLoss(nn.Cell):
'''MixSoftmaxCrossEntropyLoss'''
def __init__(self, args, ignore_label=-1, aux=True, aux_weight=0.4, \
sparse=True, reduction='none', one_d_length=2*768*768, **kwargs):
super(MixSoftmaxCrossEntropyLoss, self).__init__()
self.ignore_label = ignore_label
self.weight = aux_weight if aux else 1.0
self.select = ops.Select()
self.reduceSum = ops.ReduceSum(keep_dims=False)
self.div_no_nan = ops.DivNoNan()
self.mul = ops.Mul()
self.reshape = ops.Reshape()
self.cast = ops.Cast()
self.transpose = ops.Transpose()
self.zero_tensor = Tensor([0]*one_d_length, mindspore.float32)
self.SoftmaxCrossEntropyWithLogits = \
SoftmaxCrossEntropyWithLogits(sparse=sparse, reduction="none")
args.logger.info('using MixSoftmaxCrossEntropyLoss....')
args.logger.info('self.ignore_label:' + str(self.ignore_label))
args.logger.info('self.aux:' + str(aux))
args.logger.info('self.weight:' + str(self.weight))
args.logger.info('one_d_length:' + str(one_d_length))
def construct(self, *inputs, **kwargs):
'''construct'''
preds, target = inputs[:-1], inputs[-1]
target = self.reshape(target, (-1,))
valid_flag = target != self.ignore_label
num_valid = self.reduceSum(self.cast(valid_flag, mindspore.float32))
z = self.transpose(preds[0], (0, 2, 3, 1))#move the C-dim to the last, then reshape it.
#This operation is vital, or the data would be soiled.
loss = self.SoftmaxCrossEntropyWithLogits(self.reshape(z, (-1, 19)), target)
loss = self.select(valid_flag, loss, self.zero_tensor)
loss = self.reduceSum(loss)
loss = self.div_no_nan(loss, num_valid)
for i in range(1, len(preds)):
z = self.transpose(preds[i], (0, 2, 3, 1))
aux_loss = self.SoftmaxCrossEntropyWithLogits(self.reshape(z, (-1, 19)), target)
aux_loss = self.select(valid_flag, aux_loss, self.zero_tensor)
aux_loss = self.reduceSum(aux_loss)
aux_loss = self.div_no_nan(aux_loss, num_valid)
loss += self.mul(self.weight, aux_loss)
return loss
```
### (4) 优化器
[ms-优化器](https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.nn.html#optimizer-functions) 中也提供了很多常见的优化器。但实际情况下,我们阅读一些使用诸如 **PyTorch** 框架实现的网络源码时,会发现它在优化器部分进行了自定义,比如 [NFNet](https://github.com/deepmind/deepmind-research/blob/master/nfnets/optim.py)。老实说,在 MindSpore 中自定义优化器我到目前为止还没试过,但仔细阅读 **NFNet** 的这部分代码会发现它是对求出来的梯度做了进一步操作,而这一步操作目前还是可以用 MindSpore 完成的,具体方法将在后续详细说明。
## 2、准备工作
从 [MNIST 数据集下载页面](http://yann.lecun.com/exdb/mnist/) 下载数据集,解压之后按如下结构组织:
```tex
# ./dataset
├── val
│ ├── t10k-images-idx3-ubyte
│ └── t10k-labels-idx1-ubyte
└── train
├── train-images-idx3-ubyte
└── train-labels-idx1-ubyte
```
## 3、训练 pipeline
本节将以 LeNet 和手写数字识别为例介绍各种训练 pipeline 的构建方式。
### (1) 官方极简式
```python
dataset = create_Dataset()
model = create_model()
loss_fn = nn.loss_function()
optimizer = nn.optimizer()
model = Model(model, loss_fn, optimizer)
callbacks = [...]
model.train(epochs, dataset, callbacks=callbacks, dataset_sink_mode=True)
```
dataset、model、loss_fn 和 optimizer 都十分理想化,伸手即得,也可以直接通过 **Model** 接口进行包装,训练也非常简便,直接进行 **model.train** 就可以了,~~对新手十分友好~~。
这种代码结构的好处在于,简洁明了,我们可以从宏观上感受到 pipeline 的执行过程。我们还可以在 **model = Model(model, loss_fn, optimizer)** 处开启自动混合精度(参数名:amp_level)。正常情况下,它还开启了 **数据下沉模式(dataset_sink_mode=True)**,可以加速训练过程(device_target="CPU" 时暂不支持数据下沉)。具体可以查看 [mindspore.Model](https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore/mindspore.Model.html?highlight=model#mindspore.Model.train)。
完整代码如下:
```python
import os
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import Model
from mindspore import context
from mindspore.nn import Accuracy
from mindspore import dtype as mstype
from mindspore.dataset.vision import Inter
from mindspore.common.initializer import Normal
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
def create_dataset(data_path, batch_size=32, num_parallel_workers=1):
mnist_ds = ds.MnistDataset(data_path)
type_cast_op = C.TypeCast(mstype.int32)
resize_op = CV.Resize((32, 32), interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(1.0 / 255.0, 0.0)
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
hwc2chw_op = CV.HWC2CHW()
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.shuffle(buffer_size=10000)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds, mnist_ds.get_dataset_size()
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
mnist_path = "./dataset"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
train_dataset, steps_per_epoch = create_dataset(os.path.join(mnist_path, "train"), 32)
net = LeNet5()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
model = Model(net, loss_fn, optimizer, metrics={"Accuracy": Accuracy()})
time_cb = TimeMonitor(data_size=steps_per_epoch)
loss_cb = LossMonitor(125)
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=10)
ckpoint = ModelCheckpoint(directory="./ckpts", prefix="lenet", config=config_ck)
callbacks = [time_cb, loss_cb, ckpoint]
model.train(epoch=1, train_dataset=train_dataset, callbacks=callbacks, dataset_sink_mode=True)
```
但是,这种方式简便归简便,对开发者来讲,却是一个黑盒子,**对调试极不友好**。所有的操作都由 **model.trian** 来完成了,它里面发生了什么,具体是怎么进行的,各个步骤怎么调整......
如果我们没办法控制到某一个具体操作,实现过程中遇到报错时,掉的 80% 的头发都将是由于定位问题而导致的。因此,这种方式 **不适合调试,只适合最终版本的精简**。建议当我们的代码全部写完,也达到了精度、性能等要求之后,再尝试将代码精简到上述版本。
### (2) 略微透明式
```python
dataset = create_Dataset()
model = create_model()
loss_fn = nn.loss_function()
optimizer = nn.optimizer()
loss_net = nn.WithLossCell(model, loss_fn)
train_net = nn.TrainOneStepCell(loss_net, optimizer)
model = Model(train_net)
callbacks = [...]
model.train(epochs, dataset, callbacks=callbacks, dataset_sink_mode=True)
```
与前一种方式不同,这种方式将 **损失计算** 和 **梯度计算与反传** 这两步操作独立出来了,即 **第5行** 和 **第6行** 代码。
**第5行** 代码的主要功能在于 **在我们定义好的网络结构的输出后面接上一个损失计算**。它的源代码实现如下 [WithLossCell](https://www.mindspore.cn/docs/api/zh-CN/master/_modules/mindspore/nn/wrap/cell_wrapper.html#WithLossCell):
```python
class WithLossCell(Cell):
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, data, label):
out = self._backbone(data)
return self._loss_fn(out, label)
```
不过参数输入方面有一个限制,**def construct(self, data, label)**, 即要求计算损失时的参数传入顺序是 **data** 和 **label**, 这通常情况下不会有什么问题。
但是如果有时候我们就想要先传 **label** 再传 **data** ,或者说计算损失时还有其他参数参与,这种官方提供的算子明显就不够用了,我们可以自定义一个(示例代码来自 [FastSCNN](https://gitee.com/mindspore/models/blob/master/official/cv/fastscnn/src/fast_scnn.py)):
```python
class FastSCNNWithLossCell(nn.Cell):
"""FastSCNN loss, MixSoftmaxCrossEntropyLoss."""
def __init__(self, network, args):
super(FastSCNNWithLossCell, self).__init__()
self.network = network
self.aux = args.aux
self.loss = MixSoftmaxCrossEntropyLoss(args, aux=args.aux, aux_weight=args.aux_weight,
one_d_length=args.batch_size*args.crop_size[0]*args.crop_size[1])
def construct(self, images, targets):
outputs = self.network(images)
if self.aux:
return self.loss(outputs[0], outputs[1], outputs[2], targets)
return self.loss(outputs, targets)
```
二者都是继承自 **nn.Cell**,整体结构都是一样的,所以用法上也是一样的。
**第6行** 代码的主要功能在于 **通过前面计算出的损失值计算反向梯度,并通过优化器更新网络参数**。它的源代码实现如下 [TrainOneStepCell](https://www.mindspore.cn/docs/api/zh-CN/master/_modules/mindspore/nn/wrap/cell_wrapper.html#TrainOneStepCell):
```python
class TrainOneStepCell(Cell):
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.optimizer = optimizer
self.weights = self.optimizer.parameters
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
def construct(self, *inputs):
loss = self.network(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.network, self.weights)(*inputs, sens)
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
return loss
```
第 22 行代码 **grads = self.grad(self.network, self.weights)(\*inputs, sens)** 计算梯度,第24行代码 **loss = F.depend(loss, self.optimizer(grads))** 通过优化器更新网络参数。
如果我们想要对计算出来的梯度进行进一步操作(比如剪裁等),可以继承 **TrainOneStepCell**,并重载 **construct** 方法。简要示例如下(具体操作请查看方案 **(5) 梯度剪裁**):
```python
class MyTrainOneStepCell(TrainOneStepCell):
def construct(self, *inputs):
loss = self.network(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.network, self.weights)(*inputs, sens)
'''
your operations
'''
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
return loss
```
当然,如果读者是从 MindSpore 早期版本跟随过来的,可能会见过下述代码:
```python
class TrainingWrapper(nn.Cell):
"""Training wrapper."""
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))
```
仔细看,它和 **TrainOneStepCell** 是不是几乎一模一样?
实际上, **TrainOneStepCell** 是由 **TrainingWrapper** 优化精简而来的;相较于官方后来添加的 **TrainOneStepCell** api,**TrainingWrapper** 作为民间流传版本,对早期的模型众智开发者来讲,着实帮了大忙了。到目前为止,我仍见过不少开发者还在使用 **TrainingWrapper**。
如果是个人开发,使用 **TrainingWrapper** 问题不大,但如果是 **模型众智开发者**,则 **极不建议继续使用 TrainingWrapper**。因为 **TrainingWrapper** 中用到了一些相对接触~~内部~~的方法,比如第18行的 **if auto_parallel_context().get_device_num_is_set():** 。auto_parallel_context() 是从 **mindspore.parallel._auto_parallel_context** 中导入的,而 **_auto_parallel_context** 是 **内部包**,在提交 PR 合并代码时,这部分会被打回来。官方不建议外部开发者使用内部包,或者说不信任外部开发者使用内部包的规范程度,所以请使用 **TrainOneStepCell**。
完整代码如下:
```python
import os
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import Model
from mindspore import context
from mindspore.nn import Accuracy
from mindspore import dtype as mstype
from mindspore.dataset.vision import Inter
from mindspore.common.initializer import Normal
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
def create_dataset(data_path, batch_size=32, num_parallel_workers=1):
mnist_ds = ds.MnistDataset(data_path)
type_cast_op = C.TypeCast(mstype.int32)
resize_op = CV.Resize((32, 32), interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(1.0 / 255.0, 0.0)
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
hwc2chw_op = CV.HWC2CHW()
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.shuffle(buffer_size=10000)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds, mnist_ds.get_dataset_size()
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class LeNet5WithLossCell(nn.Cell):
def __init__(self, network, loss_fn):
super(LeNet5WithLossCell, self).__init__()
self.network = network
self.loss = loss_fn
def construct(self, images, label):
outputs = self.network(images)
'''
your operations
'''
return self.loss(outputs, label)
mnist_path = "./dataset"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
train_dataset, steps_per_epoch = create_dataset(os.path.join(mnist_path, "train"), 32)
net = LeNet5()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
loss_net = nn.WithLossCell(net, loss_fn)
#loss_net = LeNet5WithLossCell(net, loss_fn)
train_net = nn.TrainOneStepCell(loss_net, optimizer)
model = Model(train_net)
time_cb = TimeMonitor(data_size=steps_per_epoch)
loss_cb = LossMonitor(125)
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=10)
ckpoint = ModelCheckpoint(directory="./ckpts", prefix="lenet", config=config_ck)
callbacks = [time_cb, loss_cb, ckpoint]
model.train(epoch=1, train_dataset=train_dataset, callbacks=callbacks, dataset_sink_mode=True)
```
注意此处的第88行代码 **model = Model(train_net)** ,与前一小节用到的 **model = Model(net, loss_fn, optimizer, metrics={"Accuracy": Accuracy()})** 之间的不同。
此处由于 Model 包装时,我们没有给它提供 **loss_fn** 和 **optimizer**,因此它没法添加 **metrics** (推理时的精度计算方式,对训练不影响) 和 **自动混合精度 amp_level** (amp_level="O0" 时可以运行,但是 O0 意味着不使用,所以还是没用上)。
这种方式相较于前一种方式多了两行代码,还少了一些特性,我们在~~吃力不讨好~~吗?
这是由于本例太过简单,效果还看不出来。
本节方式的优势在于 **分离损失计算** 和 **梯度计算和反传** 操作,通过进一步自定义这两个模块,我们能定位到很多问题。
而且很多时候,即使没有对这两个模块进行自定义,我们只是将训练方式由 **方案(1)** 调整为 **方案 (2)**,这代码它莫名其妙的就能跑起来了,这神奇不神奇?
**方案(1)** 和 **方案 (2)** 在本例中均开启了**数据下沉模式**,但是建议各位在调试代码时,**不要开启数据下沉模式**,它有可能成为你定位问题的阻碍。在代码运行无误后,再通过数据下沉节约训练时间。
### (3) 刨根问底式
```python
dataset = create_Dataset()
model = create_model()
loss_fn = nn.loss_function()
optimizer = nn.optimizer()
loss_net = nn.WithLossCell(model, loss_fn)
train_net = nn.TrainOneStepCell(loss_net, optimizer)
data_loader = dataset.create_dict_iterator()
for k in range(epochs):
train_net.set_train(True)
for i, data in enumerate(data_loader):
loss = train_net(data["image"], data["label"])
```
本节的训练方式将彻底放弃 **Model** 的包装,不再接受 MindSpore 通过 **model.train** 接管训练操作。我们试图将所有训练细节把控在自己手中。这种训练方式与 PyTorch 的训练 pipeline 类似,调试非常方便,而且还能实现很多前两种方式不能实现的操作。
完整代码如下:
```python
import os
import mindspore
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import Model
from mindspore import context, save_checkpoint
from mindspore.nn import Accuracy
from mindspore import dtype as mstype
from mindspore.dataset.vision import Inter
from mindspore.common.initializer import Normal
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
def create_dataset(data_path, batch_size=32, num_parallel_workers=1):
mnist_ds = ds.MnistDataset(data_path)
type_cast_op = C.TypeCast(mstype.int32)
resize_op = CV.Resize((32, 32), interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(1.0 / 255.0, 0.0)
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
hwc2chw_op = CV.HWC2CHW()
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.shuffle(buffer_size=10000)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds, mnist_ds.get_dataset_size()
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class LeNet5WithLossCell(nn.Cell):
def __init__(self, network, loss_fn):
super(LeNet5WithLossCell, self).__init__()
self.network = network
self.loss = loss_fn
def construct(self, images, label):
outputs = self.network(images)
'''
your operations
'''
return self.loss(outputs, label)
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
def test_net(model, data_path):
eval_dataset, steps_per_epoch = create_dataset(os.path.join(data_path, "val"))
top1 = AverageMeter('top1')
top5 = AverageMeter('top5')
top1_m = nn.TopKCategoricalAccuracy(1)
top5_m = nn.TopKCategoricalAccuracy(5)
data_loader = eval_dataset.create_dict_iterator()
for i, data in enumerate(data_loader):
output = model(data["image"])
top1_m.clear()
top5_m.clear()
top1_m.update(output, data["label"])
top5_m.update(output, data["label"])
top1.update(top1_m.eval(), 32)
top5.update(top5_m.eval(), 32)
acc1, acc5 = top1.avg, top5.avg
print("Eval Accuracy Top1:", acc1, ", Top5:", acc5)
mnist_path = "./dataset"
epochs = 1
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
train_dataset, steps_per_epoch = create_dataset(os.path.join(mnist_path, "train"), 32)
net = LeNet5()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
loss_net = LeNet5WithLossCell(net, loss_fn)
train_net = nn.TrainOneStepCell(loss_net, optimizer)
data_loader = train_dataset.create_dict_iterator()
loss_meter = AverageMeter('loss')
for k in range(epochs):
train_net.set_train(True)
for i, data in enumerate(data_loader):
loss = train_net(data["image"], data["label"])
loss_meter.update(loss.asnumpy())
if i % 100 == 0:
print("epoch:", k, "step:", i, loss_meter)
loss_meter.reset()
save_checkpoint(train_net, "./ckpts/lenet_" + str(k) +".ckpt")
test_net(net, mnist_path)
```
如上述代码所示,我们能够在网络调试和训练过程中把握如下细节:
* 数据集加载器的输出和网络输入一一对应
第148行和第149行代码,遍历 data_loader,这里取出来的 **data["image"]** 和 **data["label"]** 是由第19行的 **ds.MnistDataset** 提前做了映射的。对于自定义的数据集处理和加载,可以查看前述 **(1)数据集处理和加载器** 部分,示例代码第49行即是通过 ds.GeneratorDataset 手动做的映射。
```python
dataset = BRDNetDataset(data_path, sigma, channel)
data_set = ds.GeneratorDataset(dataset, column_names=["image", "label"], \
num_parallel_workers=8, shuffle=shuffle, num_shards=device_num, shard_id=rank)
```
**column_names=["image", "label"]** 即是对 **BRDNetDataset** 的 **\_\_getitem\_\_** 的返回值做的映射,这里的 **column_names** 可以根据实际需要该成其他任何名称,也不受限于固定顺序,写成 **column_names=["label1", "image1"]** 也是完全可以的,只要和 **\_\_getitem\_\_** 返回值一一对应即可。
这一部分,我们将数据集处理、加载的输出,到送入网络的输入这一段操作完全握在了手里。
* 定位网络前向传播和损失计算
```python
loss_net = LeNet5WithLossCell(net, loss_fn)
train_net = nn.TrainOneStepCell(loss_net, optimizer)
loss = train_net(data["image"], data["label"])
```
这三行代码可以牢牢控制住网络训练过程中的前向传播和损失计算操作,方便定位网络结构问题和自定义损失函数问题。
* 训练信息自由打印
第152行代码,我们可以打印任何我们需要的中间信息,loss、当前epoch、当前step、运行耗时、运行速度等等,运行速度这些数据自己算一下就可以了。 **Model** 包装的方式当然也可以打印这些信息,只不过需要通过自定义 **callback** 来实现,稍微麻烦一点。
* 权重保存和训练过程中的推理
第154行代码每个epoch运行完成后保存一下网络参数,和前面的 CheckpointConfig 差不多。
```python
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=10)
ckpoint = ModelCheckpoint(directory="./ckpts", prefix="lenet", config=config_ck)
```
第155行代码,在每个epoch运行完成后进行一次推理操作,得到推理精度,可以准确把握网络的收敛情况,因为单单从loss收敛还得不到全面的结果。当然,也可以每 N 个epoch进行一次推理,对于有些网络来讲,每个epoch都推理一遍比较耗时,这种操作也没有必要。
**Model** 包装的方式也可以在训练过程中插入推理操作,同样是通过 **callback** 的方式实现的,后面将进行详细介绍。
注意第 137 行的 **net = LeNet5()** 和第 141行的 **loss_net = LeNet5WithLossCell(net, loss_fn)** 以及第 142 行的 **train_net = nn.TrainOneStepCell(loss_net, optimizer)**,由于均是后者引用了前者的地址,这三个~~网络~~将在训练过程中具有类似作用,第一个负责网络前向传播,第二个在前者的基础上计算 loss,第三个进一步计算反向梯度并通过优化器更新网络参数。
也就是说,对于 **net = LeNet5()**,如果我们执行一次 **outputs=net(data["image"])**,它将基于现在已有的网络参数对输入图片进行正常的计算并得到输出,由于它本身的结构就只有这些,它也不会计算损失和梯度并做反向优化,这就意味着它等于 **PyTorch** 的 **with torch.no_grad()** 操作,因此,**net = LeNet5()** 与 **train_net = nn.TrainOneStepCell(loss_net, optimizer)** 的混合使用可以完成自监督操作。读者如果感兴趣,可以从 [Neighbor2Neighbor](https://gitee.com/mindspore/models/tree/master/research/cv/Neighbor2Neighbor) 这个例子了解详细操作。
**注意,本节的训练方式由于没有使用数据下沉等特性,再加上一些其他频繁的切换操作,它的训练性能将远低于前述两种训练 pipeline,除非万不得已,本节的训练方式将仅适用于定位错误,排查问题。**
### (4) Model+callback
虽然 **Model** 的包装对于新手用户来讲过于 **黑盒化**,但是它仍具有无法替代的优势,因此,如何让 **model.train** 具有类似 **方案(3)** 一样的透明度和操作灵活性就显得比较重要了。
通常来讲,在训练过程中把握网络收敛情况比较重要,因此,本节介绍在 **model.train** 中通过 **callback** 插入推理操作的实现方式。
#### 官方 metrics
完整代码如下:
```python
import os
import stat
from datetime import datetime
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import Model
from mindspore import context
from mindspore.nn import Accuracy
from mindspore import log as logger
from mindspore import save_checkpoint
from mindspore import dtype as mstype
from mindspore.dataset.vision import Inter
from mindspore.train.callback import Callback
from mindspore.common.initializer import Normal
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
def create_dataset(data_path, batch_size=32, num_parallel_workers=1):
mnist_ds = ds.MnistDataset(data_path)
type_cast_op = C.TypeCast(mstype.int32)
resize_op = CV.Resize((32, 32), interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(1.0 / 255.0, 0.0)
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
hwc2chw_op = CV.HWC2CHW()
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.shuffle(buffer_size=10000)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds, mnist_ds.get_dataset_size()
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class EvalCallBack(Callback):
def __init__(self, network, dataloader, interval=1, eval_start_epoch=0, \
save_best_ckpt=True, ckpt_directory="./", besk_ckpt_name="best.ckpt"):
super(EvalCallBack, self).__init__()
self.network = network
self.dataloader = dataloader
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
acc = self.network.eval(self.dataloader, dataset_sink_mode=True)
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: epoch: {}, {}: {}".format(cur_epoch, "accuracy", acc['Accuracy']*100), flush=True)
if acc['Accuracy'] >= self.best_res:
self.best_res = acc['Accuracy']
self.best_epoch = cur_epoch
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: update best result: {}".format(acc['Accuracy']*100), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: End training, the best {0} is: {1}, it's epoch is {2}".format("accuracy",\
self.best_res*100, self.best_epoch), flush=True)
mnist_path = "./dataset"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
train_dataset, steps_per_epoch = create_dataset(os.path.join(mnist_path, "train"), 32)
net = LeNet5()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
model = Model(net, loss_fn, optimizer)
time_cb = TimeMonitor(data_size=steps_per_epoch)
loss_cb = LossMonitor(125)
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=10)
ckpoint = ModelCheckpoint(directory="./ckpts", prefix="lenet", config=config_ck)
callbacks = [time_cb, loss_cb, ckpoint]
val_dataset, val_data_size = create_dataset(os.path.join(mnist_path, "val"), 32)
network_eval = Model(net, loss_fn=loss_fn, metrics={"Accuracy": Accuracy()})
eval_cb = EvalCallBack(network_eval, val_dataset)
callbacks.append(eval_cb)
model.train(epoch=3, train_dataset=train_dataset, callbacks=callbacks, dataset_sink_mode=True)
```
运行日志:
```tex
[WARNING] ME(17320:34064,MainProcess):2022-01-02-17:40:55.985.839 [mindspore\train\model.py:435] The CPU cannot support dataset sink mode currently.So the training process will be performed with dataset not sink.
epoch: 1 step: 125, loss is 2.3245094
epoch: 1 step: 250, loss is 2.3097432
epoch: 1 step: 375, loss is 2.2903495
epoch: 1 step: 500, loss is 2.3132498
epoch: 1 step: 625, loss is 2.3003674
epoch: 1 step: 750, loss is 2.2925541
epoch: 1 step: 875, loss is 2.2658715
epoch: 1 step: 1000, loss is 0.563914
epoch: 1 step: 1125, loss is 0.42866653
epoch: 1 step: 1250, loss is 0.4092283
epoch: 1 step: 1375, loss is 0.4039851
epoch: 1 step: 1500, loss is 0.11572214
epoch: 1 step: 1625, loss is 0.2705834
epoch: 1 step: 1750, loss is 0.068169944
epoch: 1 step: 1875, loss is 0.23778073
epoch time: 9973.363 ms, per step time: 5.319 ms
[WARNING] ME(17320:34064,MainProcess):2022-01-02-17:41:05.961.198 [mindspore\train\model.py:774] CPU cannot support dataset sink mode currently.So the evaluating process will be performed with dataset non-sink mode.
2022-01-02 17:41:06,741 :INFO: epoch: 1, accuracy: 95.99358974358975
2022-01-02 17:41:06,741 :INFO: update best result: 95.99358974358975
2022-01-02 17:41:06,747 :INFO: update best checkpoint at: ./best.ckpt
2022-01-02 17:41:06,747 :INFO: End training, the best accuracy is: 95.99358974358975, it's epoch is 1
(ms1.3) E:\LaNet>python lenet4.py
[WARNING] ME(35744:40004,MainProcess):2022-01-02-17:42:01.678.213 [mindspore\train\model.py:435] The CPU cannot support dataset sink mode currently.So the training process will be performed with dataset not sink.
epoch: 1 step: 125, loss is 2.3140264
epoch: 1 step: 250, loss is 2.3043535
epoch: 1 step: 375, loss is 2.311633
epoch: 1 step: 500, loss is 2.3036153
epoch: 1 step: 625, loss is 2.308304
epoch: 1 step: 750, loss is 2.291089
epoch: 1 step: 875, loss is 2.297149
epoch: 1 step: 1000, loss is 2.312134
epoch: 1 step: 1125, loss is 2.302312
epoch: 1 step: 1250, loss is 1.0658072
epoch: 1 step: 1375, loss is 0.20514262
epoch: 1 step: 1500, loss is 0.3152214
epoch: 1 step: 1625, loss is 0.15839374
epoch: 1 step: 1750, loss is 0.060854506
epoch: 1 step: 1875, loss is 0.0756621
epoch time: 10337.838 ms, per step time: 5.514 ms
[WARNING] ME(35744:40004,MainProcess):2022-01-02-17:42:12.190.44 [mindspore\train\model.py:774] CPU cannot support dataset sink mode currently.So the evaluating process will be performed with dataset non-sink mode.
2022-01-02 17:42:12,820 :INFO: epoch: 1, accuracy: 94.75160256410257
2022-01-02 17:42:12,821 :INFO: update best result: 94.75160256410257
2022-01-02 17:42:12,828 :INFO: update best checkpoint at: ./best.ckpt
epoch: 2 step: 125, loss is 0.040145893
epoch: 2 step: 250, loss is 0.34580463
epoch: 2 step: 375, loss is 0.053992774
epoch: 2 step: 500, loss is 0.010377166
epoch: 2 step: 625, loss is 0.03660379
epoch: 2 step: 750, loss is 0.243122
epoch: 2 step: 875, loss is 0.10487746
epoch: 2 step: 1000, loss is 0.35544384
epoch: 2 step: 1125, loss is 0.039908916
epoch: 2 step: 1250, loss is 0.03180509
epoch: 2 step: 1375, loss is 0.20811962
epoch: 2 step: 1500, loss is 0.3726245
epoch: 2 step: 1625, loss is 0.06820704
epoch: 2 step: 1750, loss is 0.023554254
epoch: 2 step: 1875, loss is 0.051936243
epoch time: 9574.187 ms, per step time: 5.106 ms
[WARNING] ME(35744:40004,MainProcess):2022-01-02-17:42:22.403.887 [mindspore\train\model.py:774] CPU cannot support dataset sink mode currently.So the evaluating process will be performed with dataset non-sink mode.
2022-01-02 17:42:23,147 :INFO: epoch: 2, accuracy: 97.66626602564102
2022-01-02 17:42:23,147 :INFO: update best result: 97.66626602564102
2022-01-02 17:42:23,153 :INFO: update best checkpoint at: ./best.ckpt
epoch: 3 step: 125, loss is 0.06722653
epoch: 3 step: 250, loss is 0.0913366
epoch: 3 step: 375, loss is 0.15591136
epoch: 3 step: 500, loss is 0.0014791153
epoch: 3 step: 625, loss is 0.0742274
epoch: 3 step: 750, loss is 0.0127390465
epoch: 3 step: 875, loss is 0.2486604
epoch: 3 step: 1000, loss is 0.085124485
epoch: 3 step: 1125, loss is 0.03314853
epoch: 3 step: 1250, loss is 0.088092566
epoch: 3 step: 1375, loss is 0.062529296
epoch: 3 step: 1500, loss is 0.18533966
epoch: 3 step: 1625, loss is 0.028387262
epoch: 3 step: 1750, loss is 0.0057865116
epoch: 3 step: 1875, loss is 0.20129825
epoch time: 9733.319 ms, per step time: 5.191 ms
[WARNING] ME(35744:40004,MainProcess):2022-01-02-17:42:32.886.916 [mindspore\train\model.py:774] CPU cannot support dataset sink mode currently.So the evaluating process will be performed with dataset non-sink mode.
2022-01-02 17:42:33,688 :INFO: epoch: 3, accuracy: 98.03685897435898
2022-01-02 17:42:33,688 :INFO: update best result: 98.03685897435898
2022-01-02 17:42:33,695 :INFO: update best checkpoint at: ./best.ckpt
2022-01-02 17:42:33,695 :INFO: End training, the best accuracy is: 98.03685897435898, it's epoch is 3
```
自定义 **callback** 其实很简单,继承一下 **Callback** ,重载一下 **def epoch_end(self, run_context)** 和 **def end(self, run_context)** 就可以了。
#### 自定义 metrics
上述方式调用的是官方提供的 **metrics={"Accuracy": Accuracy()}** ,更进一步的,我们想要或者说是需要使用自定义的 metrics,具体实现方式如下:
完整代码:
```python
import os
import stat
from datetime import datetime
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import Model
from mindspore import context
from mindspore.nn import Accuracy
from mindspore import log as logger
from mindspore import save_checkpoint
from mindspore import dtype as mstype
from mindspore.dataset.vision import Inter
from mindspore.train.callback import Callback
from mindspore.common.initializer import Normal
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
def create_dataset(data_path, batch_size=32, num_parallel_workers=1):
mnist_ds = ds.MnistDataset(data_path)
type_cast_op = C.TypeCast(mstype.int32)
resize_op = CV.Resize((32, 32), interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(1.0 / 255.0, 0.0)
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
hwc2chw_op = CV.HWC2CHW()
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.shuffle(buffer_size=10000)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds, mnist_ds.get_dataset_size()
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f', tb_writer=None):
self.name = name
self.fmt = fmt
self.reset()
self.tb_writer = tb_writer
self.cur_step = 1
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
if self.tb_writer is not None:
self.tb_writer.add_scalar(self.name, self.val, self.cur_step)
self.cur_step += 1
def __str__(self):
fmtstr = '{name}:{avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class AccuracyMetric(nn.Metric):
def __init__(self, batch_size):
super(AccuracyMetric, self).__init__()
self.top1 = AverageMeter('top1')
self.top5 = AverageMeter('top5')
self.top1_m = nn.TopKCategoricalAccuracy(1)
self.top5_m = nn.TopKCategoricalAccuracy(5)
self.batch_size = batch_size
def clear(self):
self.top1.reset()
self.top5.reset()
def update(self, outputs, label):
self.top1_m.clear()
self.top5_m.clear()
self.top1_m.update(outputs, label)
self.top5_m.update(outputs, label)
self.top1.update(self.top1_m.eval(), self.batch_size)
self.top5.update(self.top5_m.eval(), self.batch_size)
def eval(self):
return self.top1.avg, self.top5.avg
class EvalCallBack(Callback):
def __init__(self, network, dataloader, interval=1, eval_start_epoch=0, \
save_best_ckpt=True, ckpt_directory="./", besk_ckpt_name="best.ckpt"):
super(EvalCallBack, self).__init__()
self.network = network
self.dataloader = dataloader
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, besk_ckpt_name)
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
acc = self.network.eval(self.dataloader, dataset_sink_mode=True)
top1, top5 = acc['AccuracyMetric']
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: epoch: {}, {}: {}, {}: {}".format(cur_epoch, "accuracy top1", top1, "top5", top5), flush=True)
if top1 >= self.best_res:
self.best_res = top1
self.best_epoch = cur_epoch
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: update best result: {}".format(top1*100), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3],\
":INFO: End training, the best {0} is: {1}, it's epoch is {2}".format("accuracy",\
self.best_res*100, self.best_epoch), flush=True)
mnist_path = "./dataset"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
train_dataset, steps_per_epoch = create_dataset(os.path.join(mnist_path, "train"), 32)
net = LeNet5()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
model = Model(net, loss_fn, optimizer)
time_cb = TimeMonitor(data_size=steps_per_epoch)
loss_cb = LossMonitor(125)
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=10)
ckpoint = ModelCheckpoint(directory="./ckpts", prefix="lenet", config=config_ck)
callbacks = [time_cb, loss_cb, ckpoint]
val_dataset, val_data_size = create_dataset(os.path.join(mnist_path, "val"), 32)
network_eval = Model(net, loss_fn=loss_fn, metrics={"AccuracyMetric": AccuracyMetric(32)})
eval_cb = EvalCallBack(network_eval, val_dataset)
callbacks.append(eval_cb)
model.train(epoch=3, train_dataset=train_dataset, callbacks=callbacks, dataset_sink_mode=True)
```
运行日志:
```tex
[WARNING] ME(40408:36108,MainProcess):2022-01-02-17:59:06.439.391 [mindspore\train\model.py:435] The CPU cannot support dataset sink mode currently.So the training process will be performed with dataset not sink.
epoch: 1 step: 125, loss is 2.2842762
epoch: 1 step: 250, loss is 2.3017297
epoch: 1 step: 375, loss is 2.3166213
epoch: 1 step: 500, loss is 2.3036542
epoch: 1 step: 625, loss is 2.2868423
epoch: 1 step: 750, loss is 1.8704919
epoch: 1 step: 875, loss is 0.7434879
epoch: 1 step: 1000, loss is 0.30049253
epoch: 1 step: 1125, loss is 0.09520069
epoch: 1 step: 1250, loss is 0.14693508
epoch: 1 step: 1375, loss is 0.049935322
epoch: 1 step: 1500, loss is 0.1242107
epoch: 1 step: 1625, loss is 0.038135473
epoch: 1 step: 1750, loss is 0.2170761
epoch: 1 step: 1875, loss is 0.09290631
epoch time: 10354.831 ms, per step time: 5.523 ms
[WARNING] ME(40408:36108,MainProcess):2022-01-02-17:59:16.796.217 [mindspore\train\model.py:774] CPU cannot support dataset sink mode currently.So the evaluating process will be performed with dataset non-sink mode.
2022-01-02 17:59:17,602 :INFO: epoch: 1, accuracy top1: 0.9688501602564102, top5: 0.9995993589743589
2022-01-02 17:59:17,602 :INFO: update best result: 96.88501602564102
2022-01-02 17:59:17,608 :INFO: update best checkpoint at: ./best.ckpt
epoch: 2 step: 125, loss is 0.19031726
epoch: 2 step: 250, loss is 0.027574457
epoch: 2 step: 375, loss is 0.037040163
epoch: 2 step: 500, loss is 0.015066621
epoch: 2 step: 625, loss is 0.0034439913
epoch: 2 step: 750, loss is 0.038274184
epoch: 2 step: 875, loss is 0.005039366
epoch: 2 step: 1000, loss is 0.0121483
epoch: 2 step: 1125, loss is 0.09210114
epoch: 2 step: 1250, loss is 0.04257711
epoch: 2 step: 1375, loss is 0.114633106
epoch: 2 step: 1500, loss is 0.15046401
epoch: 2 step: 1625, loss is 0.081233054
epoch: 2 step: 1750, loss is 0.08739235
epoch: 2 step: 1875, loss is 0.34953314
epoch time: 9251.442 ms, per step time: 4.934 ms
[WARNING] ME(40408:36108,MainProcess):2022-01-02-17:59:26.860.455 [mindspore\train\model.py:774] CPU cannot support dataset sink mode currently.So the evaluating process will be performed with dataset non-sink mode.
2022-01-02 17:59:27,647 :INFO: epoch: 2, accuracy top1: 0.9788661858974359, top5: 0.9994991987179487
2022-01-02 17:59:27,647 :INFO: update best result: 97.88661858974359
2022-01-02 17:59:27,653 :INFO: update best checkpoint at: ./best.ckpt
epoch: 3 step: 125, loss is 0.018552104
epoch: 3 step: 250, loss is 0.011122282
epoch: 3 step: 375, loss is 0.07484381
epoch: 3 step: 500, loss is 0.27695313
epoch: 3 step: 625, loss is 0.0343815
epoch: 3 step: 750, loss is 0.042258963
epoch: 3 step: 875, loss is 0.13738352
epoch: 3 step: 1000, loss is 0.22029029
epoch: 3 step: 1125, loss is 0.038821213
epoch: 3 step: 1250, loss is 0.011405985
epoch: 3 step: 1375, loss is 0.0070968377
epoch: 3 step: 1500, loss is 0.05930841
epoch: 3 step: 1625, loss is 0.0071509383
epoch: 3 step: 1750, loss is 0.10316856
epoch: 3 step: 1875, loss is 0.022942578
epoch time: 9820.442 ms, per step time: 5.238 ms
[WARNING] ME(40408:36108,MainProcess):2022-01-02-17:59:37.475.401 [mindspore\train\model.py:774] CPU cannot support dataset sink mode currently.So the evaluating process will be performed with dataset non-sink mode.
2022-01-02 17:59:38,329 :INFO: epoch: 3, accuracy top1: 0.98046875, top5: 0.9997996794871795
2022-01-02 17:59:38,329 :INFO: update best result: 98.046875
2022-01-02 17:59:38,335 :INFO: update best checkpoint at: ./best.ckpt
2022-01-02 17:59:38,335 :INFO: End training, the best accuracy is: 98.046875, it's epoch is 3
```
在本例中,我们自定义了一个 **AccuracyMetric**,并计算了推理的 top1 和 top5 精度,这些细节都是我们可以控制的(仔细看看,这里的 **AccuracyMetric** 计算方式是不是和 **方案(3) test操作** 中计算精度的方式很像)。当然,官方 [metrics](https://www.mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.nn.html#metrics) 是有这个实现的,我只是简单举个例子。
其他需要在训练过程中插入的操作也都可以使用 **callback** 实现。通过这种方式,可以既保留训练性能,又拥有操作灵活度,现在用起来已经比较顺手了。
### (5) 梯度剪裁
到现在为止,我们还剩下 **TrainOneStepCell** 没有拆开来用过,本节将以梯度剪裁为例,介绍对 **TrainOneStepCell** 的重载。
完整代码:
```python
import os
import mindspore.nn as nn
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore import Model
from mindspore import context
from mindspore.nn import Accuracy
from mindspore.ops import composite
from mindspore import dtype as mstype
from mindspore.ops import functional as F
from mindspore.dataset.vision import Inter
from mindspore.common.initializer import Normal
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
def create_dataset(data_path, batch_size=32, num_parallel_workers=1):
mnist_ds = ds.MnistDataset(data_path)
type_cast_op = C.TypeCast(mstype.int32)
resize_op = CV.Resize((32, 32), interpolation=Inter.LINEAR)
rescale_op = CV.Rescale(1.0 / 255.0, 0.0)
rescale_nml_op = CV.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)
hwc2chw_op = CV.HWC2CHW()
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.shuffle(buffer_size=10000)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
return mnist_ds, mnist_ds.get_dataset_size()
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
class LeNet5WithLossCell(nn.Cell):
def __init__(self, network, loss_fn):
super(LeNet5WithLossCell, self).__init__()
self.network = network
self.loss = loss_fn
def construct(self, images, label):
outputs = self.network(images)
'''
your operations
'''
return self.loss(outputs, label)
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 10000.0
clip_grad = composite.MultitypeFuncGraph("clip_grad")
@clip_grad.register("Number", "Number", "Tensor")
def _clip_grad(clip_type, clip_value, grad):
"""
Clip gradients.
Inputs:
clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'.
clip_value (float): Specifies how much to clip.
grad (tuple[Tensor]): Gradients.
Outputs:
tuple[Tensor]: clipped gradients.
"""
if clip_type not in (0, 1):
return grad
dt = F.dtype(grad)
if clip_type == 0:
new_grad = composite.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
class MyTrainOneStepCell(nn.TrainOneStepCell):
def __init__(self, network, optimizer, sens=1.0):
super(MyTrainOneStepCell, self).__init__(network, optimizer, sens)
self.hyper_map = composite.HyperMap()
def construct(self, *inputs):
loss = self.network(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.network, self.weights)(*inputs, sens)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
grads = self.grad_reducer(grads)
loss = F.depend(loss, self.optimizer(grads))
return loss
mnist_path = "./dataset"
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
train_dataset, steps_per_epoch = create_dataset(os.path.join(mnist_path, "train"), 32)
net = LeNet5()
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9)
loss_net = LeNet5WithLossCell(net, loss_fn)
train_net = MyTrainOneStepCell(loss_net, optimizer)
model = Model(train_net)
time_cb = TimeMonitor(data_size=steps_per_epoch)
loss_cb = LossMonitor(125)
config_ck = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=10)
ckpoint = ModelCheckpoint(directory="./ckpts", prefix="lenet", config=config_ck)
callbacks = [time_cb, loss_cb, ckpoint]
model.train(epoch=1, train_dataset=train_dataset, callbacks=callbacks, dataset_sink_mode=True)
```
运行日志:
```python
[WARNING] ME(15452:28784,MainProcess):2022-01-02-20:31:46.177.364 [mindspore\train\model.py:435] The CPU cannot support dataset sink mode currently.So the training process will be performed with dataset not sink.
epoch: 1 step: 125, loss is 2.2990582
epoch: 1 step: 250, loss is 2.2897432
epoch: 1 step: 375, loss is 2.3054984
epoch: 1 step: 500, loss is 2.2883756
epoch: 1 step: 625, loss is 2.3043606
epoch: 1 step: 750, loss is 1.5947446
epoch: 1 step: 875, loss is 0.42944697
epoch: 1 step: 1000, loss is 0.2806945
epoch: 1 step: 1125, loss is 0.33546546
epoch: 1 step: 1250, loss is 0.3850002
epoch: 1 step: 1375, loss is 0.27240688
epoch: 1 step: 1500, loss is 0.1559959
epoch: 1 step: 1625, loss is 0.3568045
epoch: 1 step: 1750, loss is 0.0339943
epoch: 1 step: 1875, loss is 0.43218622
epoch time: 23671.972 ms, per step time: 12.625 ms
```
单步耗时12ms,比不使用梯度裁剪慢了很多。
**需要注意的是,上述几种训练方式是可以互相组合的,不应拘泥于一种形式**。

MindSpore 构建训练 pipeline 的几种方式