Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aggresive cache pivot #2543

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -831,16 +831,16 @@ def pivot(
See detailed docstring in Snowpark DataFrame's pivot.
"""
snowpark_dataframe = self.to_projected_snowpark_dataframe()
pivot_snowpark_dataframe = snowpark_dataframe.pivot(
pivot_col=pivot_col,
values=values,
default_on_null=default_on_null,
).agg(*agg_exprs)
cached_snowpark_dataframe = pivot_snowpark_dataframe.cache_result()
return OrderedDataFrame(
# the pivot result columns for dynamic pivot are data dependent, a schema call is required
# to know all the quoted identifiers for the pivot result.
DataFrameReference(
snowpark_dataframe.pivot(
pivot_col=pivot_col,
values=values,
default_on_null=default_on_null,
).agg(*agg_exprs)
)
DataFrameReference(cached_snowpark_dataframe)
)

def unpivot(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6088,14 +6088,27 @@ def get_dummies(
"get_dummies with non-default dummy_na, drop_first, and dtype parameters"
+ " is not supported yet in Snowpark pandas."
)

if columns is None:
columns = [
col_name
for (col_index, col_name) in enumerate(
self._modin_frame.data_column_pandas_labels
)
if is_series or is_string_dtype(self.dtypes[col_index])
]
if is_series:
columns = self._modin_frame.data_column_pandas_labels
else:
df_dtypes = self.dtypes.to_numpy()
columns = [
col_name
for (col_index, col_name) in enumerate(
self._modin_frame.data_column_pandas_labels
)
if is_string_dtype(df_dtypes[col_index])
]

# columns = [
# col_name
# for (col_index, col_name) in enumerate(
# self._modin_frame.data_column_pandas_labels
# )
# if is_series or is_string_dtype(self.dtypes[col_index])
# ]

if not isinstance(columns, list):
columns = [columns]
Expand All @@ -6109,13 +6122,14 @@ def get_dummies(
)

if prefix is None and not is_series:
df_dtypes = self.dtypes.to_numpy()
prefix = [
col_name
for (col_index, col_name) in enumerate(
self._modin_frame.data_column_pandas_labels
)
if self._modin_frame.is_unnamed_series()
or is_string_dtype(self.dtypes[col_index])
or is_string_dtype(df_dtypes[col_index])
]

if not isinstance(prefix, list):
Expand Down
14 changes: 10 additions & 4 deletions src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,16 @@ def is_same_shape(

if is_scalar(x):
# broadcast scalar x to size of cond
object_shape = cond.shape
if len(object_shape) == 1:
df_scalar = pd.Series(x, index=range(object_shape[0]))
elif len(object_shape) == 2:
n_dim = cond.ndim
if n_dim == 1:
df_cond = cond.to_frame()
df_cond["new_value"] = x
df_scalar = df_cond["new_value"]
original_cond_column = df_cond.columns[0]
cond = df_cond[original_cond_column]
# df_scalar = pd.Series(x, index=range(object_shape[0]))
elif n_dim == 2:
object_shape = cond.shape
df_scalar = pd.DataFrame(
x, index=range(object_shape[0]), columns=range(object_shape[1])
)
Expand Down
9 changes: 8 additions & 1 deletion tests/integ/modin/series/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_series_where_with_lambda_cond_returns_singleton_should_fail():

@pytest.mark.parametrize(
"other, sql_count, join_count",
[(lambda x: -x.iloc[0], 4, 6), (lambda x: x**2, 3, 6)],
[(lambda x: -x.iloc[0], 4, 7), (lambda x: x**2, 3, 8)],
)
def test_series_where_with_lambda_other(other, sql_count, join_count):
# High join count due to creatinga Series with non-Snowpark pandas data
Expand Down Expand Up @@ -313,3 +313,10 @@ def perform_where(series):
native_ser,
perform_where,
)


def test_scalar():
df = pd.DataFrame({"A": [True, False, True], "B": [1, 2, 3]})
print("\n\n")
result = np.where(df["A"], 1, 2)
print(result)
Loading