Skip to content

Commit

Permalink
[SNOW-902943]: Add support for pd.NamedAgg in DataFrame and Series.agg
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-rdurrani committed May 22, 2024
1 parent 28cc324 commit 17019e9
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 39 deletions.
26 changes: 19 additions & 7 deletions src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from snowflake.snowpark.modin import pandas as pd
from snowflake.snowpark.modin.pandas.utils import (
_doc_binary_op,
extract_validate_and_try_convert_named_aggs_from_kwargs,
get_as_shape_compatible_dataframe_or_series,
is_scalar,
raise_if_native_pandas_objects,
Expand Down Expand Up @@ -703,17 +704,28 @@ def aggregate(
# native pandas raise error with message "no result", here we raise a more readable error.
raise ValueError("No column to aggregate on.")

func = validate_and_try_convert_agg_func_arg_func_to_str(
agg_func=func,
obj=self,
allow_duplication=False,
axis=axis,
)
if func is None:
if axis == 1:
raise ValueError(
"`func` must not be `None` when `axis=1`. Named aggregations are not supported with `axis=1`."
)
func = extract_validate_and_try_convert_named_aggs_from_kwargs(
self, allow_duplication=False, axis=axis, **kwargs
)
else:
func = validate_and_try_convert_agg_func_arg_func_to_str(
agg_func=func,
obj=self,
allow_duplication=False,
axis=axis,
)

# This is to stay consistent with pandas result format, when the func is single
# aggregation function in format of callable or str, reduce the result dimension to
# convert dataframe to series, or convert series to scalar.
need_reduce_dimension = (
# Note: When named aggregations are used, the result is not reduced, even if there
# is only a single function.
need_reduce_dimension = func is not None and (
(callable(func) or isinstance(func, str))
# A Series should be returned when a single scalar string/function aggregation function, or a
# dict of scalar string/functions is specified. In all other cases (including if the function
Expand Down
35 changes: 35 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,41 @@ def get_agg_func_to_col_map(
return agg_func_to_col_map


def get_pandas_label_to_agg_func_info_map(
col_to_agg_func_map: dict[
PandasLabelToSnowflakeIdentifierPair,
Union[AggFuncWithLabel, list[AggFuncWithLabel]],
]
) -> dict[Hashable, dict[PandasLabelToSnowflakeIdentifierPair, AggFuncInfo]]:
pandas_label_to_agg_func_info_map: dict[
Hashable, dict[PandasLabelToSnowflakeIdentifierPair, AggFuncInfo]
] = {}

def _extract_agg_func_info_from_agg_func_with_label_and_put_in_dictionary(
key: PandasLabelToSnowflakeIdentifierPair, agg_func: AggFuncWithLabel
) -> None:
nonlocal pandas_label_to_agg_func_info_map
pandas_label = agg_func.pandas_label
func = agg_func.func
pandas_label_to_agg_func_info_map[pandas_label] = {
key: AggFuncInfo(func=func, is_dummy_agg=False)
}

for column_label_to_identifier_pair, agg_funcs in col_to_agg_func_map.items():
if is_list_like(agg_funcs) and not is_named_tuple(agg_funcs):
# Here we are dealing with multiple aggregation functions applied to the same column.
for agg_func in agg_funcs:
_extract_agg_func_info_from_agg_func_with_label_and_put_in_dictionary(
column_label_to_identifier_pair, agg_func
)
else:
_extract_agg_func_info_from_agg_func_with_label_and_put_in_dictionary(
column_label_to_identifier_pair, agg_funcs # type: ignore[arg-type]
)

return pandas_label_to_agg_func_info_map


def get_pandas_aggr_func_name(aggfunc: AggFuncTypeBase) -> str:
"""
Returns the friendly name for the aggr function. For example, if it is a callable, it will return __name__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
generate_rowwise_aggregation_function,
get_agg_func_to_col_map,
get_pandas_aggr_func_name,
get_pandas_label_to_agg_func_info_map,
get_snowflake_agg_func,
)
from snowflake.snowpark.modin.plugin._internal.apply_utils import (
Expand Down Expand Up @@ -4250,27 +4251,6 @@ def agg(
pandas_labels_for_columns_to_exclude_when_agg_on_all=[],
)

# get a map between each aggregation function and the columns needs to apply this aggregation function
agg_func_to_col_map = get_agg_func_to_col_map(column_to_agg_func)

# aggregation creates an index column with the aggregation function names as its values
# For example: with following dataframe
# A B C
# 0 1 2 3
# 1 4 5 6
# 2 7 8 9
# after we call df.aggregate({"A": ["min"], "B": ["max"]}), the result is following
# A B
# min 1 NaN
# max NaN 8
#
# However, if all values in the agg_func dict are scalar strings/functions rather than lists,
# then the result will instead be a Series:
# >>> df.aggregate({"A": "min", "B": "max"})
# 0 1
# 1 8
# dtype: int64

# generate the quoted identifier for the aggregation function name column
agg_name_col_quoted_identifier = (
internal_frame.ordered_dataframe.generate_snowflake_quoted_identifiers(
Expand Down Expand Up @@ -4408,19 +4388,69 @@ def generate_agg_qc(
)

else:
for agg_func, cols in agg_func_to_col_map.items():
col_single_agg_func_map = {
column: AggFuncInfo(
func=agg_func if column in cols else "min",
is_dummy_agg=column not in cols,
if any(
isinstance(agg_func, AggFuncWithLabel)
or (
is_list_like(agg_func)
and isinstance(agg_func[0], AggFuncWithLabel)
)
for agg_func in column_to_agg_func.values()
):
# If this is true, then we are dealing with agg with NamedAggregations. In that case,
# `column_to_agg_func` looks like this:
# {PandasLabelToSnowflakeIdentifierPair(pandas_label='A', snowflake_quoted_identifier='"A"'):
# [AggFuncWithLabel(func='max', pandas_label='x'), AggFuncWithLabel(func='min', pandas_label='y')]}
# We need to go throuh and apply each of the aggregations to the corresponding column, and apply
# the correct index value. First, though, we can convert the dictionary to a different format for
# ease of code.
pandas_label_to_agg_func_info_map = (
get_pandas_label_to_agg_func_info_map(column_to_agg_func)
)
# Now, we have a mapping from the new index value to the dictionary mapping the
# PandasLabelToSnowflakeIdentifierPair to the AggFuncInfo object that we can pass in
# to generate_agg_qc.
for (
new_index_label,
agg_func_info_dict,
) in pandas_label_to_agg_func_info_map.items():
single_agg_func_query_compilers.append(
generate_agg_qc(agg_func_info_dict, new_index_label)
)
for column in column_to_agg_func.keys()
}
single_agg_func_query_compilers.append(
generate_agg_qc(
col_single_agg_func_map, get_pandas_aggr_func_name(agg_func)
else:
# get a map between each aggregation function and the columns needs to apply this aggregation function
agg_func_to_col_map = get_agg_func_to_col_map(column_to_agg_func)

# aggregation creates an index column with the aggregation function names as its values
# For example: with following dataframe
# A B C
# 0 1 2 3
# 1 4 5 6
# 2 7 8 9
# after we call df.aggregate({"A": ["min"], "B": ["max"]}), the result is following
# A B
# min 1 NaN
# max NaN 8
#
# However, if all values in the agg_func dict are scalar strings/functions rather than lists,
# then the result will instead be a Series:
# >>> df.aggregate({"A": "min", "B": "max"})
# 0 1
# 1 8
# dtype: int64
for agg_func, cols in agg_func_to_col_map.items():
col_single_agg_func_map = {
column: AggFuncInfo(
func=agg_func if column in cols else "min",
is_dummy_agg=column not in cols,
)
for column in column_to_agg_func.keys()
}
single_agg_func_query_compilers.append(
generate_agg_qc(
col_single_agg_func_map,
get_pandas_aggr_func_name(agg_func),
)
)
)

assert single_agg_func_query_compilers, "no aggregation result"
if len(single_agg_func_query_compilers) == 1:
Expand Down

0 comments on commit 17019e9

Please sign in to comment.