admin管理员组

文章数量:1635425

本文首发知乎:
https://zhuanlan.zhihu/p/676305234

相关文章:
一键运行 GraphCast (在 AutoDL 或者其他新的环境)【动手学】
GraphCast 是一种基于机器学习和图神经网络 (GNN) 的天气预报系统。该系统已被包括欧洲中期天气预报中心(ECMWF) 在内的气象机构测试。这是一种先进的人工智能模型,能够以前所未有的准确度进行中期天气预报。GraphCast 最多可以提前 10 天预测天气状况,比行业黄金标准天气模拟系统 - 由欧洲中期天气预报中心 (ECMWF) 制作的高分辨率预报 (HRES) 更准确、更快。

这种模型巧妙的使用递归的正则二十面体进行六次迭代,所产生的多面体替代原有的地球经纬度网络。在相同分辨率条件下,图模型节点数量从一百万(1, 038, 240)下降到 4 万(40, 962)。使得模型可以在 GNN 框架下学习大规模多特征的复杂数据。

图1:进行6次迭代后的网络特征


图2:示例代码生成的气象预测结果

Cite From:
Remi Lam et al. ,Learning skillful medium-range global weather forecasting.Science382,1416-1421(2023).DOI:10.1126/science.adi2336


这是从 https://google-deepmind/graphcast 复现的项目。由 https://github/sfsun67 改写和调试。

AutoDL 是国内的一家云计算平台,网址是https://www.autodl

应该有类似的文件结构,这里的数据由 Google Cloud Bucket (https://console.cloud.google/storage/browser/dm_graphcast 提供。模型权重、标准化统计和示例输入可在Google Cloud Bucket上找到。完整的模型训练需要下载ERA5数据集,该数据集可从ECMWF获得。

.
├── code
│   ├── GraphCast-from-Ground-Zero
│       ├──graphcast
│       ├──tree
│       ├──wrapt
│       ├──graphcast_demo.ipynb
│       ├──README.md
│       ├──setup.py
│       ├──...
├── data
│   ├── dataset
│       ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc
│       ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-04.nc
│       ├──dataset-source-era5_date-2022-01-01_res-1.0_levels-13_steps-12.nc
│       ├──...
│   ├── params
│       ├──params-GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz
│       ├──params-GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
│       ├──...
│   ├── stats
│       ├──stats-mean_by_level.nc
│       ├──...
└────── 

PS:

  1. Python 要使用3.10版本。老版本会出现函数调用失效的问题。
  2. 需要仔细核对包的版本,防止出现意外的错误。例如, xarray 只能使用 2023.7.0 版本,其他版本会出现错误。
  3. 需要仔细核对所有包是否安装正确。未安装的包会导致意外错误。例如,tree 和 wrapt 是两个 GraphCast 所必需的包,但是并不在源文件中。例如,tree 和 wrapt 中的 .os 文件未导入,会引发循环调用。他们的原始文件可以在 Colaboratory(https://colab.research.google/github/deepmind/graphcast/blob/master/graphcast_demo.ipynb) 的环境中找到。

代码在如下机器上测试
4. GPU: TITAN Xp 12GB; CPU: Xeon® E5-2680 v4; JAX / 0.3.10 / 3.8(ubuntu18.04) / 11.1
5. GPU: V100-SXM2-32GB 32GB; CPU: Xeon® Platinum 8255C; JAX / 0.3.10 / 3.8(ubuntu18.04) / 11.1
6. GPU: RTX 2080 Ti(11GB); CPU: Xeon® Platinum 8255C; JAX / 0.3.10 / 3.8(ubuntu18.04) / 11.1

原始版权信息:

版权所有 2023 年 DeepMind Technologies Limited。

根据 Apache 许可证第 2.0 版("许可证")获得许可;除非符合许可证的规定,否则您不得使用此文件。您可以在 http://www.apache/licenses/LICENSE-2.0 获取许可证的副本。

除非适用法律要求或书面同意,根据许可证分发的软件是基于 "按原样" 分发的,没有任何明示或暗示的担保或条件。有关许可证下的具体语言,请参见许可证中的权限和限制。

-------------------------------------------------------------------

将 Python 版本更新到 3.10.

GraphCast 需要 Python >= 3.10 。推荐 Python 3.10。

在终端中,新建一个名为 GraphCast 的环境。

参考代码如下:

# 更新 conda (可选)
conda update -n base -c defaults conda

# 在新环境 GraphCast 中安装 python=3.10  
conda create -n GraphCast python=3.10    

# 更新bashrc中的环境变量
conda init bash && source /root/.bashrc

# 激活新的环境
conda activate GraphCast

# 验证版本
python --version

# 在 Jupyter 中注册 Python 3.10 环境
# 安装 ipykernel 包
conda install ipykernel

# 注册的 Python 3.10 环境的内核名称
python -m ipykernel install --user --name=GraphCast-python3.10

注意:Jupyter 注册 Python 3.10 环境后,重启jupyter,使用新的内核 GraphCast-python3.10。

安装和初始化

# 学术资源加速 https://www.autodl/docs/network_turbo/  .
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

这一步将使用 shapely 安装环境。为了避免出现ERROR: 无法为 shapely 构建轮子,而安装基于 pyproject.toml 的项目需要轮子。

!pip uninstall -y shapely
!conda install -y shapely
!pip uninstall -y shapely

Pip 安装 graphcast 和其他依赖项

%pip install --upgrade https://github.com/deepmind/graphcast/archive/master.zip

cartopy 崩溃的解决方法

!pip uninstall -y shapely
!pip install shapely --no-binary shapely

安装其他依赖项,并解决 xarray 的版本问题。

这里需要将xarray的版本从2023.12.0(2023年12月30日安装)降低到2023.7.0,否则会报错。

!conda install -y -c conda-forge ipywidgets
!pip uninstall -y xarray
!pip install xarray==2023.7.0

导入库

import dataclasses
import datetime
import functools
import math
import re
from typing import Optional

import cartopy.crs as ccrs
#from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import data_utils
from graphcast import graphcast
from graphcast import normalization
from graphcast import rollout
from graphcast import xarray_jax
from graphcast import xarray_tree
from IPython.display import HTML
import ipywidgets as widgets
import haiku as hk
import jax
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import xarray




def parse_file_parts(file_name):
  return dict(part.split("-", 1) for part in file_name.split("_"))

载入绘图函数

def select(
    data: xarray.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xarray.Dataset:
  data = data[variable]
  if "batch" in data.dims:
    data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  if level is not None and "level" in data.coords:
    data = data.sel(level=level)
  return data

def scale(
    data: xarray.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xarray.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xarray.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)
  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()

  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
    if "time" in first_data.dims:
      td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
      figure.suptitle(f"{fig_title}, {td}", fontsize=16)
    else:
      figure.suptitle(fig_title, fontsize=16)
    for im, (plot_data, norm, cmap) in zip(images, data.values()):
      im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())

加载数据并初始化模型

载入模型参数

选择两种获取模型参数的方式之一:

  • random:您将获得随机预测,但您可以更改模型架构,这可能会使其运行更快或适应您的设备。
  • checkpoint:您将获得明智的预测,但受限于模型训练时使用的架构,这可能不适合您的设备。特别是生成梯度会使用大量内存,因此您至少需要25GB的内存(TPUv4或A100)。

检查点在一些方面有所不同:

  • 网格大小指定了地球的内部图形表示。较小的网格将运行更快,但输出将更差。网格大小不影响模型的参数数量。
  • 分辨率和压力级别的数量必须匹配数据。较低的分辨率和较少的级别会运行得更快。数据分辨率仅影响编码器/解码器。
  • 我们的所有模型都预测降水。然而,ERA5包含降水,而HRES不包含。我们标记为 “ERA5” 的模型将降水作为输入,并期望以ERA5数据作为输入,而标记为 “ERA5-HRES” 的模型不以降水作为输入,并专门训练以HRES-fc0作为输入(请参阅下面的数据部分)。

我们提供三个预训练模型:

  1. GraphCast,用于GraphCast论文的高分辨率模型(0.25度分辨率,37个压力级别),在1979年至2017年间使用ERA5数据进行训练,

  2. GraphCast_small,GraphCast的较小低分辨率版本(1度分辨率,13个压力级别和较小的网格),在1979年至2015年间使用ERA5数据进行训练,适用于具有较低内存和计算约束的模型运行,

  3. GraphCast_operational,一个高分辨率模型(0.25度分辨率,13个压力级别),在1979年至2017年使用ERA5数据进行预训练,并在2016年至2021年间使用HRES数据进行微调。此模型可以从HRES数据初始化(不需要降水输入)。

选择模型

# Rewrite by S.F. Sune, https://github/sfsun67.
'''
    我们有三种训练好的模型可供选择, 需要从https://console.cloud.google/storage/browser/dm_graphcast准备:
    GraphCast - ERA5 1979-2017 - resolution 0.25 - pressure levels 37 - mesh 2to6 - precipitation input and output.npz
    GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz
    GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
'''
# 在此路径 /root/autodl-fs/data/params 中查找结果,并列出 "params/"中所有文件的名称,去掉名称中的 "params/"perfix。

import os
import glob

# 定义数据目录,请替换成自己的目录。
dir_path_params = "/root/autodl-fs/data/params"

# Use glob to get all file paths in the directory
file_paths_params = glob.glob(os.path.join(dir_path_params, "*"))

# Remove the directory path and the ".../params/" prefix from each file name
params_file_options = [os.path.basename(path) for path in file_paths_params]


random_mesh_size = widgets.IntSlider(
    value=4, min=4, max=6, description="Mesh size:")
random_gnn_msg_steps = widgets.IntSlider(
    value=4, min=1, max=32, description="GNN message steps:")
random_latent_size = widgets.Dropdown(
    options=[int(2**i) for i in range(4, 10)], value=32,description="Latent size:")
random_levels = widgets.Dropdown(
    options=[13, 37], value=13, description="Pressure levels:")


params_file = widgets.Dropdown(
    options=params_file_options,
    description="Params file:",
    layout={"width": "max-content"})

source_tab = widgets.Tab([
    widgets.VBox([
        random_mesh_size,
        random_gnn_msg_steps,
        random_latent_size,
        random_levels,
    ]),
    params_file,
])
source_tab.set_title(0, "随机参数权重(Random)")
source_tab.set_title(1, "预训练权重(Checkpoint)")
widgets.VBox([
    source_tab,
    widgets.Label(value="运行下一个单元格以加载模型。重新运行该单元格将清除您的选择。")
])

加载模型

source = source_tab.get_title(source_tab.selected_index)

if source == "随机参数权重(Random)":
  params = None  # Filled in below
  state = {}
  model_config = graphcast.ModelConfig(
      resolution=0,
      mesh_size=random_mesh_size.value,
      latent_size=random_latent_size.value,
      gnn_msg_steps=random_gnn_msg_steps.value,
      hidden_layers=1,
      radius_query_fraction_edge_length=0.6)
  task_config = graphcast.TaskConfig(
      input_variables=graphcast.TASK.input_variables,
      target_variables=graphcast.TASK.target_variables,
      forcing_variables=graphcast.TASK.forcing_variables,
      pressure_levels=graphcast.


PRESSURE_LEVELS[random_levels.value],
      input_duration=graphcast.TASK.input_duration,
  )
else:
  assert source == "预训练权重(Checkpoint)"
  '''with gcs_bucket.blob(f"params/{params_file.value}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)'''
  
  with open(f"{dir_path_params}/{params_file.value}", "rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)
  
  params = ckpt.params
  state = {}

  model_config = ckpt.model_config
  task_config = ckpt.task_config
  print("模型描述:\n", ckpt.description, "\n")
  print("模型许可信息:\n", ckpt.license, "\n")

model_config

载入示例数据

有几个示例数据集可用,在几个坐标轴上各不相同:

  • 来源:fake、era5、hres
  • 分辨率:0.25度、1度、6度
  • 级别:13, 37
  • 步数:包含多少个时间步

并非所有组合都可用。

  • 由于加载内存的要求,较高分辨率只适用于较少的步数。
  • HRES 只有 0.25 度,13 个压力等级。

数据分辨率必须与加载的模型相匹配。

对基础数据集进行了一些转换:

  • 我们累积了 6 个小时的降水量,而不是默认的 1 个小时。
  • 对于 HRES 数据,每个时间步对应 HRES 在前导时间 0 的预报,实际上提供了 HRES 的 “初始化”。有关详细描述,请参见 GraphCast 论文中的 HRES-fc0。请注意,HRES 无法提供 6 小时的累积降水量,因此我们的模型以 HRES 输入不依赖于降水。但由于我们的模型可以预测降水,因此在示例数据中包含了 ERA5 降水量,以作为地面真实情况的示例。
  • 我们在数据中加入了 ERA5 的 “toa_incident_solar_radiation”。我们的模型使用 -6h、0h 和 +6h 辐射作为每 1 步预测的强迫项。在运行中,如果没有现成的 +6h 辐射,可以使用诸如 pysolar 等软件包计算辐射。

获取和筛选可用示例数据的列表

# Rewrite by S.F. Sune, https://github/sfsun67.
# 在“/root/autodl-fs/data/dataset”路径下查找结果,并列出“dataset/”中所有文件的名称列表,去掉“dataset/”前缀。

# 定义数据目录,请替换成自己的目录。
dir_path_dataset = "/root/autodl-fs/data/dataset"

# Use glob to get all file paths in the directory
file_paths_dataset = glob.glob(os.path.join(dir_path_dataset, "*"))

# Remove the directory path and the ".../params/" prefix from each file name
dataset_file_options = [os.path.basename(path) for path in file_paths_dataset]
#print("dataset_file_options: ", dataset_file_options)

# Remove "dataset-" prefix from each file name
dataset_file_options = [name.removeprefix("dataset-") for name in dataset_file_options]


def data_valid_for_model(
    file_name: str, model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):
  file_parts = parse_file_parts(file_name.removesuffix(".nc"))
  #print("file_parts: ", file_parts)
  return (
      model_config.resolution in (0, float(file_parts["res"])) and
      len(task_config.pressure_levels) == int(file_parts["levels"]) and
      (
          ("total_precipitation_6hr" in task_config.input_variables and
           file_parts["source"] in ("era5", "fake")) or
          ("total_precipitation_6hr" not in task_config.input_variables and
           file_parts["source"] in ("hres", "fake"))
      )
  )


dataset_file = widgets.Dropdown(
    options=[
        (", ".join([f"{k}: {v}" for k, v in parse_file_parts(option.removesuffix(".nc")).items()]), option)
        for option in dataset_file_options
        if data_valid_for_model(option, model_config, task_config)
    ],
    description="数据文件:",
    layout={"width": "max-content"})
widgets.VBox([
    dataset_file,
    widgets.Label(value="运行下一个单元格以加载数据集。重新运行此单元格将清除您的选择,并重新筛选与您的模型匹配的数据集。")
])

加载气象数据

if not data_valid_for_model(dataset_file.value, model_config, task_config):
  raise ValueError(
      "Invalid dataset file, rerun the cell above and choose a valid dataset file.")

'''with gcs_bucket.blob(f"dataset/{dataset_file.value}").open("rb") as f:
  example_batch = xarray.load_dataset(f)pute()'''

with open(f"{dir_path_dataset}/dataset-{dataset_file.value}", "rb") as f:
  example_batch = xarray.load_dataset(f).compute()

assert example_batch.dims["time"] >= 3  # 2 for input, >=1 for targets

print(", ".join([f"{k}: {v}" for k, v in parse_file_parts(dataset_file.value.removesuffix(".nc")).items()]))

example_batch

选择绘图数据

plot_example_variable = widgets.Dropdown(
    options=example_batch.data_vars.keys(),
    value="2m_temperature",
    description="变量")
plot_example_level = widgets.Dropdown(
    options=example_batch.coords["level"].values,
    value=500,
    description="级别")
plot_example_robust = widgets.Checkbox(value=True, description="鲁棒性")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=example_batch.dims["time"], value=example_batch.dims["time"],
    description="最大步")

widgets.VBox([
    plot_example_variable,
    plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="运行下一个单元格以绘制数据。重新运行此单元格将清除您的选择。")
])

绘制示例数据

plot_size = 7

data = {
    " ": scale(select(example_batch, plot_example_variable.value, plot_example_level.value, plot_example_max_steps.value),
              robust=plot_example_robust.value),
}
fig_title = plot_example_variable.value
if "等级" in example_batch[plot_example_variable.value].coords:
  fig_title += f" at {plot_example_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_example_robust.value)

选择要提取的训练和评估数据

train_steps = widgets.IntSlider(
    value=1, min=1, max=example_batch.sizes["time"]-2, description="训练步数")
eval_steps = widgets.IntSlider(
    value=example_batch.sizes["time"]-2, min=1, max=example_batch.sizes["time"]-2, description="评估步数")

widgets.VBox([
    train_steps,
    eval_steps,
    widgets.Label(value="运行下一个单元格以提取数据。重新运行此单元格将清除您的选择。")
])

提取训练和评估数据

train_inputs, train_targets, train_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{train_steps.value*6}h"),
    **dataclasses.asdict(task_config))

eval_inputs, eval_targets, eval_forcings = data_utils.extract_inputs_targets_forcings(
    example_batch, target_lead_times=slice("6h", f"{eval_steps.value*6}h"),
    **dataclasses.asdict(task_config))

print("所有示例:  ", example_batch.dims.mapping)
print("训练输入:  ", train_inputs.dims.mapping)
print("训练目标: ", train_targets.dims.mapping)
print("训练强迫:", train_forcings.dims.mapping)
print("评估输入:   ", eval_inputs.dims.mapping)
print("评估目标:  ", eval_targets.dims.mapping)
print("评估强迫项: ", eval_forcings.dims.mapping)

加载规范化数据

# Rewrite by S.F. Sune, https://github/sfsun67.
dir_path_stats = "/root/autodl-fs/data/stats"

with open(f"{dir_path_stats}/stats-diffs_stddev_by_level.nc", "rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(f"{dir_path_stats}/stats-mean_by_level.nc", "rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with open(f"{dir_path_stats}/stats-stddev_by_level.nc", "rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()

构建 jitted 函数,并可能初始化随机权重

# 构建模型并初始化权重

# 模型组网
def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
  """Constructs and wraps the GraphCast Predictor."""
  # Deeper one-step predictor.
  predictor = graphcast.GraphCast(model_config, task_config)

  # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
  # from/to float32 to/from BFloat16.
  predictor = casting.Bfloat16Cast(predictor)

  # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
  # BFloat16 happens after applying normalization to the inputs/targets.
  predictor = normalization.InputsAndResiduals(
      predictor,
      diffs_stddev_by_level=diffs_stddev_by_level,
      mean_by_level=mean_by_level,
      stddev_by_level=stddev_by_level)

  # Wraps everything so the one-step model can produce trajectories.
  predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
  return predictor

# 前向运算
@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)
  return predictor(inputs, targets_template=targets_template, forcings=forcings)

# 计算损失函数
@hk.transform_with_state    # used to convert a pure function into a stateful function
def loss_fn(model_config, task_config, inputs, targets, forcings):
  predictor = construct_wrapped_graphcast(model_config, task_config)    # constructs and wraps a GraphCast Predictor, which is a model used for making predictions in a graph-based machine learning task.
  loss, diagnostics = predictor.loss(inputs, targets, forcings)
  return xarray_tree.map_structure(
      lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
      (loss, diagnostics))

# 计算梯度
def grads_fn(params, state, model_config, task_config, inputs, targets, forcings):
  def _aux(params, state, i, t, f):
    (loss, diagnostics), next_state = loss_fn.apply(
        params, state, jax.random.PRNGKey(0), model_config, task_config,
        i, t, f)
    return loss, (diagnostics, next_state)
  (loss, (diagnostics, next_state)), grads = jax.value_and_grad(
      _aux, has_aux=True)(params, state, inputs, targets, forcings)
  return loss, diagnostics, next_state, grads

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
def with_configs(fn):
  return functools.partial(
      fn, model_config=model_config, task_config=task_config)

# Always pass params and state, so the usage below are simpler
def with_params(fn):
  return functools.partial(fn, params=params, state=state)

# Our models aren't stateful, so the state is always empty, so just return the
# predictions. This is requiredy by our rollout code, and generally simpler.
def drop_state(fn):
  return lambda **kw: fn(**kw)[0]

init_jitted = jax.jit(with_configs(run_forward.init))

if params is None:
  params, state = init_jitted(
      rng=jax.random.PRNGKey(0),
      inputs=train_inputs,
      targets_template=train_targets,
      forcings=train_forcings)

loss_fn_jitted = drop_state(with_params(jax.jit(with_configs(loss_fn.apply))))
grads_fn_jitted = with_params(jax.jit(with_configs(grads_fn)))
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(
    run_forward.apply))))

运行模型

请注意,第一次运行下面的单元格可能需要一段时间(可能几分钟),因为这包括代码编译的时间。第二次运行时速度会明显加快。

这将使用 python 循环迭代预测步骤,其中 1 步的预测是固定的。这比下面的训练步骤对内存的要求要低,应该可以使用小型 GraphCast 模型对 1 度分辨率数据进行 4 步预测。

递归计算(在 python 中的循环)

assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)
predictions

选择要绘制的预测结果

plot_pred_variable = widgets.Dropdown(
    options=predictions.data_vars.keys(),
    value="2m_temperature",
    description="变量")
plot_pred_level = widgets.Dropdown(
    options=predictions.coords["level"].values,
    value=500,
    description="级别")
plot_pred_robust = widgets.Checkbox(value=True, description="鲁棒性")
plot_pred_max_steps = widgets.IntSlider(
    min=1,
    max=predictions.dims["time"],
    value=predictions.dims["time"],
    description="最大步")

widgets.VBox([
    plot_pred_variable,
    plot_pred_level,
    plot_pred_robust,
    plot_pred_max_steps,
    widgets.Label(value="运行下一个单元格,绘制预测结果。重新运行该单元格将清除您的选择。")
])

使用预测数据绘图

plot_size = 5
plot_max_steps = min(predictions.dims["time"], plot_pred_max_steps.value)

data = {
    "Targets": scale(select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Predictions": scale(select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps), robust=plot_pred_robust.value),
    "Diff": scale((select(eval_targets, plot_pred_variable.value, plot_pred_level.value, plot_max_steps) -
                        select(predictions, plot_pred_variable.value, plot_pred_level.value, plot_max_steps)),
                       robust=plot_pred_robust.value, center=0),
}
fig_title = plot_pred_variable.value
if "level" in predictions[plot_pred_variable.value].coords:
  fig_title += f" at {plot_pred_level.value} hPa"

plot_data(data, fig_title, plot_size, plot_pred_robust.value)

训练模型

以下操作需要大量内存,而且根据所使用的加速器,只能在低分辨率数据上拟合很小的 "随机 "模型。它使用上面选择的训练步数。

第一次执行单元需要更多时间,因为其中包括函数的 jit 时间。

损失计算(多步骤递归(自回归)损失)

loss, diagnostics = loss_fn_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)

print("Loss:", float(loss))

梯度计算(通过时间进行反推)

loss, diagnostics, next_state, grads = grads_fn_jitted(
    inputs=train_inputs,
    targets=train_targets,
    forcings=train_forcings)
mean_grad = np.mean(jax.tree_util.tree_flatten(jax.tree_util.tree_map(lambda x: np.abs(x).mean(), grads))[0])
print(f"Loss: {loss:.4f}, Mean |grad|: {mean_grad:.6f}")

递归(自回归)推出(在 JAX 中保持循环)

print("Inputs:  ", train_inputs.dims.mapping)
print("Targets: ", train_targets.dims.mapping)
print("Forcings:", train_forcings.dims.mapping)

predictions = run_forward_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets * np.nan,
    forcings=train_forcings)
predictions

本文标签: 示例或者其他从零开始环境AutoDL