diff --git a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py index b1ba815e5a6..76adec8c491 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py @@ -831,16 +831,16 @@ def pivot( See detailed docstring in Snowpark DataFrame's pivot. """ snowpark_dataframe = self.to_projected_snowpark_dataframe() + pivot_snowpark_dataframe = snowpark_dataframe.pivot( + pivot_col=pivot_col, + values=values, + default_on_null=default_on_null, + ).agg(*agg_exprs) + cached_snowpark_dataframe = pivot_snowpark_dataframe.cache_result() return OrderedDataFrame( # the pivot result columns for dynamic pivot are data dependent, a schema call is required # to know all the quoted identifiers for the pivot result. - DataFrameReference( - snowpark_dataframe.pivot( - pivot_col=pivot_col, - values=values, - default_on_null=default_on_null, - ).agg(*agg_exprs) - ) + DataFrameReference(cached_snowpark_dataframe) ) def unpivot( diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 2aa286d189e..0fd989c3133 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -6088,14 +6088,27 @@ def get_dummies( "get_dummies with non-default dummy_na, drop_first, and dtype parameters" + " is not supported yet in Snowpark pandas." ) + if columns is None: - columns = [ - col_name - for (col_index, col_name) in enumerate( - self._modin_frame.data_column_pandas_labels - ) - if is_series or is_string_dtype(self.dtypes[col_index]) - ] + if is_series: + columns = self._modin_frame.data_column_pandas_labels + else: + df_dtypes = self.dtypes.to_numpy() + columns = [ + col_name + for (col_index, col_name) in enumerate( + self._modin_frame.data_column_pandas_labels + ) + if is_string_dtype(df_dtypes[col_index]) + ] + + # columns = [ + # col_name + # for (col_index, col_name) in enumerate( + # self._modin_frame.data_column_pandas_labels + # ) + # if is_series or is_string_dtype(self.dtypes[col_index]) + # ] if not isinstance(columns, list): columns = [columns] @@ -6109,13 +6122,14 @@ def get_dummies( ) if prefix is None and not is_series: + df_dtypes = self.dtypes.to_numpy() prefix = [ col_name for (col_index, col_name) in enumerate( self._modin_frame.data_column_pandas_labels ) if self._modin_frame.is_unnamed_series() - or is_string_dtype(self.dtypes[col_index]) + or is_string_dtype(df_dtypes[col_index]) ] if not isinstance(prefix, list): diff --git a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py index c751b4fe550..13aa5e328b8 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py +++ b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py @@ -89,10 +89,16 @@ def is_same_shape( if is_scalar(x): # broadcast scalar x to size of cond - object_shape = cond.shape - if len(object_shape) == 1: - df_scalar = pd.Series(x, index=range(object_shape[0])) - elif len(object_shape) == 2: + n_dim = cond.ndim + if n_dim == 1: + df_cond = cond.to_frame() + df_cond["new_value"] = x + df_scalar = df_cond["new_value"] + original_cond_column = df_cond.columns[0] + cond = df_cond[original_cond_column] + # df_scalar = pd.Series(x, index=range(object_shape[0])) + elif n_dim == 2: + object_shape = cond.shape df_scalar = pd.DataFrame( x, index=range(object_shape[0]), columns=range(object_shape[1]) ) diff --git a/tests/integ/modin/series/test_where.py b/tests/integ/modin/series/test_where.py index 3e0cffd263b..9c1b40b2bcf 100644 --- a/tests/integ/modin/series/test_where.py +++ b/tests/integ/modin/series/test_where.py @@ -196,7 +196,7 @@ def test_series_where_with_lambda_cond_returns_singleton_should_fail(): @pytest.mark.parametrize( "other, sql_count, join_count", - [(lambda x: -x.iloc[0], 4, 6), (lambda x: x**2, 3, 6)], + [(lambda x: -x.iloc[0], 4, 7), (lambda x: x**2, 3, 8)], ) def test_series_where_with_lambda_other(other, sql_count, join_count): # High join count due to creatinga Series with non-Snowpark pandas data @@ -313,3 +313,10 @@ def perform_where(series): native_ser, perform_where, ) + + +def test_scalar(): + df = pd.DataFrame({"A": [True, False, True], "B": [1, 2, 3]}) + print("\n\n") + result = np.where(df["A"], 1, 2) + print(result)