Skip to content

Commit

Permalink
Add serialbox to netcdf as a tool of ndsl
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianDeconinck committed May 17, 2024
1 parent e7e4c95 commit 480a574
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 0 deletions.
147 changes: 147 additions & 0 deletions ndsl/stencils/testing/serialbox_to_netcdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import argparse
import os
import shutil
import xarray as xr
import f90nml
import numpy as np
from typing import Optional

try:
import serialbox
except ModuleNotFoundError:
raise ModuleNotFoundError("Serialbox couldn't be imported, make sure it's in your PYTHONPATH or you env")


def get_parser():
parser = argparse.ArgumentParser("Converts Serialbox data to netcdf")
parser.add_argument(
"data_path",
type=str,
help="path of serialbox data to convert",
)
parser.add_argument(
"output_path", type=str, help="output directory where netcdf data will be saved"
)
parser.add_argument(
"-dn", "--data_name", type=str, help="[Optional] Give the name of the data, will default to Generator_rankX"
)
return parser


def read_serialized_data(serializer, savepoint, variable):
data = serializer.read(variable, savepoint)
if len(data.flatten()) == 1:
return data[0]
data[data == 1e40] = 0.0
return data


def get_all_savepoint_names(serializer):
savepoint_names = set()
for savepoint in serializer.savepoint_list():
savepoint_names.add(savepoint.name)
return savepoint_names


def get_serializer(data_path: str, rank:int , data_name:Optional[str] = None):
if data_name:
name = data_name
else:
name = f"Generator_rank{rank}"
return serialbox.Serializer(serialbox.OpenModeKind.Read, data_path, name)


def main(data_path: str, output_path: str, data_name: Optional[str] = None):
os.makedirs(output_path, exist_ok=True)
namelist_filename_in = os.path.join(data_path, "input.nml")

if not os.path.exists(namelist_filename_in):
raise FileNotFoundError(f"Can't find input.nml in {data_path}. Required.")

namelist_filename_out = os.path.join(output_path, "input.nml")
if namelist_filename_out != namelist_filename_in:
shutil.copyfile(os.path.join(data_path, "input.nml"), namelist_filename_out)
namelist = f90nml.read(namelist_filename_out)
total_ranks = (
6 * namelist["fv_core_nml"]["layout"][0] * namelist["fv_core_nml"]["layout"][1]
)

# all ranks have the same names, just look at first one
serializer_0 = get_serializer(data_path, rank=0, data_name=data_name)

savepoint_names = get_all_savepoint_names(serializer_0)
for savepoint_name in sorted(list(savepoint_names)):
rank_list = []
names_list = list(
serializer_0.fields_at_savepoint(serializer_0.get_savepoint(savepoint_name)[0])
)
serializer_list = []
for rank in range(total_ranks):
serializer = get_serializer(data_path, rank, data_name)
serializer_list.append(serializer)
savepoints = serializer.get_savepoint(savepoint_name)
rank_data = {}
for name in set(names_list):
rank_data[name] = []
for savepoint in savepoints:
rank_data[name].append(
read_serialized_data(serializer, savepoint, name)
)
rank_list.append(rank_data)
n_savepoints = len(savepoints) # checking from last rank is fine
data_vars = {}
if n_savepoints > 0:
encoding = {}
for varname in set(names_list).difference(["rank"]):
data_shape = list(rank_list[0][varname][0].shape)
if savepoint_name in ["FVDynamics-In", "FVDynamics-Out", "Driver-In", "Driver-Out"]:
if varname in [
"qvapor",
"qliquid",
"qice",
"qrain",
"qsnow",
"qgraupel",
"qo3mr",
"qsgs_tke",
]:
data_vars[varname] = get_data(
data_shape, total_ranks, n_savepoints, rank_list, varname
)[:, :, 3:-3, 3:-3, :]
else:
data_vars[varname] = get_data(
data_shape, total_ranks, n_savepoints, rank_list, varname
)
else:
data_vars[varname] = get_data(
data_shape, total_ranks, n_savepoints, rank_list, varname
)
if len(data_shape) > 2:
encoding[varname] = {"zlib": True, "complevel": 1}
dataset = xr.Dataset(data_vars=data_vars)
dataset.to_netcdf(
os.path.join(output_path, f"{savepoint_name}.nc"), encoding=encoding
)


def get_data(data_shape, total_ranks, n_savepoints, output_list, varname):
array = np.full([n_savepoints, total_ranks] + data_shape, fill_value=np.nan)
dims = ["savepoint", "rank"] + [
f"dim_{varname}_{i}" for i in range(len(data_shape))
]
data = xr.DataArray(array, dims=dims)
for rank in range(total_ranks):
for i_savepoint in range(n_savepoints):
if len(data_shape) > 0:
data[i_savepoint, rank, :] = output_list[rank][varname][i_savepoint]
else:
data[i_savepoint, rank] = output_list[rank][varname][i_savepoint]
return data

def entry_point():
parser = get_parser()
args = parser.parse_args()
main(data_path=args.data_path, output_path=args.output_path, data_name=args.data_name)

if __name__ == "__main__":
entry_point()
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,9 @@ def local_pkg(name: str, relative_path: str) -> str:
url="https://github.com/NOAA-GFDL/NDSL",
version="2024.04.00",
zip_safe=False,
entry_points={
"console_scripts": [
"ndsl-serialbox_to_netcdf = ndsl.stencils.testing.serialbox_to_netcdf:entry_point",
]
},
)

0 comments on commit 480a574

Please sign in to comment.