Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pint arrays support #26

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open

Conversation

kadykov
Copy link
Contributor

@kadykov kadykov commented Jan 13, 2023

This commit fixes the following issue of the wrong behavior of is_cupy property with pint xarrays:

import xarray as xr
import cupy as cp
import pint_xarray
import cupy_xarray
from pint_xarray import unit_registry as ureg

cp_da = xr.DataArray(
    cp.linspace(0, 1, 11),
    coords=dict(
        x=np.linspace(0, 10, 11),
    ),
)

pint_da = cp_da.pint.quantify("meter")

print(f"is_cupy response: {pint_da.cupy.is_cupy}")
print(f"pint_da type {type(pint_da.pint.magnitude)}")

Currently, the output is:

is_cupy response: False
pint_da type <class 'cupy.ndarray'>

This commit fix the issue where `is_cupy` shows wrong status for pint arrays on GPU
@kadykov kadykov changed the title Kadykov pint Fix is_cupy property for pint xarrays Jan 13, 2023
@dcherian
Copy link
Contributor

Thanks!

Can you add a test please? You'll need to have pint as an optional dependency. SO we'll need to copy some of this code from Xarray: https://github.com/pydata/xarray/blob/6c5840e1198707cdcf7dc459f27ea9510eb76388/xarray/tests/__init__.py#L83

The test should look like:

@requires_pint
def test_is_cupy_pint():
	pass

cupy_xarray/accessors.py Outdated Show resolved Hide resolved
cupy_xarray/accessors.py Outdated Show resolved Hide resolved
cupy_xarray/accessors.py Outdated Show resolved Hide resolved
@kadykov kadykov changed the title Fix is_cupy property for pint xarrays Add pint arrays support Jan 24, 2023
@negin513 negin513 self-requested a review February 17, 2023 22:46
cupy_xarray/accessors.py Outdated Show resolved Hide resolved
@@ -1,10 +1,18 @@
import numpy as np
import pytest
import xarray as xr
from xarray.core.pycompat import dask_array_type
from xarray.core.pycompat import DuckArrayModule
from xarray.tests import requires_pint
Copy link
Contributor

@dcherian dcherian Feb 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's copy this over to tests/__init__.py from xarray and similarly from DuckArrayModule



@requires_pint
def test_data_array_accessor_pint(tutorial_da_air_pint):
Copy link
Contributor

@negin513 negin513 Mar 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks this case:

pint + xr.DataArray --> cupy --> numpy



@requires_pint
def test_data_array_accessor_pint_dask(tutorial_da_air_pint_dask):
Copy link
Contributor

@negin513 negin513 Mar 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test checks the following:

 pint + Dask.DataArray --> cupy --> numpy

Copy link
Contributor

@negin513 negin513 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pull request addresses an issue where the is_cupy property returns an incorrect value (False) for pint xarrays that are backed by CuPy arrays. The proposed modifications effectively resolve this issue and the changes are clear. Great job. 👍

Recommendations
I have two recommendations regarding the tests:

  • To ensure all the edge cases are adequately resolved and to prevent potential future problems, please ensure that the test cases cover all relevant edge cases and scenarios, as this will help to guarantee the robustness of the implemented fix. I added more comments to go over this.

  • I suggest improving the naming of test functions in the accessors module to better reflect their purpose and the functionality being tested. I also want to recommend adding docstrings to each test function, explaining the specific case or scenario that each test covers. Currently the name of the tests (and tests) are very confusing and this caused me some confusions on which cases are covered.

The followings are the lists of tests of their purpose:

Test Function Name Description
test_data_set_accessor xr.dataset --> cupy ---> np
test_data_array_accessor xr.DataArray --> cupy --> np
test_data_array_accessor_dask dask.DataArray --> cupy --> np
test_data_array_accessor_pint pint + xr.DataArray --> cupy --> np
test_data_array_accessor_pint_dask pint + Dask.DataArray --> cupy --> np

For example, the name test_xr_dataset_accessor might be a better name than test_data_set_accessor. This way we can have test_dask_dataarray_accessor , etc.

I also suggest covering the following cases that are not covered:
pint + xr.dataset --> cupy --> np

We have tutorial_ds_air_pint already defined above that can be used in this test.

@kadykov
Copy link
Contributor Author

kadykov commented Apr 5, 2023

Thank you for your comments and for the good review of the changes.
I re-implemented the tests with the clearer naming and wider coverage of input data.

@kadykov
Copy link
Contributor Author

kadykov commented May 12, 2023

Also closes this issue #31

@negin513
Copy link
Contributor

Hey @dcherian , Do you have any additional comments for this PR? I think we can merge this PR.

@dcherian
Copy link
Contributor

Sorry for the massive delay here. @kadykov

I cleaned up the combinatorial explosion of variables ;) but I don't have a machine to actually run the tests on. Can either you or @negin513 confirm that the tests work please?

@dcherian
Copy link
Contributor

Actually I have access now. fixing things...

Comment on lines +55 to +58
as_dask = as_pint.chunk()
if isinstance(as_dask, xr.DataArray):
assert isinstance(as_dask.data, pint.Quantity)
assert isinstance(as_dask.data.magnitude._meta, np.ndarray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keewis is this expectation right for DataArray(pint(np.ndarray)).chunk()

Copy link
Contributor

@keewis keewis Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be aware that DataArray.chunk() simply passes the wrapped duck array to dask.array.from_array if it was not already a dask collection, which means that it would produce a dask(pint(np.ndarray)). This is not what we want, so there's DataArray.pint.chunk to work around this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants