PreDiff: 基于潜在扩散模型的降水短时预报
概述
传统的天气预报技术依赖于复杂的物理模型,这些模型不仅计算成本高昂,还要求深厚的专业知识支撑。然而,近十年来,随着地球时空观测数据的爆炸式增长,深度学习技术为构建数据驱动的预测模型开辟了新的道路。虽然这些模型在多种地球系统预测任务中展现出巨大潜力,但它们在管理不确定性和整合特定领域先验知识方面仍有不足,时常导致预测结果模糊不清或在物理上不可信。
为克服这些难题,来自香港科技大学的Gao Zhihan实现了PreDiff模型,专门用于实现概率性的时空预测。该流程融合了条件潜在扩散模型与显式的知识对齐机制,旨在生成既符合特定领域物理约束,又能精确捕捉时空变化的预测结果。通过这种方法,我们期望能够显著提升地球系统预测的准确性和可靠性。模型框架图如下图所示(图片来源于论文 PreDiff: Precipitation Nowcasting with Latent Diffusion Models):
训练的过程中,数据通过变分自编码器提取关键信息到隐空间,之后随机选择时间步生成对应时间步噪声,对数据进行加噪处理。之后将数据输入Earthformer-UNet进行降噪处理,Earthformer-UNet采用了UNet构架和Cuboid Attention,去除了Earthformer中连接Encoder和Decoder的Cross Attention结构。最后,将结果通过变分自解码器还原得到去噪后的数据,扩散模型通过反转预先定义的破坏原始数据的加噪过程,来学习数据分布。
技术路径
MindSpore Earth求解该问题的具体流程如下:
创建数据集
模型构建
损失函数
模型训练
模型评估与可视化
数据集可以在PreDiff/dataset下载数据并保存。
[1]:
import time
import os
import random
import json
from typing import Sequence, Union
import numpy as np
from einops import rearrange
import mindspore as ms
from mindspore import set_seed, context, ops, nn, mint
from mindspore.experimental import optim
from mindspore.train.serialization import save_checkpoint
下述src可以在PreDiff/src下载。
[2]:
from mindearth.utils import load_yaml_config
from src import (
prepare_output_directory,
configure_logging_system,
prepare_dataset,
init_model,
PreDiffModule,
DiffusionTrainer,
DiffusionInferrence
)
from src.sevir_dataset import SEVIRDataset
from src.visual import vis_sevir_seq
from src.utils import warmup_lambda
[3]:
set_seed(0)
np.random.seed(0)
random.seed(0)
可以在配置文件中配置模型、数据和优化器等参数。
[4]:
config = load_yaml_config("./configs/diffusion.yaml")
context.set_context(mode=ms.PYNATIVE_MODE)
ms.set_device(device_target="Ascend", device_id=0)
模型构建
模型初始化主要包括变分自编码器和earthformer的初始化。
[5]:
main_module = PreDiffModule(oc_file="./configs/diffusion.yaml")
main_module = init_model(module=main_module, config=config, mode="train")
output_dir = prepare_output_directory(config, "0")
logger = configure_logging_system(output_dir, config)
2025-04-07 10:32:11,466 - utils.py[line:820] - INFO: Process ID: 2231351
2025-04-07 10:32:11,467 - utils.py[line:821] - INFO: {'summary_dir': './summary/prediff/single_device0', 'eval_interval': 10, 'save_ckpt_epochs': 1, 'keep_ckpt_max': 100, 'ckpt_path': '', 'load_ckpt': False}
NoisyCuboidTransformerEncoder param_not_load: []
Cleared previous output directory: ./summary/prediff/single_device0
创建数据集
下载sevir-lr数据集到./dataset目录。
[6]:
dm, total_num_steps = prepare_dataset(config, PreDiffModule)
损失函数
PreDiff训练中使用mse作为loss计算,采用了梯度裁剪,并将过程封装在了DiffusionTrainer中。
[7]:
class DiffusionTrainer(nn.Cell):
"""
Class managing the training pipeline for diffusion models. Handles dataset processing,
optimizer configuration, gradient clipping, checkpoint saving, and logging.
"""
def __init__(self, main_module, dm, logger, config):
"""
Initialize trainer with model, data module, logger, and configuration.
Args:
main_module: Main diffusion model to be trained
dm: Data module providing training dataset
logger: Logging utility for training progress
config: Configuration dictionary containing hyperparameters
"""
super().__init__()
self.main_module = main_module
self.traindataset = dm.sevir_train
self.logger = logger
self.datasetprocessing = SEVIRDataset(
data_types=["vil"],
layout="NHWT",
rescale_method=config.get("rescale_method", "01"),
)
self.example_save_dir = config["summary"].get("summary_dir", "./summary")
self.fs = config["eval"].get("fs", 20)
self.label_offset = config["eval"].get("label_offset", [-0.5, 0.5])
self.label_avg_int = config["eval"].get("label_avg_int", False)
self.current_epoch = 0
self.learn_logvar = (
config.get("model", {}).get("diffusion", {}).get("learn_logvar", False)
)
self.logvar = main_module.logvar
self.maeloss = nn.MAELoss()
self.optim_config = config["optim"]
self.clip_norm = config.get("clip_norm", 2.0)
self.ckpt_dir = os.path.join(self.example_save_dir, "ckpt")
self.keep_ckpt_max = config["summary"].get("keep_ckpt_max", 100)
self.ckpt_history = []
self.grad_clip_fn = ops.clip_by_global_norm
self.optimizer = nn.Adam(params=self.main_module.main_model.trainable_params(), learning_rate=0.00001)
os.makedirs(self.ckpt_dir, exist_ok=True)
def train(self, total_steps: int):
"""Execute complete training pipeline."""
self.main_module.main_model.set_train(True)
self.logger.info("Initializing training process...")
loss_processor = Trainonestepforward(self.main_module)
grad_func = ms.ops.value_and_grad(loss_processor, None, self.main_module.main_model.trainable_params())
for epoch in range(self.optim_config["max_epochs"]):
epoch_loss = 0.0
epoch_start = time.time()
iterator = self.traindataset.create_dict_iterator()
assert iterator, "dataset is empty"
batch_idx = 0
for batch_idx, batch in enumerate(iterator):
processed_data = self.datasetprocessing.process_data(batch["vil"])
loss_value, gradients = grad_func(processed_data)
clipped_grads = self.grad_clip_fn(gradients, self.clip_norm)
self.optimizer(clipped_grads)
epoch_loss += loss_value.asnumpy()
self.logger.info(
f"epoch: {epoch} step: {batch_idx}, loss: {loss_value}"
)
self._save_ckpt(epoch)
epoch_time = time.time() - epoch_start
self.logger.info(
f"Epoch {epoch} completed in {epoch_time:.2f}s | "
f"Avg Loss: {epoch_loss/(batch_idx+1):.4f}"
)
def _get_optimizer(self, total_steps: int):
"""Configure optimization components"""
trainable_params = list(self.main_module.main_model.trainable_params())
if self.learn_logvar:
self.logger.info("Including log variance parameters")
trainable_params.append(self.logvar)
optimizer = optim.AdamW(
trainable_params,
lr=self.optim_config["lr"],
betas=tuple(self.optim_config["betas"]),
)
warmup_steps = int(self.optim_config["warmup_percentage"] * total_steps)
scheduler = self._create_lr_scheduler(optimizer, total_steps, warmup_steps)
return optimizer, scheduler
def _create_lr_scheduler(self, optimizer, total_steps: int, warmup_steps: int):
"""Build learning rate scheduler"""
warmup_scheduler = optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=warmup_lambda(
warmup_steps=warmup_steps,
min_lr_ratio=self.optim_config["warmup_min_lr_ratio"],
),
)
cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=total_steps - warmup_steps,
eta_min=self.optim_config["min_lr_ratio"] * self.optim_config["lr"],
)
return optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_steps],
)
def _save_ckpt(self, epoch: int):
"""Save model ckpt with rotation policy"""
ckpt_file = f"diffusion_epoch{epoch}.ckpt"
ckpt_path = os.path.join(self.ckpt_dir, ckpt_file)
save_checkpoint(self.main_module.main_model, ckpt_path)
self.ckpt_history.append(ckpt_path)
if len(self.ckpt_history) > self.keep_ckpt_max:
removed_ckpt = self.ckpt_history.pop(0)
os.remove(removed_ckpt)
self.logger.info(f"Removed outdated ckpt: {removed_ckpt}")
class Trainonestepforward(nn.Cell):
"""A neural network cell that performs one training step forward pass for a diffusion model.
This class encapsulates the forward pass computation for training a diffusion model,
handling the input processing, latent space encoding, conditioning, and loss calculation.
Args:
model (nn.Cell): The main diffusion model containing the necessary submodules
for encoding, conditioning, and loss computation.
"""
def __init__(self, model):
super().__init__()
self.main_module = model
def construct(self, inputs):
"""Perform one forward training step and compute the loss."""
x, condition = self.main_module.get_input(inputs)
x = x.transpose(0, 1, 4, 2, 3)
n, t_, c_, h, w = x.shape
x = x.reshape(n * t_, c_, h, w)
z = self.main_module.encode_first_stage(x)
_, c_z, h_z, w_z = z.shape
z = z.reshape(n, -1, c_z, h_z, w_z)
z = z.transpose(0, 1, 3, 4, 2)
t = ops.randint(0, self.main_module.num_timesteps, (n,)).long()
zc = self.main_module.cond_stage_forward(condition)
loss = self.main_module.p_losses(z, zc, t, noise=None)
return loss
模型训练
在本教程中,我们使用DiffusionTrainer对模型进行训练。
[8]:
trainer = DiffusionTrainer(
main_module=main_module, dm=dm, logger=logger, config=config
)
trainer.train(total_steps=total_num_steps)
2025-04-07 10:32:36,351 - 4106154625.py[line:46] - INFO: Initializing training process...
.........2025-04-07 10:34:09,378 - 4106154625.py[line:64] - INFO: epoch: 0 step: 0, loss: 1.0008465
.2025-04-07 10:34:16,871 - 4106154625.py[line:64] - INFO: epoch: 0 step: 1, loss: 1.0023363
2025-04-07 10:34:18,724 - 4106154625.py[line:64] - INFO: epoch: 0 step: 2, loss: 1.0009086
.2025-04-07 10:34:20,513 - 4106154625.py[line:64] - INFO: epoch: 0 step: 3, loss: 0.99787366
2025-04-07 10:34:22,280 - 4106154625.py[line:64] - INFO: epoch: 0 step: 4, loss: 0.9979043
2025-04-07 10:34:24,072 - 4106154625.py[line:64] - INFO: epoch: 0 step: 5, loss: 0.99897844
2025-04-07 10:34:25,864 - 4106154625.py[line:64] - INFO: epoch: 0 step: 6, loss: 1.0021904
2025-04-07 10:34:27,709 - 4106154625.py[line:64] - INFO: epoch: 0 step: 7, loss: 0.9984627
2025-04-07 10:34:29,578 - 4106154625.py[line:64] - INFO: epoch: 0 step: 8, loss: 0.9952746
.2025-04-07 10:34:31,432 - 4106154625.py[line:64] - INFO: epoch: 0 step: 9, loss: 1.0003254
2025-04-07 10:34:33,402 - 4106154625.py[line:64] - INFO: epoch: 0 step: 10, loss: 1.0020428
2025-04-07 10:34:35,218 - 4106154625.py[line:64] - INFO: epoch: 0 step: 11, loss: 0.99563503
2025-04-07 10:34:37,149 - 4106154625.py[line:64] - INFO: epoch: 0 step: 12, loss: 0.99336195
2025-04-07 10:34:38,949 - 4106154625.py[line:64] - INFO: epoch: 0 step: 13, loss: 1.0023757
......2025-04-07 13:39:55,859 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1247, loss: 0.021378823
2025-04-07 13:39:57,754 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1248, loss: 0.01565772
2025-04-07 13:39:59,606 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1249, loss: 0.012067624
2025-04-07 13:40:01,396 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1250, loss: 0.017700804
2025-04-07 13:40:03,181 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1251, loss: 0.06254268
2025-04-07 13:40:04,945 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1252, loss: 0.013293369
.2025-04-07 13:40:06,770 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1253, loss: 0.026906993
2025-04-07 13:40:08,644 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1254, loss: 0.18210539
2025-04-07 13:40:10,593 - 4106154625.py[line:64] - INFO: epoch: 4 step: 1255, loss: 0.024170894
2025-04-07 13:40:12,430 - 4106154625.py[line:69] - INFO: Epoch 4 completed in 2274.61s | Avg Loss: 0.0517
模型评估与可视化
完成训练后,我们使用第5个ckpt进行推理。下述展示了预测值与实际值之间的误差和各项指标。
[14]:
def get_alignment_kwargs_avg_x(target_seq):
"""Generate alignment parameters for guided sampling"""
batch_size = target_seq.shape[0]
avg_intensity = mint.mean(target_seq.view(batch_size, -1), dim=1, keepdim=True)
return {"avg_x_gt": avg_intensity * 2.0}
class DiffusionInferrence(nn.Cell):
"""
Class managing model inference and evaluation processes. Handles loading checkpoints,
generating predictions, calculating evaluation metrics, and saving visualization results.
"""
def __init__(self, main_module, dm, logger, config):
"""
Initialize inference manager with model, data module, logger, and configuration.
Args:
main_module: Main diffusion model for inference
dm: Data module providing test dataset
logger: Logging utility for evaluation progress
config: Configuration dictionary containing evaluation parameters
"""
super().__init__()
self.num_samples = config["eval"].get("num_samples_per_context", 1)
self.eval_example_only = config["eval"].get("eval_example_only", True)
self.alignment_type = (
config.get("model", {}).get("align", {}).get("alignment_type", "avg_x")
)
self.use_alignment = self.alignment_type is not None
self.eval_aligned = config["eval"].get("eval_aligned", True)
self.eval_unaligned = config["eval"].get("eval_unaligned", True)
self.num_samples_per_context = config["eval"].get("num_samples_per_context", 1)
self.logging_prefix = config["logging"].get("logging_prefix", "PreDiff")
self.test_example_data_idx_list = [48]
self.main_module = main_module
self.testdataset = dm.sevir_test
self.logger = logger
self.datasetprocessing = SEVIRDataset(
data_types=["vil"],
layout="NHWT",
rescale_method=config.get("rescale_method", "01"),
)
self.example_save_dir = config["summary"].get("summary_dir", "./summary")
self.fs = config["eval"].get("fs", 20)
self.label_offset = config["eval"].get("label_offset", [-0.5, 0.5])
self.label_avg_int = config["eval"].get("label_avg_int", False)
self.current_epoch = 0
self.learn_logvar = (
config.get("model", {}).get("diffusion", {}).get("learn_logvar", False)
)
self.logvar = main_module.logvar
self.maeloss = nn.MAELoss()
self.test_metrics = {
"step": 0,
"mse": 0.0,
"mae": 0.0,
"ssim": 0.0,
"mse_kc": 0.0,
"mae_kc": 0.0,
}
def test(self):
"""Execute complete evaluation pipeline."""
self.logger.info("============== Start Test ==============")
self.start_time = time.time()
for batch_idx, item in enumerate(self.testdataset.create_dict_iterator()):
self.test_metrics = self._test_onestep(item, batch_idx, self.test_metrics)
self._finalize_test(self.test_metrics)
def _test_onestep(self, item, batch_idx, metrics):
"""Process one test batch and update evaluation metrics."""
data_idx = int(batch_idx * 2)
if not self._should_test_onestep(data_idx):
return metrics
data = item.get("vil")
data = self.datasetprocessing.process_data(data)
target_seq, cond, context_seq = self._get_model_inputs(data)
aligned_preds, unaligned_preds = self._generate_predictions(
cond, target_seq
)
metrics = self._update_metrics(
aligned_preds, unaligned_preds, target_seq, metrics
)
self._plt_pred(
data_idx,
context_seq,
target_seq,
aligned_preds,
unaligned_preds,
metrics["step"],
)
metrics["step"] += 1
return metrics
def _should_test_onestep(self, data_idx):
"""Determine if evaluation should be performed on current data index."""
return (not self.eval_example_only) or (
data_idx in self.test_example_data_idx_list
)
def _get_model_inputs(self, data):
"""Extract and prepare model inputs from raw data."""
target_seq, cond, context_seq = self.main_module.get_input(
data, return_verbose=True
)
return target_seq, cond, context_seq
def _generate_predictions(self, cond, target_seq):
"""Generate both aligned and unaligned predictions from the model."""
aligned_preds = []
unaligned_preds = []
for _ in range(self.num_samples_per_context):
if self.use_alignment and self.eval_aligned:
aligned_pred = self._sample_with_alignment(
cond, target_seq
)
aligned_preds.append(aligned_pred)
if self.eval_unaligned:
unaligned_pred = self._sample_without_alignment(cond)
unaligned_preds.append(unaligned_pred)
return aligned_preds, unaligned_preds
def _sample_with_alignment(self, cond, target_seq):
"""Generate predictions using alignment mechanism."""
alignment_kwargs = get_alignment_kwargs_avg_x(target_seq)
pred_seq = self.main_module.sample(
cond=cond,
batch_size=cond["y"].shape[0],
return_intermediates=False,
use_alignment=True,
alignment_kwargs=alignment_kwargs,
verbose=False,
)
if pred_seq.dtype != ms.float32:
pred_seq = pred_seq.float()
return pred_seq
def _sample_without_alignment(self, cond):
"""Generate predictions without alignment."""
pred_seq = self.main_module.sample(
cond=cond,
batch_size=cond["y"].shape[0],
return_intermediates=False,
verbose=False,
)
if pred_seq.dtype != ms.float32:
pred_seq = pred_seq.float()
return pred_seq
def _update_metrics(self, aligned_preds, unaligned_preds, target_seq, metrics):
"""Update evaluation metrics with new predictions."""
for pred in aligned_preds:
metrics["mse_kc"] += ops.mse_loss(pred, target_seq)
metrics["mae_kc"] += self.maeloss(pred, target_seq)
self.main_module.test_aligned_score.update(pred, target_seq)
for pred in unaligned_preds:
metrics["mse"] += ops.mse_loss(pred, target_seq)
metrics["mae"] += self.maeloss(pred, target_seq)
self.main_module.test_score.update(pred, target_seq)
pred_bchw = self._convert_to_bchw(pred)
target_bchw = self._convert_to_bchw(target_seq)
metrics["ssim"] += self.main_module.test_ssim(pred_bchw, target_bchw)[0]
return metrics
def _convert_to_bchw(self, tensor):
"""Convert tensor to batch-channel-height-width format for metrics."""
return rearrange(tensor.asnumpy(), "b t h w c -> (b t) c h w")
def _plt_pred(
self, data_idx, context_seq, target_seq, aligned_preds, unaligned_preds, step
):
"""Generate and save visualization of predictions."""
pred_sequences = [pred[0].asnumpy() for pred in aligned_preds + unaligned_preds]
pred_labels = [
f"{self.logging_prefix}_aligned_pred_{i}" for i in range(len(aligned_preds))
] + [f"{self.logging_prefix}_pred_{i}" for i in range(len(unaligned_preds))]
self.save_vis_step_end(
data_idx=data_idx,
context_seq=context_seq[0].asnumpy(),
target_seq=target_seq[0].asnumpy(),
pred_seq=pred_sequences,
pred_label=pred_labels,
mode="test",
suffix=f"_step_{step}",
)
def _finalize_test(self, metrics):
"""Complete test process and log final metrics."""
total_time = (time.time() - self.start_time) * 1000
self.logger.info(f"test cost: {total_time:.2f} ms")
self._compute_total_metrics(metrics)
self.logger.info("============== Test Completed ==============")
def _compute_total_metrics(self, metrics):
"""log_metrics"""
step_count = max(metrics["step"], 1)
if self.eval_unaligned:
self.logger.info(f"MSE: {metrics['mse'] / step_count}")
self.logger.info(f"MAE: {metrics['mae'] / step_count}")
self.logger.info(f"SSIM: {metrics['ssim'] / step_count}")
test_score = self.main_module.test_score.eval()
self.logger.info("SCORE:\n%s", json.dumps(test_score, indent=4))
if self.use_alignment:
self.logger.info(f"KC_MSE: {metrics['mse_kc'] / step_count}")
self.logger.info(f"KC_MAE: {metrics['mae_kc'] / step_count}")
aligned_score = self.main_module.test_aligned_score.eval()
self.logger.info("KC_SCORE:\n%s", json.dumps(aligned_score, indent=4))
def save_vis_step_end(
self,
data_idx: int,
context_seq: np.ndarray,
target_seq: np.ndarray,
pred_seq: Union[np.ndarray, Sequence[np.ndarray]],
pred_label: Union[str, Sequence[str]] = None,
mode: str = "train",
prefix: str = "",
suffix: str = "",
):
"""Save visualization of predictions with context and target."""
example_data_idx_list = self.test_example_data_idx_list
if isinstance(pred_seq, Sequence):
seq_list = [context_seq, target_seq] + list(pred_seq)
label_list = ["context", "target"] + pred_label
else:
seq_list = [context_seq, target_seq, pred_seq]
label_list = ["context", "target", pred_label]
if data_idx in example_data_idx_list:
png_save_name = f"{prefix}{mode}_data_{data_idx}{suffix}.png"
vis_sevir_seq(
save_path=os.path.join(self.example_save_dir, png_save_name),
seq=seq_list,
label=label_list,
interval_real_time=10,
plot_stride=1,
fs=self.fs,
label_offset=self.label_offset,
label_avg_int=self.label_avg_int,
)
[15]:
main_module.main_model.set_train(False)
params = ms.load_checkpoint("/PreDiff/summary/prediff/single_device0/ckpt/diffusion_epoch4.ckpt")
a, b = ms.load_param_into_net(main_module.main_model, params)
print(b)
tester = DiffusionInferrence(
main_module=main_module, dm=dm, logger=logger, config=config
)
tester.test()
2025-04-07 14:04:16,558 - 2610859736.py[line:66] - INFO: ============== Start Test ==============
[]
..2025-04-07 14:10:31,931 - 2610859736.py[line:201] - INFO: test cost: 375371.60 ms
2025-04-07 14:10:31,937 - 2610859736.py[line:215] - INFO: KC_MSE: 0.0036273836
2025-04-07 14:10:31,939 - 2610859736.py[line:216] - INFO: KC_MAE: 0.017427118
2025-04-07 14:10:31,955 - 2610859736.py[line:218] - INFO: KC_SCORE:
{
"16": {
"csi": 0.2715393900871277,
"pod": 0.5063194632530212,
"sucr": 0.369321346282959,
"bias": 3.9119162559509277
},
"74": {
"csi": 0.15696434676647186,
"pod": 0.17386901378631592,
"sucr": 0.6175059080123901,
"bias": 0.16501028835773468
}
}
2025-04-07 14:10:31,956 - 2610859736.py[line:203] - INFO: ============== Test Completed ==============