Skip to content

Commit

Permalink
SNOW-1617634: Update large query breakdown to not break and duplicate…
Browse files Browse the repository at this point in the history
… CTEs (#2419)
  • Loading branch information
sfc-gh-aalam authored Oct 24, 2024
1 parent cd8f160 commit 0409865
Show file tree
Hide file tree
Showing 11 changed files with 371 additions and 134 deletions.
33 changes: 20 additions & 13 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections import UserDict, defaultdict
from copy import copy, deepcopy
from enum import Enum
from functools import cached_property
from functools import cached_property, reduce
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand All @@ -34,7 +34,10 @@
TableFunctionRelation,
)
from snowflake.snowpark._internal.analyzer.window_expression import WindowExpression
from snowflake.snowpark._internal.compiler.cte_utils import encode_node_id_with_query
from snowflake.snowpark._internal.compiler.cte_utils import (
encode_node_id_with_query,
merge_referenced_ctes,
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark.types import DataType

Expand Down Expand Up @@ -370,8 +373,9 @@ def column_states(self, value: ColumnStateDict):

@property
@abstractmethod
def referenced_ctes(self) -> Set[WithQueryBlock]:
"""Return the set of ctes referenced by the whole selectable subtree, includes its-self and children"""
def referenced_ctes(self) -> Dict[WithQueryBlock, int]:
"""Return the dict of ctes referenced by the whole selectable subtree and the
reference count of the cte. Includes itself and its children"""
pass


Expand Down Expand Up @@ -422,10 +426,10 @@ def query_params(self) -> Optional[Sequence[Any]]:
return None

@property
def referenced_ctes(self) -> Set[WithQueryBlock]:
def referenced_ctes(self) -> Dict[WithQueryBlock, int]:
# the SelectableEntity only allows select from base table. No
# CTE table will be referred.
return set()
return dict()


class SelectSQL(Selectable):
Expand Down Expand Up @@ -513,10 +517,10 @@ def to_subqueryable(self) -> "SelectSQL":
return new

@property
def referenced_ctes(self) -> Set[WithQueryBlock]:
def referenced_ctes(self) -> Dict[WithQueryBlock, int]:
# SelectSQL directly calls sql query, there will be no
# auto created CTE tables referenced
return set()
return dict()


class SelectSnowflakePlan(Selectable):
Expand Down Expand Up @@ -588,7 +592,7 @@ def reset_cumulative_node_complexity(self) -> None:
self.snowflake_plan.reset_cumulative_node_complexity()

@property
def referenced_ctes(self) -> Set[WithQueryBlock]:
def referenced_ctes(self) -> Dict[WithQueryBlock, int]:
return self._snowflake_plan.referenced_ctes


Expand Down Expand Up @@ -863,7 +867,7 @@ def cumulative_node_complexity(self, value: Dict[PlanNodeCategory, int]):
self._cumulative_node_complexity = value

@property
def referenced_ctes(self) -> Set[WithQueryBlock]:
def referenced_ctes(self) -> Dict[WithQueryBlock, int]:
return self.from_.referenced_ctes

def to_subqueryable(self) -> "Selectable":
Expand Down Expand Up @@ -1311,7 +1315,7 @@ def reset_cumulative_node_complexity(self) -> None:
self.snowflake_plan.reset_cumulative_node_complexity()

@property
def referenced_ctes(self) -> Set[WithQueryBlock]:
def referenced_ctes(self) -> Dict[WithQueryBlock, int]:
return self._snowflake_plan.referenced_ctes


Expand Down Expand Up @@ -1402,9 +1406,12 @@ def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return {PlanNodeCategory.SET_OPERATION: len(self.set_operands) - 1}

@property
def referenced_ctes(self) -> Set[WithQueryBlock]:
def referenced_ctes(self) -> Dict[WithQueryBlock, int]:
# get a union of referenced cte tables from all child nodes
return set().union(*[node.referenced_ctes for node in self._nodes])
# and sum up the reference counts
return reduce(
merge_referenced_ctes, [node.referenced_ctes for node in self._nodes]
)


class DeriveColumnDependencyError(Exception):
Expand Down
25 changes: 14 additions & 11 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
Expand Down Expand Up @@ -96,6 +95,7 @@
from snowflake.snowpark._internal.compiler.cte_utils import (
encode_node_id_with_query,
find_duplicate_subtrees,
merge_referenced_ctes,
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.utils import (
Expand Down Expand Up @@ -216,10 +216,10 @@ def __init__(
df_aliased_col_name_to_real_col_name: Optional[
DefaultDict[str, Dict[str, str]]
] = None,
# This field records all the CTE tables that are referred by the
# current SnowflakePlan tree. This is needed for the final query
# This field records all the WithQueryBlocks and their reference count that are
# referred by the current SnowflakePlan tree. This is needed for the final query
# generation to generate the correct sql query with CTE definition.
referenced_ctes: Optional[Set[WithQueryBlock]] = None,
referenced_ctes: Optional[Dict[WithQueryBlock, int]] = None,
*,
session: "snowflake.snowpark.session.Session",
) -> None:
Expand Down Expand Up @@ -248,8 +248,8 @@ def __init__(
# query, query parameters and the node type. We use this id for equality
# comparison to determine if two plans are the same.
self.encoded_node_id_with_query = encode_node_id_with_query(self)
self.referenced_ctes: Set[WithQueryBlock] = (
referenced_ctes.copy() if referenced_ctes else set()
self.referenced_ctes: Dict[WithQueryBlock, int] = (
referenced_ctes.copy() if referenced_ctes else dict()
)
self._cumulative_node_complexity: Optional[Dict[PlanNodeCategory, int]] = None
# UUID for the plan to uniquely identify the SnowflakePlan object. We also use this
Expand Down Expand Up @@ -588,7 +588,7 @@ def build_binary(
if post_action not in post_actions:
post_actions.append(copy.copy(post_action))

referenced_ctes: Set[WithQueryBlock] = set()
referenced_ctes: Dict[WithQueryBlock, int] = dict()
if (
self.session.cte_optimization_enabled
and self.session._query_compilation_stage_enabled
Expand All @@ -597,8 +597,9 @@ def build_binary(
# the referred cte tables are propagated from left and right can have
# duplicated queries if there is a common CTE block referenced by
# both left and right.
referenced_ctes.update(select_left.referenced_ctes)
referenced_ctes.update(select_right.referenced_ctes)
referenced_ctes = merge_referenced_ctes(
select_left.referenced_ctes, select_right.referenced_ctes
)

queries = merged_queries + [
Query(
Expand Down Expand Up @@ -1610,8 +1611,10 @@ def with_query_block(
# the query parameter will be propagate along with the definition during
# query generation stage.
queries = child.queries[:-1] + [Query(sql=new_query)]
# propagate the cte table
referenced_ctes = {with_query_block}.union(child.referenced_ctes)
# propagate the WithQueryBlock references
referenced_ctes = merge_referenced_ctes(
child.referenced_ctes, {with_query_block: 1}
)

return SnowflakePlan(
queries,
Expand Down
16 changes: 15 additions & 1 deletion src/snowflake/snowpark/_internal/compiler/cte_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import hashlib
import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Optional, Set
from typing import TYPE_CHECKING, Dict, Optional, Set

from snowflake.snowpark._internal.analyzer.snowflake_plan_node import WithQueryBlock
from snowflake.snowpark._internal.utils import is_sql_select_statement

if TYPE_CHECKING:
Expand Down Expand Up @@ -131,3 +132,16 @@ def encode_node_id_with_query(node: "TreeNode") -> str:
return f"{query_id}_{node_type_name}"
else:
return str(id(node))


def merge_referenced_ctes(
ref1: Dict[WithQueryBlock, int], ref2: Dict[WithQueryBlock, int]
) -> Dict[WithQueryBlock, int]:
"""Utility function to merge two referenced_cte dictionaries"""
merged = ref1.copy()
for with_query_block, value in ref2.items():
if with_query_block in merged:
merged[with_query_block] += value
else:
merged[with_query_block] = value
return merged
Loading

0 comments on commit 0409865

Please sign in to comment.