Skip to content

Commit

Permalink
Fixes related to cupy/cupy#7757
Browse files Browse the repository at this point in the history
  • Loading branch information
ivirshup committed Jul 26, 2023
1 parent bac0111 commit 2075a70
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
6 changes: 5 additions & 1 deletion anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from scipy.sparse import spmatrix

from .anndata import AnnData
from ..compat import AwkArray, DaskArray, CupySparseMatrix, CupyArray
from ..compat import AwkArray, DaskArray, CupySparseMatrix, CupyArray, CupyCSRMatrix
from ..utils import asarray, dim_len
from .index import _subset, make_slice
from anndata._warnings import ExperimentalFeatureWarning
Expand Down Expand Up @@ -153,6 +153,10 @@ def equal_sparse(a, b) -> bool:
xp = array_api_compat.array_namespace(a.data)

if isinstance(b, (CupySparseMatrix, sparse.spmatrix)):
if isinstance(a, CupySparseMatrix):
# Comparison broken for CSC matrices
# https://github.com/cupy/cupy/issues/7757
a, b = CupyCSRMatrix(a), CupyCSRMatrix(b)
comp = a != b
if isinstance(comp, bool):
return not comp
Expand Down
10 changes: 8 additions & 2 deletions anndata/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,14 @@ def test_set_scalar_subset_X(matrix_type, subset_func):

assert adata_subset.is_view
assert np.all(asarray(adata[subset_idx, :].X) == 1)

assert asarray((orig_X_val != adata.X)).sum() == mul(*adata_subset.shape)
if isinstance(adata.X, CupyCSCMatrix):
# Comparison broken for CSC matrices
# https://github.com/cupy/cupy/issues/7757
assert asarray((orig_X_val.tocsr() != adata.X.tocsr())).sum() == mul(
*adata_subset.shape
)
else:
assert asarray((orig_X_val != adata.X)).sum() == mul(*adata_subset.shape)


# TODO: Use different kind of subsetting for adata and view
Expand Down

0 comments on commit 2075a70

Please sign in to comment.