Skip to content

Commit

Permalink
Use a subset of pixels when creating a margin cache. (#323)
Browse files Browse the repository at this point in the history
* Resume functionality for margin cache pipeline.

* Use a subset of pixels when creating a margin cache.

* Add healpix pixel list to provenance info.

* Formatting

* Bad merge addition.
  • Loading branch information
delucchi-cmu authored May 31, 2024
1 parent 53b4b0f commit 277d801
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"dask[complete]>=2024.3.0", # Includes dask expressions.
"deprecated",
"healpy",
"hipscat >=0.3.0",
"hipscat >=0.3.4",
"ipykernel", # Support for Jupyter notebooks
"numpy",
"pandas",
Expand Down
15 changes: 12 additions & 3 deletions src/hipscat_import/margin_cache/margin_cache_arguments.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Union
from dataclasses import dataclass, field
from typing import Any, Dict, List, Union

import healpy as hp
from hipscat.catalog import Catalog
from hipscat.catalog.margin_cache.margin_cache_catalog_info import MarginCacheCatalogInfo
from hipscat.io.validation import is_valid_catalog
from hipscat.pixel_math.healpix_pixel import HealpixPixel

from hipscat_import.runtime_arguments import RuntimeArguments

Expand Down Expand Up @@ -35,6 +36,9 @@ class MarginCacheArguments(RuntimeArguments):
"""the path to the hipscat-formatted input catalog."""
input_storage_options: Union[Dict[Any, Any], None] = None
"""optional dictionary of abstract filesystem credentials for the INPUT."""
debug_filter_pixel_list: List[HealpixPixel] = field(default_factory=list)
"""debug setting. if provided, we will first filter the catalog to the pixels
provided. this can be useful for creating a margin over a subset of a catalog."""

def __post_init__(self):
self._check_arguments()
Expand All @@ -49,8 +53,12 @@ def _check_arguments(self):
self.catalog = Catalog.read_from_hipscat(
self.input_catalog_path, storage_options=self.input_storage_options
)
if len(self.debug_filter_pixel_list) > 0:
self.catalog = self.catalog.filter_from_pixel_list(self.debug_filter_pixel_list)
if len(self.catalog.get_healpix_pixels()) == 0:
raise ValueError("debug_filter_pixel_list has created empty catalog")

highest_order = self.catalog.partition_info.get_highest_order()
highest_order = int(self.catalog.partition_info.get_highest_order())
margin_pixel_k = highest_order + 1
if self.margin_order > -1:
if self.margin_order < margin_pixel_k:
Expand Down Expand Up @@ -85,4 +93,5 @@ def additional_runtime_provenance_info(self) -> dict:
"input_catalog_path": self.input_catalog_path,
"margin_threshold": self.margin_threshold,
"margin_order": self.margin_order,
"debug_filter_pixel_list": self.debug_filter_pixel_list,
}
45 changes: 43 additions & 2 deletions tests/hipscat_import/margin_cache/test_arguments_margin_cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Tests of margin cache generation arguments"""

import pytest
from hipscat.io import write_metadata
from hipscat.pixel_math.healpix_pixel import HealpixPixel

from hipscat_import.margin_cache.margin_cache_arguments import MarginCacheArguments

# pylint: disable=protected-access


def test_empty_required(tmp_path):
"""*Most* required arguments are provided."""
Expand Down Expand Up @@ -64,6 +64,42 @@ def test_margin_order_invalid(small_sky_source_catalog, tmp_path):
)


def test_debug_filter_pixel_list(small_sky_source_catalog, tmp_path):
"""Ensure we can generate catalog with a filtereed list of pixels, and
that we raise an exception when the filter results in an empty catalog."""
args = MarginCacheArguments(
margin_threshold=5.0,
input_catalog_path=small_sky_source_catalog,
output_path=tmp_path,
output_artifact_name="catalog_cache",
margin_order=4,
debug_filter_pixel_list=[HealpixPixel(0, 11)],
)

assert len(args.catalog.get_healpix_pixels()) == 13

args = MarginCacheArguments(
margin_threshold=5.0,
input_catalog_path=small_sky_source_catalog,
output_path=tmp_path,
output_artifact_name="catalog_cache",
margin_order=4,
debug_filter_pixel_list=[HealpixPixel(1, 44)],
)

assert len(args.catalog.get_healpix_pixels()) == 4

with pytest.raises(ValueError, match="debug_filter_pixel_list"):
MarginCacheArguments(
margin_threshold=5.0,
input_catalog_path=small_sky_source_catalog,
output_path=tmp_path,
output_artifact_name="catalog_cache",
margin_order=4,
debug_filter_pixel_list=[HealpixPixel(0, 5)],
)


def test_margin_threshold_warns(small_sky_source_catalog, tmp_path):
"""Ensure we give a warning when margin_threshold is greater than margin_order resolution"""

Expand Down Expand Up @@ -99,7 +135,12 @@ def test_provenance_info(small_sky_source_catalog, tmp_path):
output_path=tmp_path,
output_artifact_name="catalog_cache",
margin_order=4,
debug_filter_pixel_list=[HealpixPixel(1, 44)],
)

runtime_args = args.provenance_info()["runtime_args"]
assert "margin_threshold" in runtime_args

write_metadata.write_provenance_info(
catalog_base_dir=args.catalog_path, dataset_info=args.to_catalog_info(20_000), tool_args=runtime_args
)

0 comments on commit 277d801

Please sign in to comment.