diff --git a/anndata/_core/merge.py b/anndata/_core/merge.py index 9f53ea3a6..6e7f94172 100644 --- a/anndata/_core/merge.py +++ b/anndata/_core/merge.py @@ -29,7 +29,7 @@ from scipy.sparse import spmatrix from .anndata import AnnData -from ..compat import AwkArray, DaskArray, CupySparseMatrix, CupyArray +from ..compat import AwkArray, DaskArray, CupySparseMatrix, CupyArray, CupyCSRMatrix from ..utils import asarray, dim_len from .index import _subset, make_slice from anndata._warnings import ExperimentalFeatureWarning @@ -153,6 +153,10 @@ def equal_sparse(a, b) -> bool: xp = array_api_compat.array_namespace(a.data) if isinstance(b, (CupySparseMatrix, sparse.spmatrix)): + if isinstance(a, CupySparseMatrix): + # Comparison broken for CSC matrices + # https://github.com/cupy/cupy/issues/7757 + a, b = CupyCSRMatrix(a), CupyCSRMatrix(b) comp = a != b if isinstance(comp, bool): return not comp diff --git a/anndata/tests/test_views.py b/anndata/tests/test_views.py index 06570e672..29fc10e09 100644 --- a/anndata/tests/test_views.py +++ b/anndata/tests/test_views.py @@ -327,8 +327,14 @@ def test_set_scalar_subset_X(matrix_type, subset_func): assert adata_subset.is_view assert np.all(asarray(adata[subset_idx, :].X) == 1) - - assert asarray((orig_X_val != adata.X)).sum() == mul(*adata_subset.shape) + if isinstance(adata.X, CupyCSCMatrix): + # Comparison broken for CSC matrices + # https://github.com/cupy/cupy/issues/7757 + assert asarray((orig_X_val.tocsr() != adata.X.tocsr())).sum() == mul( + *adata_subset.shape + ) + else: + assert asarray((orig_X_val != adata.X)).sum() == mul(*adata_subset.shape) # TODO: Use different kind of subsetting for adata and view