-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add serialbox to netcdf as a tool of ndsl
- Loading branch information
1 parent
e7e4c95
commit 480a574
Showing
2 changed files
with
152 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import argparse | ||
import os | ||
import shutil | ||
import xarray as xr | ||
import f90nml | ||
import numpy as np | ||
from typing import Optional | ||
|
||
try: | ||
import serialbox | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError("Serialbox couldn't be imported, make sure it's in your PYTHONPATH or you env") | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser("Converts Serialbox data to netcdf") | ||
parser.add_argument( | ||
"data_path", | ||
type=str, | ||
help="path of serialbox data to convert", | ||
) | ||
parser.add_argument( | ||
"output_path", type=str, help="output directory where netcdf data will be saved" | ||
) | ||
parser.add_argument( | ||
"-dn", "--data_name", type=str, help="[Optional] Give the name of the data, will default to Generator_rankX" | ||
) | ||
return parser | ||
|
||
|
||
def read_serialized_data(serializer, savepoint, variable): | ||
data = serializer.read(variable, savepoint) | ||
if len(data.flatten()) == 1: | ||
return data[0] | ||
data[data == 1e40] = 0.0 | ||
return data | ||
|
||
|
||
def get_all_savepoint_names(serializer): | ||
savepoint_names = set() | ||
for savepoint in serializer.savepoint_list(): | ||
savepoint_names.add(savepoint.name) | ||
return savepoint_names | ||
|
||
|
||
def get_serializer(data_path: str, rank:int , data_name:Optional[str] = None): | ||
if data_name: | ||
name = data_name | ||
else: | ||
name = f"Generator_rank{rank}" | ||
return serialbox.Serializer(serialbox.OpenModeKind.Read, data_path, name) | ||
|
||
|
||
def main(data_path: str, output_path: str, data_name: Optional[str] = None): | ||
os.makedirs(output_path, exist_ok=True) | ||
namelist_filename_in = os.path.join(data_path, "input.nml") | ||
|
||
if not os.path.exists(namelist_filename_in): | ||
raise FileNotFoundError(f"Can't find input.nml in {data_path}. Required.") | ||
|
||
namelist_filename_out = os.path.join(output_path, "input.nml") | ||
if namelist_filename_out != namelist_filename_in: | ||
shutil.copyfile(os.path.join(data_path, "input.nml"), namelist_filename_out) | ||
namelist = f90nml.read(namelist_filename_out) | ||
total_ranks = ( | ||
6 * namelist["fv_core_nml"]["layout"][0] * namelist["fv_core_nml"]["layout"][1] | ||
) | ||
|
||
# all ranks have the same names, just look at first one | ||
serializer_0 = get_serializer(data_path, rank=0, data_name=data_name) | ||
|
||
savepoint_names = get_all_savepoint_names(serializer_0) | ||
for savepoint_name in sorted(list(savepoint_names)): | ||
rank_list = [] | ||
names_list = list( | ||
serializer_0.fields_at_savepoint(serializer_0.get_savepoint(savepoint_name)[0]) | ||
) | ||
serializer_list = [] | ||
for rank in range(total_ranks): | ||
serializer = get_serializer(data_path, rank, data_name) | ||
serializer_list.append(serializer) | ||
savepoints = serializer.get_savepoint(savepoint_name) | ||
rank_data = {} | ||
for name in set(names_list): | ||
rank_data[name] = [] | ||
for savepoint in savepoints: | ||
rank_data[name].append( | ||
read_serialized_data(serializer, savepoint, name) | ||
) | ||
rank_list.append(rank_data) | ||
n_savepoints = len(savepoints) # checking from last rank is fine | ||
data_vars = {} | ||
if n_savepoints > 0: | ||
encoding = {} | ||
for varname in set(names_list).difference(["rank"]): | ||
data_shape = list(rank_list[0][varname][0].shape) | ||
if savepoint_name in ["FVDynamics-In", "FVDynamics-Out", "Driver-In", "Driver-Out"]: | ||
if varname in [ | ||
"qvapor", | ||
"qliquid", | ||
"qice", | ||
"qrain", | ||
"qsnow", | ||
"qgraupel", | ||
"qo3mr", | ||
"qsgs_tke", | ||
]: | ||
data_vars[varname] = get_data( | ||
data_shape, total_ranks, n_savepoints, rank_list, varname | ||
)[:, :, 3:-3, 3:-3, :] | ||
else: | ||
data_vars[varname] = get_data( | ||
data_shape, total_ranks, n_savepoints, rank_list, varname | ||
) | ||
else: | ||
data_vars[varname] = get_data( | ||
data_shape, total_ranks, n_savepoints, rank_list, varname | ||
) | ||
if len(data_shape) > 2: | ||
encoding[varname] = {"zlib": True, "complevel": 1} | ||
dataset = xr.Dataset(data_vars=data_vars) | ||
dataset.to_netcdf( | ||
os.path.join(output_path, f"{savepoint_name}.nc"), encoding=encoding | ||
) | ||
|
||
|
||
def get_data(data_shape, total_ranks, n_savepoints, output_list, varname): | ||
array = np.full([n_savepoints, total_ranks] + data_shape, fill_value=np.nan) | ||
dims = ["savepoint", "rank"] + [ | ||
f"dim_{varname}_{i}" for i in range(len(data_shape)) | ||
] | ||
data = xr.DataArray(array, dims=dims) | ||
for rank in range(total_ranks): | ||
for i_savepoint in range(n_savepoints): | ||
if len(data_shape) > 0: | ||
data[i_savepoint, rank, :] = output_list[rank][varname][i_savepoint] | ||
else: | ||
data[i_savepoint, rank] = output_list[rank][varname][i_savepoint] | ||
return data | ||
|
||
def entry_point(): | ||
parser = get_parser() | ||
args = parser.parse_args() | ||
main(data_path=args.data_path, output_path=args.output_path, data_name=args.data_name) | ||
|
||
if __name__ == "__main__": | ||
entry_point() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters