Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow different schema for tmp tables created during table materialization #664

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ athena:
- Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid
- Useful if you are looking to run multiple dbt build inserting in the same table in parallel
- `temp_schema` (`default=none`)
- For incremental models, it allows to define a schema to hold temporary create statements
used in incremental model runs
- For incremental and materialized table models, it allows to define a schema to hold temporary create statements
used in model runs
- Schema will be created in the model target database if does not exist
- `lf_tags_config` (`default=none`)
- [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns
Expand Down
3 changes: 2 additions & 1 deletion dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class AthenaConfig(AdapterConfig):
partitions_limit: Maximum numbers of partitions when batching.
force_batch: Skip creating the table as ctas and run the operation directly in batch insert mode.
unique_tmp_table_suffix: Enforce the use of a unique id as tmp table suffix instead of __dbt_tmp.
temp_schema: Define in which schema to create temporary tables used in incremental runs.
temp_schema: Define in which schema to create temporary tables used in incremental runs
and materialized tables.
"""

work_group: Optional[str] = None
Expand Down
8 changes: 8 additions & 0 deletions dbt/include/athena/macros/adapters/relation.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
{%- endcall %}
{%- endmacro %}

{% macro set_table_relation_schema(relation, schema) %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some comments to this method?

{%- if temp_schema is not none -%}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{%- if temp_schema is not none -%}
{%- if schema is not none -%}

or could you explain how do you get temp_schema here ?

{%- set relation = relation.incorporate(path={"schema": schema}) -%}
{%- do create_schema(relation) -%}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% convinced by this. pretty_much create_schema will be called by every model where we setup the temp_schema. It will be good too find a a better way to handle that.

{% endif %}
{{ return(relation) }}
{% endmacro %}

{% macro make_temp_relation(base_relation, suffix='__dbt_tmp', temp_schema=none) %}
{%- set temp_identifier = base_relation.identifier ~ suffix -%}
{%- set temp_relation = base_relation.incorporate(path={"identifier": temp_identifier}) -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@
{%- set lf_grants = config.get('lf_grants') -%}

{%- set table_type = config.get('table_type', default='hive') | lower -%}
{%- set temp_schema = config.get('temp_schema') -%}
nicor88 marked this conversation as resolved.
Show resolved Hide resolved
{%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%}
{%- set old_tmp_relation = adapter.get_relation(identifier=identifier ~ '__ha',
schema=schema,
database=database) -%}
{%- if temp_schema is not none and old_tmp_relation is not none-%}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully get this check, why do you need to include a check for old_tmp_relation not being none here? If old_tmp_relation is none, we still want to create the new tmp table in the tmp schema.

{%- set old_tmp_relation = set_table_relation_schema(relation=old_tmp_relation, schema=temp_schema) -%}
{%- endif -%}
{%- set old_bkp_relation = adapter.get_relation(identifier=identifier ~ '__bkp',
schema=schema,
database=database) -%}
{%- if temp_schema is not none and old_bkp_relation is not none-%}
{%- set old_bkp_relation = set_table_relation_schema(relation=old_bkp_relation, schema=temp_schema) -%}
{%- endif -%}
{%- set is_ha = config.get('ha', default=false) -%}
{%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%}
{%- set s3_data_naming = config.get('s3_data_naming', default='table_unique') -%}
Expand All @@ -31,6 +38,9 @@
database=database,
s3_path_table_part=target_relation.identifier,
type='table') -%}
{%- if temp_schema is not none -%}
{%- set tmp_relation = set_table_relation_schema(relation=tmp_relation, schema=temp_schema) -%}
{%- endif -%}

{%- if (
table_type == 'hive'
Expand Down Expand Up @@ -137,6 +147,9 @@
-- we cannot use old_bkp_relation, because it returns None if the relation doesn't exist
-- we need to create a python object via the make_temp_relation instead
{%- set old_relation_bkp = make_temp_relation(old_relation, '__bkp') -%}
{%- if temp_schema is not none -%}
{%- set old_relation_bkp = set_table_relation_schema(relation=old_relation_bkp, schema=temp_schema) -%}
{%- endif -%}

{%- if old_relation_table_type == 'iceberg_table' -%}
{{ rename_relation(old_relation, old_relation_bkp) }}
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/adapter/test_incremental_tmp_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import yaml
from tests.functional.adapter.utils.parse_dbt_run_output import (
extract_create_statement_table_names,
extract_running_create_statements,
extract_running_ddl_statements,
)

from dbt.contracts.results import RunStatus
Expand Down Expand Up @@ -61,7 +61,7 @@ def test__schema_tmp(self, project, capsys):
assert records_count_first_run == 1

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

assert len(athena_running_create_statements) == 1

Expand Down Expand Up @@ -95,7 +95,7 @@ def test__schema_tmp(self, project, capsys):
assert records_count_incremental_run == 2

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

assert len(athena_running_create_statements) == 1

Expand Down
132 changes: 132 additions & 0 deletions tests/functional/adapter/test_table_temp_schema_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest
import yaml
from tests.functional.adapter.utils.parse_dbt_run_output import (
extract_create_statement_table_names,
extract_rename_statement_table_names,
extract_running_ddl_statements,
)

from dbt.contracts.results import RunStatus
from dbt.tests.util import run_dbt

models__iceberg_table = """
{{ config(
materialized='table',
table_type='iceberg',
temp_schema=var('temp_schema_name'),
)
}}

select
{{ var('test_id') }} as id
"""


class TestTableIcebergTableUnique:
@pytest.fixture(scope="class")
def models(self):
return {"models__iceberg_table.sql": models__iceberg_table}

def test__temp_schema_name_iceberg_table(self, project, capsys):
relation_name = "models__iceberg_table"
temp_schema_name = f"{project.test_schema}_tmp"
drop_temp_schema = f"drop schema if exists `{temp_schema_name}` cascade"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"
model_run_result_test_id_query = f"select id from {project.test_schema}.{relation_name}"

vars_dict = {
"temp_schema_name": temp_schema_name,
"test_id": 1,
}

model_run = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
yaml.safe_dump(vars_dict),
"--log-level",
"debug",
"--log-format",
"json",
]
)

model_run_result = model_run.results[0]
assert model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")
assert len(athena_running_create_statements) == 1

incremental_model_run_result_table_name = extract_create_statement_table_names(
athena_running_create_statements[0]
)[0]
assert temp_schema_name not in incremental_model_run_result_table_name

model_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
assert model_records_count == 1

model_test_id_in_table = project.run_sql(model_run_result_test_id_query, fetch="all")[0][0]
assert model_test_id_in_table == 1

vars_dict["test_id"] = 2

model_run = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
yaml.safe_dump(vars_dict),
"--log-level",
"debug",
"--log-format",
"json",
]
)

model_run_result = model_run.results[0]
assert model_run_result.status == RunStatus.Success

model_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
assert model_records_count == 1

model_test_id_in_table = project.run_sql(model_run_result_test_id_query, fetch="all")[0][0]
assert model_test_id_in_table == 2

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")
assert len(athena_running_create_statements) == 1

model_run_2_result_table_name = extract_create_statement_table_names(athena_running_create_statements[0])[0]
assert temp_schema_name in model_run_2_result_table_name

athena_running_alter_statements = extract_running_ddl_statements(out, relation_name, "alter table")
"""2 rename statements: to __bck and from __ha and
# ['alter table `..._test_table_temp_schema_name`.`models__iceberg_table`
rename to `..._test_table_temp_schema_name_tmp`.`models__iceberg_table__bkp`'
, 'alter table ..._test_table_temp_schema_name_tmp`.`models__iceberg_table__ha`
rename to `..._test_table_temp_schema_name`.`models__iceberg_table`']
"""
assert len(athena_running_alter_statements) == 2

athena_running_alter_statement_tables = extract_rename_statement_table_names(athena_running_alter_statements[0])
athena_running_alter_statement_origin_table = athena_running_alter_statement_tables.get("alter_table_names")[0]
athena_running_alter_statement_renamed_to_table = athena_running_alter_statement_tables.get(
"rename_to_table_names"
)[0]
assert project.test_schema in athena_running_alter_statement_origin_table
assert athena_running_alter_statement_renamed_to_table == f"`{temp_schema_name}`.`{relation_name}__bkp`"

athena_running_alter_statement_tables = extract_rename_statement_table_names(athena_running_alter_statements[1])
athena_running_alter_statement_origin_table = athena_running_alter_statement_tables.get("alter_table_names")[0]
athena_running_alter_statement_renamed_to_table = athena_running_alter_statement_tables.get(
"rename_to_table_names"
)[0]

assert temp_schema_name in athena_running_alter_statement_origin_table
assert athena_running_alter_statement_renamed_to_table == f"`{project.test_schema}`.`{relation_name}`"

project.run_sql(drop_temp_schema)
8 changes: 4 additions & 4 deletions tests/functional/adapter/test_unique_tmp_table_suffix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from tests.functional.adapter.utils.parse_dbt_run_output import (
extract_create_statement_table_names,
extract_running_create_statements,
extract_running_ddl_statements,
)

from dbt.contracts.results import RunStatus
Expand Down Expand Up @@ -55,7 +55,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert first_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

assert len(athena_running_create_statements) == 1

Expand Down Expand Up @@ -87,7 +87,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert incremental_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

assert len(athena_running_create_statements) == 1

Expand Down Expand Up @@ -119,7 +119,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert incremental_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

incremental_model_run_result_table_name_2 = extract_create_statement_table_names(
athena_running_create_statements[0]
Expand Down
26 changes: 18 additions & 8 deletions tests/functional/adapter/utils/parse_dbt_run_output.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import re
from typing import List
from typing import Dict, List


def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: str) -> List[str]:
sql_create_statements = []
def extract_running_ddl_statements(dbt_run_capsys_output: str, relation_name: str, ddl_type: str) -> List[str]:
sql_ddl_statements = []
# Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..."
for events_msg in dbt_run_capsys_output.split("\n")[1:]:
base_msg_data = None
Expand All @@ -21,16 +21,26 @@ def extract_running_create_statements(dbt_run_capsys_output: str, relation_name:
if base_msg_data:
base_msg = base_msg_data.get("base_msg")
if "Running Athena query:" in str(base_msg):
if "create table" in base_msg:
sql_create_statements.append(base_msg)
if ddl_type in base_msg:
sql_ddl_statements.append(base_msg)

if base_msg_data.get("conn_name") == f"model.test.{relation_name}" and "sql" in base_msg_data:
if "create table" in base_msg_data.get("sql"):
sql_create_statements.append(base_msg_data.get("sql"))
if ddl_type in base_msg_data.get("sql"):
sql_ddl_statements.append(base_msg_data.get("sql"))

return sql_create_statements
return sql_ddl_statements


def extract_create_statement_table_names(sql_create_statement: str) -> List[str]:
table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement)
return [table_name.rstrip() for table_name in table_names]


def extract_rename_statement_table_names(sql_alter_rename_statement: str) -> Dict[str, List[str]]:
alter_table_names = re.findall(r"(?s)(?<=alter table ).*?(?= rename)", sql_alter_rename_statement)
rename_to_table_names = re.findall(r"(?s)(?<=rename to ).*?(?=$)", sql_alter_rename_statement)

return {
"alter_table_names": [table_name.rstrip() for table_name in alter_table_names],
"rename_to_table_names": [table_name.rstrip() for table_name in rename_to_table_names],
}
Loading