diff --git a/anndata/_core/anndata.py b/anndata/_core/anndata.py index b74659c21..83a1b158c 100644 --- a/anndata/_core/anndata.py +++ b/anndata/_core/anndata.py @@ -43,7 +43,7 @@ as_view, _resolve_idxs, ) -from .sparse_dataset import sparse_dataset +from .sparse_dataset import sparse_dataset, BaseCompressedSparseDataset from .. import utils from ..utils import convert_to_dict, ensure_df_homogeneous, dim_len from ..logging import anndata_logger as logger @@ -609,6 +609,8 @@ def X(self) -> Optional[Union[np.ndarray, sparse.spmatrix, ArrayView]]: # indices that aren’t strictly increasing if self.is_view: X = _subset(X, (self._oidx, self._vidx)) + if isinstance(X, BaseCompressedSparseDataset): + X = X.to_memory() elif self.is_view and self._adata_ref.X is None: X = None elif self.is_view: diff --git a/anndata/_core/sparse_dataset.py b/anndata/_core/sparse_dataset.py index e9b5e4824..681005814 100644 --- a/anndata/_core/sparse_dataset.py +++ b/anndata/_core/sparse_dataset.py @@ -13,15 +13,19 @@ from abc import ABC import collections.abc as cabc from itertools import accumulate, chain +from pathlib import Path from typing import Union, NamedTuple, Tuple, Sequence, Iterable, Type from warnings import warn import h5py +import zarr import numpy as np import scipy.sparse as ss from scipy.sparse import _sparsetools -from ..compat import _read_attr +from anndata._core.views import _resolve_idx, as_view + +from ..compat import _read_attr, ZarrArray try: # Not really important, just for IDEs to be more helpful @@ -46,9 +50,17 @@ class BackedSparseMatrix(_cs_matrix): since that calls copy on `.data`, `.indices`, and `.indptr`. """ + _cached_indptr = None + def copy(self) -> ss.spmatrix: if isinstance(self.data, h5py.Dataset): return sparse_dataset(self.data.parent).to_memory() + if isinstance(self.data, ZarrArray): + return sparse_dataset( + zarr.open( + store=self.data.store, path=Path(self.data.path).parent, mode="r" + ) + ).to_memory() else: return super().copy() @@ -116,6 +128,17 @@ def _offsets( ) return offsets + @property + def indptr(self): + if self._cached_indptr is None: + self._cached_indptr = self._indptr[:] + return self._cached_indptr + + @indptr.setter + def indptr(self, indptr): + self._indptr = indptr + self._cached_indptr = None + class backed_csr_matrix(BackedSparseMatrix, ss.csr_matrix): def _get_intXslice(self, row: int, col: slice) -> ss.csr_matrix: @@ -243,9 +266,66 @@ class BaseCompressedSparseDataset(ABC): Analogous to :class:`h5py.Dataset ` or `zarr.Array`, but for sparse matrices. """ - def __init__(self, group: h5py.Group): + def __init__(self, group: Union[h5py.Group, zarr.Group]): type(self)._check_group_format(group) self.group = group + self._row_subset_idx = slice(None, None, None) + self._col_subset_idx = slice(None, None, None) + + @property + def row_subset_idx(self): + """cached row subset indexer""" + if isinstance(self._row_subset_idx, np.ndarray): + return self._row_subset_idx.flatten() # why???? + return self._row_subset_idx + + @property + def has_no_subset_idx(self) -> bool: + """whether or not a subset indexer is on the object""" + return self.has_no_col_subset_idx and self.has_no_row_subset_idx + + @property + def has_no_col_subset_idx(self) -> bool: + """whether or not a column subset indexer is on the object""" + if isinstance(self.col_subset_idx, slice): + if self.col_subset_idx == slice( + None, None, None + ) or self.col_subset_idx == slice(0, self.get_backing_shape()[1], 1): + return True + return False + + @property + def has_no_row_subset_idx(self) -> bool: + """whether or not a row subset indexer is on the object""" + if isinstance(self.row_subset_idx, slice): + if self.row_subset_idx == slice( + None, None, None + ) or self.row_subset_idx == slice(0, self.get_backing_shape()[0], 1): + return True + return False + + @row_subset_idx.setter + def row_subset_idx(self, new_idx): + self._row_subset_idx = ( + new_idx + if self.row_subset_idx is None + else _resolve_idx(self.row_subset_idx, new_idx, self.shape[0]) + ) + + @property + def col_subset_idx(self): + """cached column subset indexer""" + if isinstance(self._col_subset_idx, np.ndarray): + return self._col_subset_idx.flatten() + return self._col_subset_idx + + @col_subset_idx.setter + def col_subset_idx(self, new_idx): + self._col_subset_idx = ( + new_idx + if self.col_subset_idx is None + else _resolve_idx(self.col_subset_idx, new_idx, self.shape[1]) + ) @property def dtype(self) -> np.dtype: @@ -260,14 +340,46 @@ def _check_group_format(cls, group): def name(self) -> str: return self.group.name - @property - def shape(self) -> Tuple[int, int]: + def get_backing_shape(self) -> Tuple[int, int]: + """Generates the shape of the underlying data store i.e., with no indexers. + + Returns: + Tuple[int, int]: shape + """ shape = _read_attr(self.group.attrs, "shape", None) if shape is None: # TODO warn shape = self.group.attrs.get("h5sparse_shape") return tuple(shape) + @property + def shape(self) -> Tuple[int, int]: + """Generates the true shape of the object i.e., the shape of the `to_memory` operation, including indexers. + + Returns: + Tuple[int, int]: shape + """ + shape = self.get_backing_shape() + if self.has_no_subset_idx: + return tuple(shape) + row_length = 0 + col_length = 0 + if isinstance(self.row_subset_idx, slice): + if self.row_subset_idx == slice(None, None, None): + row_length = shape[0] + else: + row_length = self.row_subset_idx.stop - self.row_subset_idx.start + else: + row_length = len(self.row_subset_idx) # can we assume a flatten method? + if isinstance(self.col_subset_idx, slice): + if self.col_subset_idx == slice(None, None, None): + col_length = shape[1] + else: + col_length = self.col_subset_idx.stop - self.col_subset_idx.start + else: + col_length = len(self.col_subset_idx) # can we assume a flatten method? + return (row_length, col_length) + @property def value(self) -> ss.spmatrix: return self.to_memory() @@ -281,14 +393,12 @@ def __repr__(self) -> str: def __getitem__(self, index: Union[Index, Tuple[()]]) -> Union[float, ss.spmatrix]: row, col = self._normalize_index(index) - mtx = self.to_backed() - sub = mtx[row, col] - # If indexing is array x array it returns a backed_sparse_matrix - # Not sure what the performance is on that operation - if isinstance(sub, BackedSparseMatrix): - return get_memory_class(self.format_str)(sub) - else: - return sub + new_mtx = sparse_dataset(self.group) + new_mtx.row_subset_idx = self.row_subset_idx + new_mtx.row_subset_idx = row + new_mtx.col_subset_idx = self.col_subset_idx + new_mtx.col_subset_idx = col + return new_mtx def _normalize_index( self, index: Union[Index, Tuple[()]] @@ -369,20 +479,44 @@ def append(self, sparse_matrix: ss.spmatrix): indices[orig_data_size:] = sparse_matrix.indices def to_backed(self) -> BackedSparseMatrix: + """Generates a `BackedSparseMatrix` whose data arrays are zarr/hdf5 arrays. + + Returns: + BackedSparseMatrix: backed data object + """ format_class = get_backed_class(self.format_str) - mtx = format_class(self.shape, dtype=self.dtype) + mtx = format_class(self.get_backing_shape(), dtype=self.dtype) mtx.data = self.group["data"] mtx.indices = self.group["indices"] - mtx.indptr = self.group["indptr"][:] + mtx.indptr = self.group["indptr"] return mtx def to_memory(self) -> ss.spmatrix: - format_class = get_memory_class(self.format_str) - mtx = format_class(self.shape, dtype=self.dtype) - mtx.data = self.group["data"][...] - mtx.indices = self.group["indices"][...] - mtx.indptr = self.group["indptr"][...] - return mtx + """Applies indexers to the `BackedSparseMatrix` to return data. + + Returns: + ss.spmatrix: data matrix + """ + + # Could not get row idx with csc and vice versa working without reading into memory but shouldn't matter + if (self.format_str == "csr" and self.has_no_row_subset_idx) or ( + self.format_str == "csc" and self.has_no_col_subset_idx + ): + format_class = get_memory_class(self.format_str) + mtx = format_class(self.get_backing_shape(), dtype=self.dtype) + mtx.data = self.group["data"][...] + mtx.indices = self.group["indices"][...] + mtx.indptr = self.group["indptr"][...] + if self.has_no_subset_idx: + return mtx + else: + mtx = self.to_backed() + if self.format_str == "csr": + return mtx[self.row_subset_idx, :][:, self.col_subset_idx] + return mtx[:, self.col_subset_idx][self.row_subset_idx, :] + + def toarray(self) -> np.ndarray: + return self.to_memory().toarray() class CSRDataset(BaseCompressedSparseDataset): @@ -405,3 +539,8 @@ def sparse_dataset(group) -> BaseCompressedSparseDataset: @_subset.register(BaseCompressedSparseDataset) def subset_sparsedataset(d, subset_idx): return d[subset_idx] + + +@as_view.register(BaseCompressedSparseDataset) +def _view_masked(a: BaseCompressedSparseDataset, view_args): + return a diff --git a/anndata/_io/specs/methods.py b/anndata/_io/specs/methods.py index 02421a46e..18ac9c3b5 100644 --- a/anndata/_io/specs/methods.py +++ b/anndata/_io/specs/methods.py @@ -513,7 +513,7 @@ def write_sparse_dataset(f, k, elem, _writer, dataset_kwargs=MappingProxyType({} write_sparse_compressed( f, k, - elem.to_backed(), + elem.to_memory(), # if there is a subset on the elem, to_memory lazily reads in __only__ the subset _writer, fmt=elem.format_str, dataset_kwargs=dataset_kwargs, @@ -536,7 +536,7 @@ def read_sparse(elem, _reader): @_REGISTRY.register_read_partial(ZarrGroup, IOSpec("csc_matrix", "0.1.0")) @_REGISTRY.register_read_partial(ZarrGroup, IOSpec("csr_matrix", "0.1.0")) def read_sparse_partial(elem, *, items=None, indices=(slice(None), slice(None))): - return sparse_dataset(elem)[indices] + return sparse_dataset(elem)[indices].to_memory() ################# diff --git a/anndata/experimental/merge.py b/anndata/experimental/merge.py index 89fd2bdf4..c8c818b04 100644 --- a/anndata/experimental/merge.py +++ b/anndata/experimental/merge.py @@ -73,23 +73,24 @@ def _gen_slice_to_append( fill_value=None, ): for ds, ri in zip(datasets, reindexers): - n_slices = ds.shape[axis] * ds.shape[1 - axis] // max_loaded_elems + backed = ds.to_backed() # backed object returns data immediately but ds does not, needed for the `else` + n_slices = backed.shape[axis] * backed.shape[1 - axis] // max_loaded_elems if n_slices < 2: yield (csr_matrix, csc_matrix)[axis]( ri(to_memory(ds), axis=1 - axis, fill_value=fill_value) ) else: - slice_size = max_loaded_elems // ds.shape[1 - axis] + slice_size = max_loaded_elems // backed.shape[1 - axis] if slice_size == 0: slice_size = 1 - rem_slices = ds.shape[axis] + rem_slices = backed.shape[axis] idx = 0 while rem_slices > 0: ds_part = None if axis == 0: - ds_part = ds[idx : idx + slice_size, :] + ds_part = backed[idx : idx + slice_size, :] elif axis == 1: - ds_part = ds[:, idx : idx + slice_size] + ds_part = backed[:, idx : idx + slice_size] yield (csr_matrix, csc_matrix)[axis]( ri(ds_part, axis=1 - axis, fill_value=fill_value)