DEM-SRNet: 全球3弧秒(90m)海陆高分辨率数字高程模型

DownloadNotebookDownloadCodeViewSource

概述

DEM-SRNet是数字高程模型网络的简称,可以提供准确的基础地理数据,因此在全球气候变化、海洋潮汐运动、地球圈物质交换等研究领域发挥着至关重要的作用。在探索海洋地形的地理分布和板块运动模式之前,必须克服精细建模的关键理论和技术障碍。直到2020年代,分辨率超过100米的全球DEM才出现,当时全球30米分辨率的陆地NASADEM和FABDEM都免费提供。由于数据量大,制作更高分辨率的全球DEM地图需要更多的时间和计算资源。由于分辨率和传感器的不同,这些数据融合存在各种精度困难。一种称为超分辨率(SR)的深度学习技术已用于识别和补偿分辨率和传感器之间的差异。30米分辨率的NASADEM、GEBCO_2021数据和众多高分辨率(HR)区域海洋DEM数据集都是公开可用的,使得生成具有3弧秒分辨率的全球DEM数据集成为可能。

本教程介绍了DEM-SRNet的研究背景和技术路径,并展示了如何通过MindSpore Earth训练和快速推断模型。更多信息详见论文:Super-resolution reconstruction of a 3 arc-second global DEM dataset

技术路径

MindSpore Earth解决这个问题的具体流程如下:

  1. 创建数据集

  2. 模型构建

  3. 损失函数

  4. 模型训练

  5. 模型评估和可视化

DEM-SRNet

基于深度残差网络的DEM-SRNet的预训练数据来源于地面DEM数据。如图所示,设计的预训练结构源自增强型深度超分辨率网络(EDSR)。DEM-SRNet网络的第一个卷积层提取特征集合。EDSR模型默认的残差块数量扩大到32。通过实验比较,最佳残差块数量(ResBlocks)为42,每个残差块由两个卷积层组成,并用ReLU激活函数进行插值。最后,由卷积层和逐元素加法层组成。后者包括用于提取特征图的卷积层,利用比例因子为5的插值层从输入的低分辨率15角秒数据上采样到目标高分辨率3角秒数据,最后,卷积层聚合低分辨率空间中的特征图并生成SR输出数据。与典型的反卷积层相反,使用插值函数层,可以对低分辨率数据进行超分辨率,具备更好的效果。

该网络由88个卷积层、43个元素加法层和1个插值层组成。每个卷积层中的卷积核大小设置为3,填充设置为1。总共有50,734,080个网络参数。在训练阶段,使用初始学习率为0.0001的Adam优化器和指数衰减方法使用大数据集来训练模型。当模型的性能在验证集中开始劣化时,使用为容忍度为6的早停技术来终止训练,以防止过度拟合。小批量梯度下降法通常需要300个epoch才能从头开始构建预训练网络。初始参数源自地面数据预训练网络,预训练网络的冻结层与海洋DEM结合使用,以微调全局DEM-SRNet模型。由于微调样本有限,学习率对收敛过程有显着影响,并将学习率调整为0.00001。

DEM_SR

模型训练分为两个步骤:

  1. 预训练:如上图所示,在预训练步骤中,预训练的DEM-SRNet网络的第一个卷积层提取特征集合。EDSR模型默认残差块数已扩展至32。

  2. 微调:本文的主要贡献是使用GEBCO_2021和NASADEM陆地样本预训练神经网络模型,然后使用有限的区域海洋DEM数据进行额外的微调。

[1]:
import os
import numpy as np
import matplotlib.pyplot as plt

from mindspore import context, nn, Tensor
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import LossMonitor, TimeMonitor

from mindearth.utils import load_yaml_config, create_logger
from mindearth.module import Trainer
from mindearth.data import DemData, Dataset

src 文件可以从DEM super-resolution/src下载。

[2]:
from src import init_model, plt_dem_data
from src import EvaluateCallBack, InferenceModule

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=0)

model、data和optimizer的参数可以通过加载DEM-SRNet.yaml文件获取。

[3]:
config = load_yaml_config('DEM-SRNet.yaml')

config['train']['distribute'] = False # 是否设置分布式并行
config['train']['amp_level'] = 'O2' # 设置混合精度等级
config["train"]["load_ckpt"] = False # 是否加载预训练checkpoint

config['data']['num_workers'] = 1  # 设置并行计算的进程数量
config['data']['epochs'] = 100 # 设置训练轮数

config['summary']["valid_frequency"] = 100 # 设置验证频率
config['summary']["summary_dir"] = './summary' # 设置模型存储路径

logger = create_logger(path=os.path.join(config['summary']["summary_dir"], "results.log"))

创建数据集

dataset下载训练数据集、验证数据集、测试数据集到当前目录./dataset

修改DEM-SRNet.yaml配置文件中的root_dir参数,该参数设置了数据集的路径。

./dataset中的目录结构如下所示:

.
├── train
│   └── train.h5
├── valid
│   └── valid.h5
├── test
│   └── test.h5

模型构建

初始化Dem模型。

[4]:
model = init_model(config)

损失函数

DEM-SRNet使用均方误差进行模型训练。

[5]:
loss_fn = nn.MSELoss()

模型训练

在本教程中,我们继承Trainer并重写get_callback成员函数,以便我们可以在训练过程中对测试数据集进行推理。

随着MindSpore版本>=1.10.1,我们可以使用函数式编程来训练神经网络。MindEarth提供了模型训练的训练接口。

[6]:
class DemSrTrainer(Trainer):
    r"""
    Self-define forecast model inherited from `Trainer`.

    Args:
        config (dict): parameters for training.
        model (Cell): network for training.
        loss_fn (str): user-defined loss function.
        logger (logging.RootLogger): tools for logging.

    Supported Platforms:
        ``Ascend`` ``GPU``

    """
    def __init__(self, config, model, loss_fn, logger):
        super(DemSrTrainer, self).__init__(config, model, loss_fn, logger, weather_data_source="DemSR")
        self.model = model
        self.optimizer_params = config["optimizer"]
        self.train_dataset, self.valid_dataset = self.get_dataset()
        self.optimizer = self.get_optimizer()
        self.solver = self.get_solver()

    def get_optimizer(self):
        r"""define the optimizer of the model, abstract method."""
        self.steps_per_epoch = self.train_dataset.get_dataset_size()
        if self.logger:
            self.logger.info(f'steps_per_epoch: {self.steps_per_epoch}')
        if self.optimizer_params['name']:
            optimizer = nn.Adam(self.model.trainable_params(),
                                learning_rate=Tensor(self.optimizer_params['learning_rate']))
        else:
            raise NotImplementedError(
                "self.optimizer_params['name'] not implemented, please overwrite get_optimizer()")
        return optimizer

    def get_callback(self):
        r"""define the callback of the model, abstract method."""
        pred_cb = EvaluateCallBack(self.model, self.valid_dataset, self.config, self.logger)
        return pred_cb

    def train(self):
        r""" train """
        callback_lst = [LossMonitor(), TimeMonitor(), self.ckpt_cb]
        if self.pred_cb:
            callback_lst.append(self.pred_cb)
        self.solver.train(epoch=config['data']['epoch_size'],
                          train_dataset=self.train_dataset,
                          callbacks=callback_lst,
                          dataset_sink_mode=True)
trainer = DemSrTrainer(config, model, loss_fn, logger)
2023-09-27 10:33:38,449 - 3395036387.py[line:27] - INFO: steps_per_epoch: 109
[7]:
trainer.train()
epoch: 1 step: 109, loss is 0.0018688203
Train epoch time: 55616.483 ms, per step time: 510.243 ms
epoch: 2 step: 109, loss is 0.0008327974
Train epoch time: 31303.473 ms, per step time: 287.188 ms
epoch: 3 step: 109, loss is 0.00022218125
...
epoch: 98 step: 109, loss is 1.3421039e-05
Train epoch time: 29786.506 ms, per step time: 273.271 ms
epoch: 99 step: 109, loss is 1.113452e-05
Train epoch time: 31082.307 ms, per step time: 285.159 ms
epoch: 100 step: 109, loss is 2.0731915e-05
Train epoch time: 30750.022 ms, per step time: 282.110 ms

模型评估及可视化

[8]:
params = load_checkpoint("./summary/ckpt/step_/DemSrNet_7-100_109.ckpt")
load_param_into_net(model, params)

inference_module = InferenceModule(model, config, logger)
test_dataset_generator = DemData(data_params=config["data"], run_mode='test')
test_dataset = Dataset(test_dataset_generator, distribute=False,
                       num_workers=config["data"]['num_workers'], shuffle=False)
test_dataset = test_dataset.create_dataset(config["data"]['batch_size'])
create_test_data = test_dataset.create_dict_iterator()

下图对第100个模型的真实值和预测值进行了可视化。

[9]:
data = next(create_test_data)

inputs = data['inputs']
labels = data['labels']

low_res = inputs[0].asnumpy()[0].astype(np.float32)
pred = inference_module.forecast(inputs)
pred = pred[0].asnumpy()[0].astype(np.float32)
label = labels[0].asnumpy()[0].astype(np.float32)

plt.figure(num='e_imshow', figsize=(15, 36))
plt.subplot(1, 3, 1)
plt_dem_data(low_res, "Low Reslution")
plt.subplot(1, 3, 2)
plt_dem_data(label, "Ground Truth")
plt.subplot(1, 3, 3)
plt_dem_data(pred, "Prediction")

../_images/dem-super-resolution_DEM-SRNet_20_0.png