Skip to content

Commit

Permalink
Merge pull request #45 from NOAA-GFDL/feature/centralize_translate_test
Browse files Browse the repository at this point in the history
[Refactor] Centralize Translate Test architecture
  • Loading branch information
FlorianDeconinck authored May 31, 2024
2 parents 12cf113 + 2b26863 commit 495953d
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 42 deletions.
176 changes: 137 additions & 39 deletions ndsl/stencils/testing/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import xarray as xr
import yaml

import ndsl.dsl
from ndsl import CompilationConfig, StencilConfig, StencilFactory
from ndsl.comm.communicator import (
Communicator,
CubedSphereCommunicator,
Expand All @@ -17,11 +17,87 @@
from ndsl.comm.partitioner import CubedSpherePartitioner, TilePartitioner
from ndsl.dsl.dace.dace_config import DaceConfig
from ndsl.namelist import Namelist
from ndsl.stencils.testing.grid import Grid # type: ignore
from ndsl.stencils.testing.parallel_translate import ParallelTranslate
from ndsl.stencils.testing.savepoint import SavepointCase, dataset_to_dict
from ndsl.stencils.testing.translate import TranslateGrid


def pytest_addoption(parser):
"""Option for the Translate Test system
See -h or inline help for details.
"""
parser.addoption(
"--backend",
action="store",
default="numpy",
help="Backend to execute the test with, can only be one.",
)
parser.addoption(
"--which_modules",
action="store",
help="Whitelist of modules to run. Only the part after Translate, e.g. in TranslateXYZ it'd be XYZ",
)
parser.addoption(
"--skip_modules",
action="store",
help="Blacklist of modules to not run. Only the part after Translate, e.g. in TranslateXYZ it'd be XYZ",
)
parser.addoption(
"--which_rank", action="store", help="Restrict test to a single rank"
)
parser.addoption(
"--data_path",
action="store",
default="./",
help="Path of Netcdf input and outputs. Naming pattern needs to be XYZ-In and XYZ-Out for a test class named TranslateXYZ",
)
parser.addoption(
"--threshold_overrides_file",
action="store",
default=None,
help="Path to a yaml overriding the default error threshold for a custom value.",
)
parser.addoption(
"--print_failures",
action="store_true",
help="Print the failures detail. Default to True.",
)
parser.addoption(
"--failure_stride",
action="store",
default=1,
help="How many indices of failures to print from worst to best. Default to 1.",
)
parser.addoption(
"--grid",
action="store",
default="file",
help='Grid loading mode. "file" looks for "Grid-Info.nc", "compute" does the same but recomputes MetricTerms, "default" creates a simple grid with no metrics terms. Default to "file".',
)
parser.addoption(
"--topology",
action="store",
default="cube-sphere",
help='Topology of the grid. "cube-sphere" means a 6-faced grid, "doubly-periodic" means a 1 tile grid. Default to "cube-sphere".',
)


def pytest_configure(config):
# register an additional marker
config.addinivalue_line(
"markers", "sequential(name): mark test as running sequentially on ranks"
)
config.addinivalue_line(
"markers", "parallel(name): mark test as running in parallel across ranks"
)
config.addinivalue_line(
"markers",
"mock_parallel(name): mark test as running in mock parallel across ranks",
)


@pytest.fixture()
def data_path(pytestconfig):
return data_path_and_namelist_filename_from_config(pytestconfig)
Expand Down Expand Up @@ -109,12 +185,14 @@ def get_parallel_savepoint_names(metafunc, data_path):

def get_ranks(metafunc, layout):
only_rank = metafunc.config.getoption("which_rank")
dperiodic = metafunc.config.getoption("dperiodic")
topology = metafunc.config.getoption("topology")
if only_rank is None:
if dperiodic:
if topology == "doubly-periodic":
total_ranks = layout[0] * layout[1]
else:
elif topology == "cube-sphere":
total_ranks = 6 * layout[0] * layout[1]
else:
raise NotImplementedError(f"Topology {topology} is unknown.")
return range(total_ranks)
else:
return [int(only_rank)]
Expand All @@ -125,8 +203,8 @@ def get_namelist(namelist_filename):


def get_config(backend: str, communicator: Optional[Communicator]):
stencil_config = ndsl.dsl.stencil.StencilConfig(
compilation_config=ndsl.dsl.stencil.CompilationConfig(
stencil_config = StencilConfig(
compilation_config=CompilationConfig(
backend=backend, rebuild=False, validate_args=True
),
dace_config=DaceConfig(
Expand All @@ -142,17 +220,17 @@ def sequential_savepoint_cases(metafunc, data_path, namelist_filename, *, backen
namelist = get_namelist(namelist_filename)
stencil_config = get_config(backend, None)
ranks = get_ranks(metafunc, namelist.layout)
compute_grid = metafunc.config.getoption("compute_grid")
dperiodic = metafunc.config.getoption("dperiodic")
grid_mode = metafunc.config.getoption("grid")
topology_mode = metafunc.config.getoption("topology")
return _savepoint_cases(
savepoint_names,
ranks,
stencil_config,
namelist,
backend,
data_path,
compute_grid,
dperiodic,
grid_mode,
topology_mode,
)


Expand All @@ -161,25 +239,40 @@ def _savepoint_cases(
ranks,
stencil_config,
namelist,
backend,
data_path,
compute_grid: bool,
dperiodic: bool,
backend: str,
data_path: str,
grid_mode: str,
topology_mode: bool,
):
return_list = []
ds_grid: xr.Dataset = xr.open_dataset(os.path.join(data_path, "Grid-Info.nc")).isel(
savepoint=0
)
for rank in ranks:
grid = TranslateGrid(
dataset_to_dict(ds_grid.isel(rank=rank)),
rank=rank,
layout=namelist.layout,
backend=backend,
).python_grid()
if compute_grid:
compute_grid_data(grid, namelist, backend, namelist.layout, dperiodic)
stencil_factory = ndsl.dsl.stencil.StencilFactory(
if grid_mode == "default":
grid = Grid._make(
namelist.npx + 1,
namelist.npy + 1,
namelist.npz,
namelist.layout,
rank,
backend,
)
elif grid_mode == "file" or grid_mode == "compute":
ds_grid: xr.Dataset = xr.open_dataset(
os.path.join(data_path, "Grid-Info.nc")
).isel(savepoint=0)
grid = TranslateGrid(
dataset_to_dict(ds_grid.isel(rank=rank)),
rank=rank,
layout=namelist.layout,
backend=backend,
).python_grid()
if grid_mode == "compute":
compute_grid_data(
grid, namelist, backend, namelist.layout, topology_mode
)
else:
raise NotImplementedError(f"Grid mode {grid_mode} is unknown.")

stencil_factory = StencilFactory(
config=stencil_config,
grid_indexing=grid.grid_indexing,
)
Expand All @@ -204,12 +297,12 @@ def _savepoint_cases(
return return_list


def compute_grid_data(grid, namelist, backend, layout, dperiodic):
def compute_grid_data(grid, namelist, backend, layout, topology_mode):
grid.make_grid_data(
npx=namelist.npx,
npy=namelist.npy,
npz=namelist.npz,
communicator=get_communicator(MPI.COMM_WORLD, layout, dperiodic),
communicator=get_communicator(MPI.COMM_WORLD, layout, topology_mode),
backend=backend,
)

Expand All @@ -218,20 +311,20 @@ def parallel_savepoint_cases(
metafunc, data_path, namelist_filename, mpi_rank, *, backend: str, comm
):
namelist = get_namelist(namelist_filename)
dperiodic = metafunc.config.getoption("dperiodic")
communicator = get_communicator(comm, namelist.layout, dperiodic)
topology_mode = metafunc.config.getoption("topology")
communicator = get_communicator(comm, namelist.layout, topology_mode)
stencil_config = get_config(backend, communicator)
savepoint_names = get_parallel_savepoint_names(metafunc, data_path)
compute_grid = metafunc.config.getoption("compute_grid")
grid_mode = metafunc.config.getoption("grid")
return _savepoint_cases(
savepoint_names,
[mpi_rank],
stencil_config,
namelist,
backend,
data_path,
compute_grid,
dperiodic,
grid_mode,
topology_mode,
)


Expand Down Expand Up @@ -276,8 +369,8 @@ def generate_parallel_stencil_tests(metafunc, *, backend: str):
)


def get_communicator(comm, layout, dperiodic):
if (MPI.COMM_WORLD.Get_size() > 1) and (not dperiodic):
def get_communicator(comm, layout, topology_mode):
if (MPI.COMM_WORLD.Get_size() > 1) and (topology_mode == "doubly-periodic"):
partitioner = CubedSpherePartitioner(TilePartitioner(layout))
communicator = CubedSphereCommunicator(comm, partitioner)
else:
Expand All @@ -297,10 +390,15 @@ def failure_stride(pytestconfig):


@pytest.fixture()
def compute_grid(pytestconfig):
return pytestconfig.getoption("compute_grid")
def grid(pytestconfig):
return pytestconfig.getoption("grid")


@pytest.fixture()
def topology_mode(pytestconfig):
return pytestconfig.getoption("topology_mode")


@pytest.fixture()
def dperiodic(pytestconfig):
return pytestconfig.getoption("dperiodic")
def backend(pytestconfig):
return pytestconfig.getoption("backend")
Loading

0 comments on commit 495953d

Please sign in to comment.