Skip to content

Commit

Permalink
Merge pull request #68 from oelbert/feature/merge_savepoints
Browse files Browse the repository at this point in the history
Option to merge blocks in serialized data
  • Loading branch information
FlorianDeconinck authored Sep 3, 2024
2 parents d0c1703 + cb4ce98 commit 82e5384
Showing 1 changed file with 55 additions and 5 deletions.
60 changes: 55 additions & 5 deletions ndsl/stencils/testing/serialbox_to_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def get_parser():
type=str,
help="[Optional] Give the name of the data, will default to Generator_rankX",
)
parser.add_argument(
"-m",
"--merge",
action="store_true",
default=False,
help="merges datastreams blocked into separate savepoints",
)
return parser


Expand All @@ -58,7 +65,12 @@ def get_serializer(data_path: str, rank: int, data_name: Optional[str] = None):
return serialbox.Serializer(serialbox.OpenModeKind.Read, data_path, name)


def main(data_path: str, output_path: str, data_name: Optional[str] = None):
def main(
data_path: str,
output_path: str,
merge_blocks: bool,
data_name: Optional[str] = None,
):
os.makedirs(output_path, exist_ok=True)
namelist_filename_in = os.path.join(data_path, "input.nml")

Expand All @@ -69,8 +81,21 @@ def main(data_path: str, output_path: str, data_name: Optional[str] = None):
if namelist_filename_out != namelist_filename_in:
shutil.copyfile(os.path.join(data_path, "input.nml"), namelist_filename_out)
namelist = f90nml.read(namelist_filename_out)
total_ranks = (
6 * namelist["fv_core_nml"]["layout"][0] * namelist["fv_core_nml"]["layout"][1]
if namelist["fv_core_nml"]["grid_type"] <= 3:
total_ranks = (
6
* namelist["fv_core_nml"]["layout"][0]
* namelist["fv_core_nml"]["layout"][1]
)
else:
total_ranks = (
namelist["fv_core_nml"]["layout"][0] * namelist["fv_core_nml"]["layout"][1]
)
nx = int(
(namelist["fv_core_nml"]["npx"] - 1) / (namelist["fv_core_nml"]["layout"][0])
)
ny = int(
(namelist["fv_core_nml"]["npy"] - 1) / (namelist["fv_core_nml"]["layout"][1])
)

# all ranks have the same names, just look at first one
Expand All @@ -96,8 +121,30 @@ def main(data_path: str, output_path: str, data_name: Optional[str] = None):
rank_data[name].append(
read_serialized_data(serializer, savepoint, name)
)
nblocks = len(rank_data[name])
if merge_blocks and len(rank_data[name]) > 1:
full_data = np.array(rank_data[name])
if len(full_data.shape) > 1:
if nx * ny == full_data.shape[0] * full_data.shape[1]:
# If we have an (i, x) array from each block reshape it
new_shape = (nx, ny) + full_data.shape[2:]
full_data = full_data.reshape(new_shape)
else:
# We have one array for all blocks
# could be a k-array or something else, so we take one copy
# TODO: is there a decent check for this?
full_data = full_data[0]
elif len(full_data.shape) == 1:
# if it's a scalar from each block then just take one
full_data = full_data[0]
else:
raise IndexError(f"{name} data appears to be empty")
rank_data[name] = [full_data]
rank_list.append(rank_data)
n_savepoints = len(savepoints) # checking from last rank is fine
if merge_blocks:
n_savepoints = 1
else:
n_savepoints = len(savepoints) # checking from last rank is fine
data_vars = {}
if n_savepoints > 0:
encoding = {}
Expand Down Expand Up @@ -166,7 +213,10 @@ def entry_point():
parser = get_parser()
args = parser.parse_args()
main(
data_path=args.data_path, output_path=args.output_path, data_name=args.data_name
data_path=args.data_path,
output_path=args.output_path,
merge_blocks=args.merge,
data_name=args.data_name,
)


Expand Down

0 comments on commit 82e5384

Please sign in to comment.