Skip to content

Commit

Permalink
Support IO
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup committed Jul 26, 2023
1 parent 5d69e2d commit bac0111
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 29 deletions.
80 changes: 54 additions & 26 deletions anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from os import PathLike
from collections.abc import Mapping
from functools import partial
from itertools import product
from functools import partial, wraps
from typing import Union, Literal
from types import MappingProxyType
from warnings import warn
Expand All @@ -28,7 +29,7 @@
)
from anndata._io.utils import report_write_key_on_error, check_key, H5PY_V3
from anndata._warnings import OldFormatWarning
from anndata.compat import AwkArray
from anndata.compat import AwkArray, CupyArray, CupyCSRMatrix, CupyCSCMatrix

from .registry import (
_REGISTRY,
Expand Down Expand Up @@ -68,6 +69,26 @@
# return False


def _to_cpu_mem_wrapper(write_func):
"""
Wrapper to bring cupy types into cpu memory before writing.
Ideally we do direct writing at some point.
"""

def wrapper(
f,
k,
cupy_val: CupyArray | CupyCSCMatrix | CupyCSRMatrix,
_writer,
*,
dataset_kwargs=MappingProxyType,
):
return write_func(f, k, cupy_val.get(), _writer, dataset_kwargs=dataset_kwargs)

return wrapper


################################
# Fallbacks / backwards compat #
################################
Expand Down Expand Up @@ -308,6 +329,14 @@ def write_basic(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):
f.create_dataset(k, data=elem, **dataset_kwargs)


_REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))(
_to_cpu_mem_wrapper(write_basic)
)
_REGISTRY.register_write(ZarrGroup, CupyArray, IOSpec("array", "0.2.0"))(
_to_cpu_mem_wrapper(write_basic)
)


@_REGISTRY.register_write(ZarrGroup, DaskArray, IOSpec("array", "0.2.0"))
@_REGISTRY.register_write(H5Group, DaskArray, IOSpec("array", "0.2.0"))
def write_basic_dask(f, k, elem, _writer, dataset_kwargs=MappingProxyType({})):
Expand Down Expand Up @@ -451,30 +480,29 @@ def write_sparse_compressed(

write_csr = partial(write_sparse_compressed, fmt="csr")
write_csc = partial(write_sparse_compressed, fmt="csc")
_REGISTRY.register_write(H5Group, sparse.csr_matrix, IOSpec("csr_matrix", "0.1.0"))(
write_csr
)
_REGISTRY.register_write(H5Group, views.SparseCSRView, IOSpec("csr_matrix", "0.1.0"))(
write_csr
)
_REGISTRY.register_write(H5Group, sparse.csc_matrix, IOSpec("csc_matrix", "0.1.0"))(
write_csc
)
_REGISTRY.register_write(H5Group, views.SparseCSCView, IOSpec("csc_matrix", "0.1.0"))(
write_csc
)
_REGISTRY.register_write(ZarrGroup, sparse.csr_matrix, IOSpec("csr_matrix", "0.1.0"))(
write_csr
)
_REGISTRY.register_write(ZarrGroup, views.SparseCSRView, IOSpec("csr_matrix", "0.1.0"))(
write_csr
)
_REGISTRY.register_write(ZarrGroup, sparse.csc_matrix, IOSpec("csc_matrix", "0.1.0"))(
write_csc
)
_REGISTRY.register_write(ZarrGroup, views.SparseCSCView, IOSpec("csc_matrix", "0.1.0"))(
write_csc
)

for store_type, (cls, spec, func) in product(
(H5Group, ZarrGroup),
[
(sparse.csr_matrix, IOSpec("csr_matrix", "0.1.0"), write_csr),
(views.SparseCSRView, IOSpec("csr_matrix", "0.1.0"), write_csr),
(sparse.csc_matrix, IOSpec("csc_matrix", "0.1.0"), write_csc),
(views.SparseCSCView, IOSpec("csc_matrix", "0.1.0"), write_csc),
(CupyCSRMatrix, IOSpec("csr_matrix", "0.1.0"), _to_cpu_mem_wrapper(write_csr)),
(
views.CupySparseCSRView,
IOSpec("csr_matrix", "0.1.0"),
_to_cpu_mem_wrapper(write_csr),
),
(CupyCSCMatrix, IOSpec("csc_matrix", "0.1.0"), _to_cpu_mem_wrapper(write_csc)),
(
views.CupySparseCSCView,
IOSpec("csc_matrix", "0.1.0"),
_to_cpu_mem_wrapper(write_csc),
),
],
):
_REGISTRY.register_write(store_type, cls, spec)(func)


@_REGISTRY.register_write(H5Group, SparseDataset, IOSpec("", "0.1.0"))
Expand Down
4 changes: 2 additions & 2 deletions anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,9 +639,9 @@ def as_cupy_type(val, typ=None):
import cupy as cp

if isinstance(val, np.ndarray):
return cpsparse.csr_matrix(cp.array(val))
return cpsparse.csc_matrix(cp.array(val))
else:
return cpsparse.csr_matrix(val)
return cpsparse.csc_matrix(val)
else:
raise NotImplementedError(
f"Conversion from {type(val)} to {typ} not implemented"
Expand Down
29 changes: 28 additions & 1 deletion anndata/tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from anndata._io.utils import AnnDataReadError
from anndata.compat import _read_attr, H5Group, ZarrGroup
from anndata._io.specs import write_elem, read_elem
from anndata.tests.helpers import assert_equal, gen_adata
from anndata.tests.helpers import assert_equal, gen_adata, as_cupy_type


@pytest.fixture(params=["h5ad", "zarr"])
Expand Down Expand Up @@ -94,6 +94,33 @@ def test_io_spec(store, value, encoding_type):
assert get_spec(store[key]) == _REGISTRY.get_spec(value)


# Can't instantiate cupy types at the top level, so converting them within the test
@pytest.mark.gpu
@pytest.mark.parametrize(
"value,encoding_type",
[
(np.array([1, 2, 3]), "array"),
(np.arange(12).reshape(4, 3), "array"),
(sparse.random(5, 3, format="csr", density=0.5), "csr_matrix"),
(sparse.random(5, 3, format="csc", density=0.5), "csc_matrix"),
],
)
def test_io_spec_cupy(store, value, encoding_type):
"""Tests that"""
key = f"key_for_{encoding_type}"
print(type(value))
value = as_cupy_type(value)

print(type(value))
write_elem(store, key, value, dataset_kwargs={})

assert encoding_type == _read_attr(store[key].attrs, "encoding-type")

from_disk = as_cupy_type(read_elem(store[key]))
assert_equal(value, from_disk)
assert get_spec(store[key]) == _REGISTRY.get_spec(value)


def test_io_spec_raw(store):
adata = gen_adata((3, 2))
adata.raw = adata
Expand Down

0 comments on commit bac0111

Please sign in to comment.