Skip to content

Commit

Permalink
adapted tests for obscore casa-ms and casa module itself 🎶
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukas113 committed Oct 2, 2024
1 parent f11d13e commit 5e368e6
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 52 deletions.
4 changes: 3 additions & 1 deletion karabo/data/casa.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def ms_version(
"""
ms_path
ct = cls._get_casa_table_instance(ms_path=ms_path)
version: MS_VERSION = ct.getkeyword("MS_VERSION")
version: MS_VERSION = str( # type: ignore[assignment]
ct.getkeyword("MS_VERSION")
) # noqa: E501
return version


Expand Down
39 changes: 24 additions & 15 deletions karabo/data/obscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,11 @@ def from_visibility(
freq=float(np.max(spectral_window.ref_frequency)),
b=b,
)
ocm.pol_xel = MSPolarizationTable.nrows(ms_path=vis_inode)
corr_types = np.unique(ms_meta.polarization.corr_type.ravel()).tolist()
ocm.set_pol_states(pol_states=corr_types)
pol_states = ocm.get_pol_states()
if pol_states is not None:
ocm.pol_xel = len(pol_states)
elif vis.format == "OSKAR_VIS":
header, _ = VisHeader.read(vis_inode)
ocm.s_ra = header.phase_centre_ra_deg
Expand Down Expand Up @@ -724,7 +726,7 @@ def set_fields(self, **kwargs: Any) -> None:
wmsg = (
f"Skipping `{k}` because it's not a valid field of `ObsCoreMeta`."
)
warn(message=wmsg, category=UserWarning, stacklevel=1)
warn(message=wmsg, category=UserWarning, stacklevel=2)
continue
setattr(self, k, v)

Expand Down Expand Up @@ -939,7 +941,7 @@ def _convert(
if axis == "RA":
if number > 360.0 or number < 0.0:
wmsg = f"Coercing {axis}={number} to {axis}={number}%360"
warn(message=wmsg, category=UserWarning, stacklevel=1)
warn(message=wmsg, category=UserWarning, stacklevel=2)
number = number % 360
elif axis == "DEC":
if number < -90.0 or number > 90.0:
Expand Down Expand Up @@ -994,7 +996,7 @@ def _check_mandatory_fields(self, *, verbose: bool) -> bool:
f"{mandatory_missing=} fields are None in `ObsCoreMeta`, "
+ "but are mandatory to ObsTAP services."
)
warn(message=wmsg, category=UserWarning, stacklevel=1)
warn(message=wmsg, category=UserWarning, stacklevel=2)
return valid

def _check_polarization(self, *, verbose: bool) -> bool:
Expand All @@ -1013,34 +1015,41 @@ def _check_polarization(self, *, verbose: bool) -> bool:
try:
_ = self.get_pol_states()
except ValueError as ve:
valid = False
valid = False # pol-state string is corrupt (e.g. self-set)
if verbose:
warn(message=str(ve), category=UserWarning, stacklevel=1)
warn(message=str(ve), category=UserWarning, stacklevel=2)
if pol_xel is None and pol_states is not None:
valid = False
if verbose:
wmsg = f"`pol_xel` should be specified because {pol_states=}"
warn(message=wmsg, category=UserWarning, stacklevel=1)
warn(message=wmsg, category=UserWarning, stacklevel=2)
elif pol_xel is not None and pol_states is None:
valid = False
if verbose:
wmsg = (
f"{pol_xel=} is specified, but {pol_states=} which isn't consistent"
)
warn(message=wmsg, category=UserWarning, stacklevel=1)
warn(message=wmsg, category=UserWarning, stacklevel=2)
elif pol_xel is not None and pol_states is not None:
if pol_xel != (num_pol_states := len(pol_states)):
valid = False
if verbose:
wmsg = f"{pol_xel=} should be {num_pol_states=}"
warn(message=wmsg, category=UserWarning, stacklevel=1)
try:
pol_states_list = self.get_pol_states()
except ValueError:
valid = False # pol-state string is corrupt (e.g. self-set)
else:
if pol_states_list is not None and pol_xel != (
num_pol_states := len(pol_states_list)
):
valid = False
if verbose:
wmsg = f"{pol_xel=} should be {num_pol_states=}"
warn(message=wmsg, category=UserWarning, stacklevel=2)
if (pol_xel is not None or pol_states is not None) and (
self.o_ucd is None or (ucd_str := "phys.polarisation") not in self.o_ucd
):
valid = False
if verbose:
wmsg = f"`o_ucd` must at least contain '{ucd_str}' but it doesn't"
warn(message=wmsg, category=UserWarning, stacklevel=1)
warn(message=wmsg, category=UserWarning, stacklevel=2)
return valid

def _check_axes(self, *, verbose: bool) -> bool:
Expand All @@ -1067,7 +1076,7 @@ def check_value(value: int | None) -> bool:
valid = False
if verbose:
wmsg = f"Invalid axes-values: {invalid_fields}"
warn(message=wmsg, category=UserWarning, stacklevel=1)
warn(message=wmsg, category=UserWarning, stacklevel=2)
return valid

@classmethod
Expand Down
1 change: 1 addition & 0 deletions karabo/simulation/interferometer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def run_simulation(
raise ValueError(
f"{visibility_path} is not a valid path for format {visibility_format}"
)
os.makedirs(os.path.dirname(visibility_path), exist_ok=True)
if backend is SimulatorBackend.OSKAR:
if primary_beam is not None:
warn(
Expand Down
38 changes: 38 additions & 0 deletions karabo/test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Pytest global fixtures needs to be here!"""

import os
import zipfile
from collections.abc import Callable, Generator, Iterable
from dataclasses import dataclass

Expand All @@ -10,6 +11,12 @@
from numpy.typing import NDArray
from pytest import Config, Item, Parser

from karabo.data.external_data import (
SingleFileDownloadObject,
cscs_karabo_public_testing_base_url,
)
from karabo.imaging.image import Image
from karabo.simulation.visibility import Visibility
from karabo.test import data_path
from karabo.util.file_handler import FileHandler

Expand Down Expand Up @@ -197,3 +204,34 @@ def _normalized_norm_diff(img_path_1: str, img_path_2: str) -> float:
return float(np.linalg.norm(img1 - img2) / (img1.shape[0] * img1.shape[1]))

return _normalized_norm_diff


@pytest.fixture(scope="session")
def minimal_oskar_vis() -> Visibility:
vis_path = SingleFileDownloadObject(
remote_file_path="test_minimal_visibility.vis",
remote_base_url=cscs_karabo_public_testing_base_url,
).get()
return Visibility(vis_path)


@pytest.fixture(scope="session")
def minimal_casa_ms() -> Visibility:
vis_zip_path = SingleFileDownloadObject(
remote_file_path="test_minimal_casa.ms.zip",
remote_base_url=cscs_karabo_public_testing_base_url,
).get()
vis_path = vis_zip_path.strip(".zip")
if not os.path.exists(vis_path):
with zipfile.ZipFile(vis_zip_path, "r") as zip_ref:
zip_ref.extractall(os.path.dirname(vis_path))
return Visibility(vis_path)


@pytest.fixture(scope="session")
def minimal_fits_restored() -> Image:
restored_path = SingleFileDownloadObject(
remote_file_path="test_minimal_clean_restored.fits",
remote_base_url=cscs_karabo_public_testing_base_url,
).get()
return Image(path=restored_path)
35 changes: 35 additions & 0 deletions karabo/test/test_casa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from karabo.data.casa import (
MSAntennaTable,
MSFieldTable,
MSMainTable,
MSMeta,
MSObservationTable,
MSPolarizationTable,
MSSpectralWindowTable,
)
from karabo.simulation.visibility import Visibility


class TestCasaMS:
def test_tables(self, minimal_casa_ms: Visibility) -> None:
"""Minimal table-creation test.
This test currently just calls the table creation function of the particular
to ensure field-name correctness & data loading success.
Args:
minimal_casa_ms: Casa MS fixture.
"""
assert minimal_casa_ms.format == "MS"
ms_path = minimal_casa_ms.path
_ = MSAntennaTable.from_ms(ms_path=ms_path)
_ = MSFieldTable.from_ms(ms_path=ms_path)
_ = MSMainTable.from_ms(ms_path=ms_path)
assert MSMainTable.ms_version(ms_path=ms_path) == "2.0"
_ = MSObservationTable.from_ms(ms_path=ms_path)
_ = MSPolarizationTable.from_ms(ms_path=ms_path)
_ = MSSpectralWindowTable.from_ms(ms_path=ms_path)
ms_meta = MSMeta.from_ms(ms_path=ms_path)
assert ms_meta.ms_version == "2.0"
69 changes: 34 additions & 35 deletions karabo/test/test_obscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@
import numpy as np
import pytest
from astropy import units as u
from pytest import FixtureRequest
from rfc3986.exceptions import InvalidComponentsError

from karabo.data.external_data import (
SingleFileDownloadObject,
cscs_karabo_public_testing_base_url,
)
from karabo.data.obscore import FitsHeaderAxes, FitsHeaderAxis, ObsCoreMeta
from karabo.data.src import RucioMeta
from karabo.imaging.image import Image
Expand All @@ -24,24 +21,6 @@
from karabo.util.helpers import get_rnd_str


@pytest.fixture(scope="module")
def minimal_visibility() -> Visibility:
vis_path = SingleFileDownloadObject(
remote_file_path="test_minimal_visibility.vis",
remote_base_url=cscs_karabo_public_testing_base_url,
).get()
return Visibility(vis_path)


@pytest.fixture(scope="module")
def minimal_fits_restored() -> Image:
restored_path = SingleFileDownloadObject(
remote_file_path="test_minimal_clean_restored.fits",
remote_base_url=cscs_karabo_public_testing_base_url,
).get()
return Image(path=restored_path)


class TestObsCoreMeta:
def test_sshapes(self) -> None:
assert (
Expand Down Expand Up @@ -119,19 +98,35 @@ def test_ivoid(
fragment=fragment,
)

def test_from_visibility(self, minimal_visibility: Visibility) -> None:
telescope = Telescope.constructor("ASKAP", backend=SimulatorBackend.OSKAR)
observation = Observation( # settings from notebook, of `minimal_visibility`
start_frequency_hz=100e6,
start_date_and_time=datetime(2024, 3, 15, 10, 46, 0),
phase_centre_ra_deg=250.0,
phase_centre_dec_deg=-80.0,
number_of_channels=16,
frequency_increment_hz=1e6,
number_of_time_steps=24,
)
@pytest.mark.parametrize(
"vis_fixture_name",
[
"minimal_oskar_vis",
"minimal_casa_ms",
],
)
def test_from_visibility(
self,
vis_fixture_name: str,
request: FixtureRequest,
) -> None:
visibility: Visibility = request.getfixturevalue(vis_fixture_name)
if visibility.format == "OSKAR_VIS":
telescope = Telescope.constructor("ASKAP", backend=SimulatorBackend.OSKAR)
observation = Observation( # original settings for `minimal_oskar_vis`
start_frequency_hz=100e6,
start_date_and_time=datetime(2024, 3, 15, 10, 46, 0),
phase_centre_ra_deg=250.0,
phase_centre_dec_deg=-80.0,
number_of_channels=16,
frequency_increment_hz=1e6,
number_of_time_steps=24,
)
else:
telescope = None
observation = None
ocm = ObsCoreMeta.from_visibility(
vis=minimal_visibility,
vis=visibility,
calibrated=False,
tel=telescope,
obs=observation,
Expand All @@ -153,9 +148,13 @@ def test_from_visibility(self, minimal_visibility: Visibility) -> None:
assert ocm.em_ucd is not None
assert ocm.o_ucd is not None
assert ocm.calib_level == 1 # because `calibrated` flag set to False
assert ocm.instrument_name == telescope.name
assert ocm.instrument_name is not None
assert ocm.s_resolution is not None and ocm.s_resolution > 0.0

if visibility.format == "MS":
assert ocm.pol_xel is not None and ocm.pol_xel > 0
assert ocm.pol_states is not None and len(ocm.pol_states) > 0

with tempfile.TemporaryDirectory() as tmpdir:
meta_path = os.path.join(tmpdir, "obscore-vis.json")
with pytest.warns(UserWarning): # mandatory fields not set
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ profile = black
[flake8]
max-line-length = 88
doctests = True
ignore = E203, W503
ignore = E203, W503, E704
exclude = .git, .eggs, __pycache__, tests/, docs/, build/, dist/

[mypy]
Expand Down

0 comments on commit 5e368e6

Please sign in to comment.