Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNOW-902943]: Add support for pd.NamedAgg in DataFrame and Series.agg #1652

Merged
merged 20 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@

- Added support for `if_not_exists` parameter during udf and sproc registration.

#### Bug Fixes

- Fixed a bug that causes output of GroupBy.aggregate's columns to be ordered incorrectly.
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved

#### Improvements

- Added partial support for `DataFrame.pivot_table` with no `index` parameter, as well as for `margins` parameter.
- Aligned error experience when calling udf and sprocs.
- Added appropriate error messages for is_permanent/anonymous udf/sproc registration to make it more clear that those features are not yet supported.
- Updated the signature of `DataFrame.shift`/`Series.shift`/`DataFrameGroupBy.shift`/`SeriesGroupBy.shift` to match pandas 2.2.1. Snowpark pandas does not yet support the newly-added `suffix` argument, or sequence values of `periods`.
- Added support for named aggregations in `DataFrame.aggregate` and `Series.aggregate`.
sfc-gh-yzou marked this conversation as resolved.
Show resolved Hide resolved
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved

### Snowpark Local Testing Updates

Expand Down
36 changes: 30 additions & 6 deletions src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from snowflake.snowpark.modin import pandas as pd
from snowflake.snowpark.modin.pandas.utils import (
_doc_binary_op,
extract_validate_and_try_convert_named_aggs_from_kwargs,
get_as_shape_compatible_dataframe_or_series,
is_scalar,
raise_if_native_pandas_objects,
Expand Down Expand Up @@ -703,16 +704,30 @@ def aggregate(
# native pandas raise error with message "no result", here we raise a more readable error.
raise ValueError("No column to aggregate on.")

func = validate_and_try_convert_agg_func_arg_func_to_str(
agg_func=func,
obj=self,
allow_duplication=False,
axis=axis,
)
if func is None:
if axis == 1:
raise ValueError(
"`func` must not be `None` when `axis=1`. Named aggregations are not supported with `axis=1`."
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
)
func = extract_validate_and_try_convert_named_aggs_from_kwargs(
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
self, allow_duplication=False, axis=axis, **kwargs
)
else:
func = validate_and_try_convert_agg_func_arg_func_to_str(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-joshi with the modin upstreaming work, I think we probably should start moving many of those conversion to backend, just leave basic checking at frontend

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, though our frontend implementation of aggregate already diverges pretty heavily from the modin version. We might want to use the extension API to overwrite aggregate when we start the migration.

agg_func=func,
obj=self,
allow_duplication=False,
axis=axis,
)

# This is to stay consistent with pandas result format, when the func is single
# aggregation function in format of callable or str, reduce the result dimension to
# convert dataframe to series, or convert series to scalar.
# Note: When named aggregations are used, the result is not reduced, even if there
# is only a single function.
# needs_reduce_dimension cannot be True if we are using named aggregations, since
# the values for func in that case are either NamedTuples (AggFuncWithLabels) or
# lists of NamedTuples, both of which are list like.
need_reduce_dimension = (
(callable(func) or isinstance(func, str))
# A Series should be returned when a single scalar string/function aggregation function, or a
Expand Down Expand Up @@ -768,7 +783,16 @@ def aggregate(
# >>> pd.DataFrame([[np.nan], [0]]).count(skipna=True, axis=0)
# TypeError: got an unexpected keyword argument 'skipna'
if is_dict_like(func):
order_of_aggregations = list(kwargs.keys())
formatted_kwargs = ", ".join(
[f"{key}={value}" for key, value in kwargs.items()]
)
kwargs.clear()
# Used to make error message formatting a little cleaner
# when using named aggregations.
kwargs["_formatted_named_kwargs"] = formatted_kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to pass that here, i thought you have a way to do it at the query compiler layer, like what you did in groupby not implemented error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in groupby, the agg_kwargs are passed directly to the query compiler, so we can do the formatting there. Here the kwargs are cleared and not passed to the QC, so we'd have to parse them out of the agg func - which means that they may not be in order in the error message, and the code to parse them looks a little messier.

# Used to correctly order the aggregations when using named aggregations.
kwargs["_correct_aggregation_order"] = order_of_aggregations

result = self.__constructor__(
query_compiler=self._query_compiler.agg(
Expand Down
32 changes: 31 additions & 1 deletion src/snowflake/snowpark/modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,37 @@ def aggregate(
"axis other than 0 is not supported"
) # pragma: no cover
if func is None:
# When func is None, we assume that the aggregation functions have been passed in via named aggregations,
# which can be of the form named_agg=('col_name', 'agg_func') or named_agg=pd.NamedAgg('col_name', 'agg_func').
# We need to parse out the following three things:
# 1. The new label to apply to the result of the aggregation.
# 2. The column to apply the aggregation over.
# 3. The aggregation to apply.
# This function checks that:
# 1. The kwargs contain named aggregations.
# 2. The kwargs do not contain anything besides named aggregations. (for pandas compatibility - see function for more details.)
# If both of these things are true, it then extracts the named aggregations from the kwargs, and returns a dictionary that contains
# a mapping from the column pandas labels to apply the aggregation over (2 above) to a tuple containing the aggregation to apply
# and the new label to assign it (1 and 3 above). Take for example, the following call:
# df.groupby(...).agg(new_col1=('A', 'min'), new_col2=('B', 'max'), new_col3=('A', 'max'))
# After this function returns, func will look like this:
# {
# "A": [AggFuncWithLabel(func="min", pandas_label="new_col1"), AggFuncWithLabel(func="max", pandas_label="new_col3")],
# "B": AggFuncWithLabel(func="max", pandas_label="new_col2")
# }
# This remapping causes an issue with ordering though - the dictionary above will be processed in the following order:
# 1. apply "min" to "A" and name it "new_col1"
# 2. apply "max" to "A" and name it "new_col3"
# 3. apply "max" to "B" and name it "new_col2"
# In other words - the order is slightly shifted so that named aggregations on the same column are contiguous in the ordering
# although the ordering of the kwargs is used to determine the ordering of named aggregations on the same columns. Since
# the reordering for groupby agg is a reordering of columns, its relatively cheap to do after the aggregation is over,
# rather than attempting to preserve the order of the named aggregations internally.
func = extract_validate_and_try_convert_named_aggs_from_kwargs(
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
obj=self, allow_duplication=True, axis=self._axis, **kwargs
obj=self,
allow_duplication=True,
axis=self._axis,
**kwargs,
)
else:
func = validate_and_try_convert_agg_func_arg_func_to_str(
Expand Down Expand Up @@ -615,6 +644,7 @@ def aggregate(
how="axis_wise",
is_result_dataframe=is_result_dataframe,
)

sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
return result

agg = aggregate
Expand Down
55 changes: 40 additions & 15 deletions src/snowflake/snowpark/modin/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,21 +553,47 @@ def extract_validate_and_try_convert_named_aggs_from_kwargs(
A dictionary mapping columns to a tuple containing the aggregation to perform, as well
as the pandas label to give the aggregated column.
"""
from snowflake.snowpark.modin.pandas import Series
from snowflake.snowpark.modin.pandas.groupby import SeriesGroupBy

is_series_like = isinstance(obj, (Series, SeriesGroupBy))
named_aggs = {}
accepted_keys = []
columns = obj._query_compiler.columns
for key, value in kwargs.items():
if isinstance(value, pd.NamedAgg) or (
isinstance(value, tuple) and len(value) == 2
):
if is_series_like:
# pandas does not allow pd.NamedAgg or 2-tuples for named aggregations
# when the base object is a Series, but has different errors depending
# on whether we are doing a Series.agg or Series.groupby.agg.
if isinstance(obj, Series):
raise SpecificationError("nested renamer is not supported")
else:
value_type_str = (
"NamedAgg" if isinstance(value, pd.NamedAgg) else "tuple"
)
raise TypeError(
f"func is expected but received {value_type_str} in **kwargs."
)
if axis == 0:
# If axis == 1, we would need a query to materialize the index to check its existence
# so we defer the error checking to later.
if value[0] not in columns:
raise KeyError(f"Column(s) ['{value[0]}'] do not exist")

# This function converts our named aggregations dictionary from a mapping of
# new_label -> tuple[column_name, agg_func] to a mapping of
# column_name -> tuple[agg_func, new_label] in order to process
sfc-gh-yzou marked this conversation as resolved.
Show resolved Hide resolved
# the aggregation functions internally. One issue with this is that the order
# of the named aggregations can change - say we have the following aggregations:
# {new_col: ('A', min), new_col1: ('B', max), new_col2: ('A', max)}
# The output of this function will look like this:
# {A: [AggFuncWithLabel(func=min, label=new_col), AggFuncWithLabel(func=max, label=new_col2)]
# B: AggFuncWithLabel(func=max, label=new_col1)}
# And so our final dataframe will have the wrong order. We handle the reordering of the generated
# labels at the QC layer.
if value[0] in named_aggs:
if not isinstance(named_aggs[value[0]], list):
named_aggs[value[0]] = [named_aggs[value[0]]]
Expand All @@ -577,8 +603,11 @@ def extract_validate_and_try_convert_named_aggs_from_kwargs(
else:
named_aggs[value[0]] = AggFuncWithLabel(func=value[1], pandas_label=key)
accepted_keys += [key]
elif isinstance(obj, SeriesGroupBy):
col_name = obj._df._query_compiler.columns[0]
elif is_series_like:
if isinstance(obj, SeriesGroupBy):
col_name = obj._df._query_compiler.columns[0]
else:
col_name = obj._query_compiler.columns[0]
if col_name not in named_aggs:
named_aggs[col_name] = AggFuncWithLabel(func=value, pandas_label=key)
else:
Expand All @@ -587,14 +616,15 @@ def extract_validate_and_try_convert_named_aggs_from_kwargs(
named_aggs[col_name] += [AggFuncWithLabel(func=value, pandas_label=key)]
accepted_keys += [key]

if len(named_aggs.keys()) == 0:
ErrorMessage.not_implemented(
"Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas"
)

if any(key not in accepted_keys for key in kwargs.keys()):
# For compatibility with pandas errors. Otherwise, we would just ignore
# those kwargs.
if len(named_aggs.keys()) == 0 or any(
key not in accepted_keys for key in kwargs.keys()
):
# First check makes sure that some functions have been passed. If nothing has been passed,
# we raise the TypeError.
# The second check is for compatibility with pandas errors. Say the user does something like this:
# df.agg(x=pd.NamedAgg('A', 'min'), random_extra_kwarg=14). pandas errors out, since func is None
# and not every kwarg is a named aggregation. Without this check explicitly, we would just ignore
# the extraneous kwargs, so we include this check for parity with pandas.
raise TypeError("Must provide 'func' or tuples of '(column, aggfunc).")

validated_named_aggs = {}
Expand Down Expand Up @@ -659,11 +689,6 @@ def validate_and_try_convert_agg_func_arg_func_to_str(
If nested dict configuration is used when agg_func is dict like or functions with duplicated names.

"""
if agg_func is None:
ErrorMessage.not_implemented(
"Must provide value for 'func' argument, func=None is currently not supported with Snowpark pandas"
)

if callable(agg_func):
result_agg_func = try_convert_builtin_func_to_str(agg_func, obj)
elif is_dict_like(agg_func):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,17 @@ def generate_column_agg_info(
new_data_column_index_names += [None]

return column_agg_ops, new_data_column_index_names


def using_named_aggregations_for_func(func: Any) -> bool:
"""
Helper method to check if func is formatted in a way that indicates that we are using named aggregations.
sfc-gh-rdurrani marked this conversation as resolved.
Show resolved Hide resolved
"""
return is_dict_like(func) and any(
isinstance(value, AggFuncWithLabel)
or (
isinstance(value, list)
and any(isinstance(v, AggFuncWithLabel) for v in value)
)
for value in func.values()
)
Loading
Loading