Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): feature set for backed views of views #4

Open
wants to merge 3 commits into
base: zarr-sparse-array
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
as_view,
_resolve_idxs,
)
from .sparse_dataset import sparse_dataset
from .sparse_dataset import sparse_dataset, BaseCompressedSparseDataset
from .. import utils
from ..utils import convert_to_dict, ensure_df_homogeneous, dim_len
from ..logging import anndata_logger as logger
Expand Down Expand Up @@ -609,6 +609,8 @@ def X(self) -> Optional[Union[np.ndarray, sparse.spmatrix, ArrayView]]:
# indices that aren’t strictly increasing
if self.is_view:
X = _subset(X, (self._oidx, self._vidx))
if isinstance(X, BaseCompressedSparseDataset):
X = X.to_memory()
elif self.is_view and self._adata_ref.X is None:
X = None
elif self.is_view:
Expand Down
179 changes: 159 additions & 20 deletions anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
from abc import ABC
import collections.abc as cabc
from itertools import accumulate, chain
from pathlib import Path
from typing import Union, NamedTuple, Tuple, Sequence, Iterable, Type
from warnings import warn

import h5py
import zarr
import numpy as np
import scipy.sparse as ss
from scipy.sparse import _sparsetools

from ..compat import _read_attr
from anndata._core.views import _resolve_idx, as_view

from ..compat import _read_attr, ZarrArray

try:
# Not really important, just for IDEs to be more helpful
Expand All @@ -46,9 +50,17 @@ class BackedSparseMatrix(_cs_matrix):
since that calls copy on `.data`, `.indices`, and `.indptr`.
"""

_cached_indptr = None

def copy(self) -> ss.spmatrix:
if isinstance(self.data, h5py.Dataset):
return sparse_dataset(self.data.parent).to_memory()
if isinstance(self.data, ZarrArray):
return sparse_dataset(
zarr.open(
store=self.data.store, path=Path(self.data.path).parent, mode="r"
)
).to_memory()
else:
return super().copy()

Expand Down Expand Up @@ -116,6 +128,17 @@ def _offsets(
)
return offsets

@property
def indptr(self):
if self._cached_indptr is None:
self._cached_indptr = self._indptr[:]
return self._cached_indptr

@indptr.setter
def indptr(self, indptr):
self._indptr = indptr
self._cached_indptr = None


class backed_csr_matrix(BackedSparseMatrix, ss.csr_matrix):
def _get_intXslice(self, row: int, col: slice) -> ss.csr_matrix:
Expand Down Expand Up @@ -243,9 +266,66 @@ class BaseCompressedSparseDataset(ABC):
Analogous to :class:`h5py.Dataset <h5py:Dataset>` or `zarr.Array`, but for sparse matrices.
"""

def __init__(self, group: h5py.Group):
def __init__(self, group: Union[h5py.Group, zarr.Group]):
type(self)._check_group_format(group)
self.group = group
self._row_subset_idx = slice(None, None, None)
self._col_subset_idx = slice(None, None, None)

@property
def row_subset_idx(self):
"""cached row subset indexer"""
if isinstance(self._row_subset_idx, np.ndarray):
return self._row_subset_idx.flatten() # why????
return self._row_subset_idx

@property
def has_no_subset_idx(self) -> bool:
"""whether or not a subset indexer is on the object"""
return self.has_no_col_subset_idx and self.has_no_row_subset_idx

@property
def has_no_col_subset_idx(self) -> bool:
"""whether or not a column subset indexer is on the object"""
if isinstance(self.col_subset_idx, slice):
if self.col_subset_idx == slice(
None, None, None
) or self.col_subset_idx == slice(0, self.get_backing_shape()[1], 1):
return True
return False

@property
def has_no_row_subset_idx(self) -> bool:
"""whether or not a row subset indexer is on the object"""
if isinstance(self.row_subset_idx, slice):
if self.row_subset_idx == slice(
None, None, None
) or self.row_subset_idx == slice(0, self.get_backing_shape()[0], 1):
return True
return False

@row_subset_idx.setter
def row_subset_idx(self, new_idx):
self._row_subset_idx = (
new_idx
if self.row_subset_idx is None
else _resolve_idx(self.row_subset_idx, new_idx, self.shape[0])
)

@property
def col_subset_idx(self):
"""cached column subset indexer"""
if isinstance(self._col_subset_idx, np.ndarray):
return self._col_subset_idx.flatten()
return self._col_subset_idx

@col_subset_idx.setter
def col_subset_idx(self, new_idx):
self._col_subset_idx = (
new_idx
if self.col_subset_idx is None
else _resolve_idx(self.col_subset_idx, new_idx, self.shape[1])
)

@property
def dtype(self) -> np.dtype:
Expand All @@ -260,14 +340,46 @@ def _check_group_format(cls, group):
def name(self) -> str:
return self.group.name

@property
def shape(self) -> Tuple[int, int]:
def get_backing_shape(self) -> Tuple[int, int]:
"""Generates the shape of the underlying data store i.e., with no indexers.

Returns:
Tuple[int, int]: shape
"""
shape = _read_attr(self.group.attrs, "shape", None)
if shape is None:
# TODO warn
shape = self.group.attrs.get("h5sparse_shape")
return tuple(shape)

@property
def shape(self) -> Tuple[int, int]:
"""Generates the true shape of the object i.e., the shape of the `to_memory` operation, including indexers.

Returns:
Tuple[int, int]: shape
"""
shape = self.get_backing_shape()
if self.has_no_subset_idx:
return tuple(shape)
row_length = 0
col_length = 0
if isinstance(self.row_subset_idx, slice):
if self.row_subset_idx == slice(None, None, None):
row_length = shape[0]
else:
row_length = self.row_subset_idx.stop - self.row_subset_idx.start
else:
row_length = len(self.row_subset_idx) # can we assume a flatten method?
if isinstance(self.col_subset_idx, slice):
if self.col_subset_idx == slice(None, None, None):
col_length = shape[1]
else:
col_length = self.col_subset_idx.stop - self.col_subset_idx.start
else:
col_length = len(self.col_subset_idx) # can we assume a flatten method?
return (row_length, col_length)

@property
def value(self) -> ss.spmatrix:
return self.to_memory()
Expand All @@ -281,14 +393,12 @@ def __repr__(self) -> str:

def __getitem__(self, index: Union[Index, Tuple[()]]) -> Union[float, ss.spmatrix]:
row, col = self._normalize_index(index)
mtx = self.to_backed()
sub = mtx[row, col]
# If indexing is array x array it returns a backed_sparse_matrix
# Not sure what the performance is on that operation
if isinstance(sub, BackedSparseMatrix):
return get_memory_class(self.format_str)(sub)
else:
return sub
new_mtx = sparse_dataset(self.group)
new_mtx.row_subset_idx = self.row_subset_idx
new_mtx.row_subset_idx = row
new_mtx.col_subset_idx = self.col_subset_idx
new_mtx.col_subset_idx = col
return new_mtx

def _normalize_index(
self, index: Union[Index, Tuple[()]]
Expand Down Expand Up @@ -369,20 +479,44 @@ def append(self, sparse_matrix: ss.spmatrix):
indices[orig_data_size:] = sparse_matrix.indices

def to_backed(self) -> BackedSparseMatrix:
"""Generates a `BackedSparseMatrix` whose data arrays are zarr/hdf5 arrays.

Returns:
BackedSparseMatrix: backed data object
"""
format_class = get_backed_class(self.format_str)
mtx = format_class(self.shape, dtype=self.dtype)
mtx = format_class(self.get_backing_shape(), dtype=self.dtype)
mtx.data = self.group["data"]
mtx.indices = self.group["indices"]
mtx.indptr = self.group["indptr"][:]
mtx.indptr = self.group["indptr"]
return mtx

def to_memory(self) -> ss.spmatrix:
format_class = get_memory_class(self.format_str)
mtx = format_class(self.shape, dtype=self.dtype)
mtx.data = self.group["data"][...]
mtx.indices = self.group["indices"][...]
mtx.indptr = self.group["indptr"][...]
return mtx
"""Applies indexers to the `BackedSparseMatrix` to return data.

Returns:
ss.spmatrix: data matrix
"""

# Could not get row idx with csc and vice versa working without reading into memory but shouldn't matter
if (self.format_str == "csr" and self.has_no_row_subset_idx) or (
self.format_str == "csc" and self.has_no_col_subset_idx
):
format_class = get_memory_class(self.format_str)
mtx = format_class(self.get_backing_shape(), dtype=self.dtype)
mtx.data = self.group["data"][...]
mtx.indices = self.group["indices"][...]
mtx.indptr = self.group["indptr"][...]
if self.has_no_subset_idx:
return mtx
else:
mtx = self.to_backed()
if self.format_str == "csr":
return mtx[self.row_subset_idx, :][:, self.col_subset_idx]
return mtx[:, self.col_subset_idx][self.row_subset_idx, :]

def toarray(self) -> np.ndarray:
return self.to_memory().toarray()


class CSRDataset(BaseCompressedSparseDataset):
Expand All @@ -405,3 +539,8 @@ def sparse_dataset(group) -> BaseCompressedSparseDataset:
@_subset.register(BaseCompressedSparseDataset)
def subset_sparsedataset(d, subset_idx):
return d[subset_idx]


@as_view.register(BaseCompressedSparseDataset)
def _view_masked(a: BaseCompressedSparseDataset, view_args):
return a
4 changes: 2 additions & 2 deletions anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def write_sparse_dataset(f, k, elem, _writer, dataset_kwargs=MappingProxyType({}
write_sparse_compressed(
f,
k,
elem.to_backed(),
elem.to_memory(), # if there is a subset on the elem, to_memory lazily reads in __only__ the subset
_writer,
fmt=elem.format_str,
dataset_kwargs=dataset_kwargs,
Expand All @@ -536,7 +536,7 @@ def read_sparse(elem, _reader):
@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("csc_matrix", "0.1.0"))
@_REGISTRY.register_read_partial(ZarrGroup, IOSpec("csr_matrix", "0.1.0"))
def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None))):
return sparse_dataset(elem)[indices]
return sparse_dataset(elem)[indices].to_memory()


#################
Expand Down
11 changes: 6 additions & 5 deletions anndata/experimental/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,24 @@ def _gen_slice_to_append(
fill_value=None,
):
for ds, ri in zip(datasets, reindexers):
n_slices = ds.shape[axis] * ds.shape[1 - axis] // max_loaded_elems
backed = ds.to_backed() # backed object returns data immediately but ds does not, needed for the `else`
n_slices = backed.shape[axis] * backed.shape[1 - axis] // max_loaded_elems
if n_slices < 2:
yield (csr_matrix, csc_matrix)[axis](
ri(to_memory(ds), axis=1 - axis, fill_value=fill_value)
)
else:
slice_size = max_loaded_elems // ds.shape[1 - axis]
slice_size = max_loaded_elems // backed.shape[1 - axis]
if slice_size == 0:
slice_size = 1
rem_slices = ds.shape[axis]
rem_slices = backed.shape[axis]
idx = 0
while rem_slices > 0:
ds_part = None
if axis == 0:
ds_part = ds[idx : idx + slice_size, :]
ds_part = backed[idx : idx + slice_size, :]
elif axis == 1:
ds_part = ds[:, idx : idx + slice_size]
ds_part = backed[:, idx : idx + slice_size]

yield (csr_matrix, csc_matrix)[axis](
ri(ds_part, axis=1 - axis, fill_value=fill_value)
Expand Down