diff --git a/README.md b/README.md index ff811492..341f4aef 100644 --- a/README.md +++ b/README.md @@ -215,8 +215,8 @@ athena: - Replace the __dbt_tmp suffix used as temporary table name suffix by a unique uuid - Useful if you are looking to run multiple dbt build inserting in the same table in parallel - `temp_schema` (`default=none`) - - For incremental models, it allows to define a schema to hold temporary create statements - used in incremental model runs + - For incremental and materialized table models, it allows to define a schema to hold temporary create statements + used in model runs - Schema will be created in the model target database if does not exist - `lf_tags_config` (`default=none`) - [AWS Lake Formation](#aws-lake-formation-integration) tags to associate with the table and columns diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index ea422e02..93fa763d 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -98,7 +98,8 @@ class AthenaConfig(AdapterConfig): partitions_limit: Maximum numbers of partitions when batching. force_batch: Skip creating the table as ctas and run the operation directly in batch insert mode. unique_tmp_table_suffix: Enforce the use of a unique id as tmp table suffix instead of __dbt_tmp. - temp_schema: Define in which schema to create temporary tables used in incremental runs. + temp_schema: Define in which schema to create temporary tables used in incremental runs + and materialized tables. """ work_group: Optional[str] = None diff --git a/dbt/include/athena/macros/adapters/relation.sql b/dbt/include/athena/macros/adapters/relation.sql index 611ffc59..ad5671f1 100644 --- a/dbt/include/athena/macros/adapters/relation.sql +++ b/dbt/include/athena/macros/adapters/relation.sql @@ -36,6 +36,14 @@ {%- endcall %} {%- endmacro %} +{% macro set_table_relation_schema(relation, schema) %} + {%- if temp_schema is not none -%} + {%- set relation = relation.incorporate(path={"schema": schema}) -%} + {%- do create_schema(relation) -%} + {% endif %} + {{ return(relation) }} +{% endmacro %} + {% macro make_temp_relation(base_relation, suffix='__dbt_tmp', temp_schema=none) %} {%- set temp_identifier = base_relation.identifier ~ suffix -%} {%- set temp_relation = base_relation.incorporate(path={"identifier": temp_identifier}) -%} diff --git a/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt/include/athena/macros/materializations/models/table/table.sql index 1f94361c..6970d916 100644 --- a/dbt/include/athena/macros/materializations/models/table/table.sql +++ b/dbt/include/athena/macros/materializations/models/table/table.sql @@ -7,13 +7,20 @@ {%- set lf_grants = config.get('lf_grants') -%} {%- set table_type = config.get('table_type', default='hive') | lower -%} + {%- set temp_schema = config.get('temp_schema') -%} {%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} {%- set old_tmp_relation = adapter.get_relation(identifier=identifier ~ '__ha', schema=schema, database=database) -%} + {%- if temp_schema is not none and old_tmp_relation is not none-%} + {%- set old_tmp_relation = set_table_relation_schema(relation=old_tmp_relation, schema=temp_schema) -%} + {%- endif -%} {%- set old_bkp_relation = adapter.get_relation(identifier=identifier ~ '__bkp', schema=schema, database=database) -%} + {%- if temp_schema is not none and old_bkp_relation is not none-%} + {%- set old_bkp_relation = set_table_relation_schema(relation=old_bkp_relation, schema=temp_schema) -%} + {%- endif -%} {%- set is_ha = config.get('ha', default=false) -%} {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} {%- set s3_data_naming = config.get('s3_data_naming', default='table_unique') -%} @@ -31,6 +38,9 @@ database=database, s3_path_table_part=target_relation.identifier, type='table') -%} + {%- if temp_schema is not none -%} + {%- set tmp_relation = set_table_relation_schema(relation=tmp_relation, schema=temp_schema) -%} + {%- endif -%} {%- if ( table_type == 'hive' @@ -137,6 +147,9 @@ -- we cannot use old_bkp_relation, because it returns None if the relation doesn't exist -- we need to create a python object via the make_temp_relation instead {%- set old_relation_bkp = make_temp_relation(old_relation, '__bkp') -%} + {%- if temp_schema is not none -%} + {%- set old_relation_bkp = set_table_relation_schema(relation=old_relation_bkp, schema=temp_schema) -%} + {%- endif -%} {%- if old_relation_table_type == 'iceberg_table' -%} {{ rename_relation(old_relation, old_relation_bkp) }} diff --git a/tests/functional/adapter/test_incremental_tmp_schema.py b/tests/functional/adapter/test_incremental_tmp_schema.py index d06e95f2..dbd72049 100644 --- a/tests/functional/adapter/test_incremental_tmp_schema.py +++ b/tests/functional/adapter/test_incremental_tmp_schema.py @@ -2,7 +2,7 @@ import yaml from tests.functional.adapter.utils.parse_dbt_run_output import ( extract_create_statement_table_names, - extract_running_create_statements, + extract_running_ddl_statements, ) from dbt.contracts.results import RunStatus @@ -61,7 +61,7 @@ def test__schema_tmp(self, project, capsys): assert records_count_first_run == 1 out, _ = capsys.readouterr() - athena_running_create_statements = extract_running_create_statements(out, relation_name) + athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table") assert len(athena_running_create_statements) == 1 @@ -95,7 +95,7 @@ def test__schema_tmp(self, project, capsys): assert records_count_incremental_run == 2 out, _ = capsys.readouterr() - athena_running_create_statements = extract_running_create_statements(out, relation_name) + athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table") assert len(athena_running_create_statements) == 1 diff --git a/tests/functional/adapter/test_table_temp_schema_name.py b/tests/functional/adapter/test_table_temp_schema_name.py new file mode 100644 index 00000000..28b518b7 --- /dev/null +++ b/tests/functional/adapter/test_table_temp_schema_name.py @@ -0,0 +1,132 @@ +import pytest +import yaml +from tests.functional.adapter.utils.parse_dbt_run_output import ( + extract_create_statement_table_names, + extract_rename_statement_table_names, + extract_running_ddl_statements, +) + +from dbt.contracts.results import RunStatus +from dbt.tests.util import run_dbt + +models__iceberg_table = """ +{{ config( + materialized='table', + table_type='iceberg', + temp_schema=var('temp_schema_name'), + ) +}} + +select + {{ var('test_id') }} as id +""" + + +class TestTableIcebergTableUnique: + @pytest.fixture(scope="class") + def models(self): + return {"models__iceberg_table.sql": models__iceberg_table} + + def test__temp_schema_name_iceberg_table(self, project, capsys): + relation_name = "models__iceberg_table" + temp_schema_name = f"{project.test_schema}_tmp" + drop_temp_schema = f"drop schema if exists `{temp_schema_name}` cascade" + model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}" + model_run_result_test_id_query = f"select id from {project.test_schema}.{relation_name}" + + vars_dict = { + "temp_schema_name": temp_schema_name, + "test_id": 1, + } + + model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + yaml.safe_dump(vars_dict), + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table") + assert len(athena_running_create_statements) == 1 + + incremental_model_run_result_table_name = extract_create_statement_table_names( + athena_running_create_statements[0] + )[0] + assert temp_schema_name not in incremental_model_run_result_table_name + + model_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert model_records_count == 1 + + model_test_id_in_table = project.run_sql(model_run_result_test_id_query, fetch="all")[0][0] + assert model_test_id_in_table == 1 + + vars_dict["test_id"] = 2 + + model_run = run_dbt( + [ + "run", + "--select", + relation_name, + "--vars", + yaml.safe_dump(vars_dict), + "--log-level", + "debug", + "--log-format", + "json", + ] + ) + + model_run_result = model_run.results[0] + assert model_run_result.status == RunStatus.Success + + model_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0] + assert model_records_count == 1 + + model_test_id_in_table = project.run_sql(model_run_result_test_id_query, fetch="all")[0][0] + assert model_test_id_in_table == 2 + + out, _ = capsys.readouterr() + athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table") + assert len(athena_running_create_statements) == 1 + + model_run_2_result_table_name = extract_create_statement_table_names(athena_running_create_statements[0])[0] + assert temp_schema_name in model_run_2_result_table_name + + athena_running_alter_statements = extract_running_ddl_statements(out, relation_name, "alter table") + """2 rename statements: to __bck and from __ha and + # ['alter table `..._test_table_temp_schema_name`.`models__iceberg_table` + rename to `..._test_table_temp_schema_name_tmp`.`models__iceberg_table__bkp`' + , 'alter table ..._test_table_temp_schema_name_tmp`.`models__iceberg_table__ha` + rename to `..._test_table_temp_schema_name`.`models__iceberg_table`'] + """ + assert len(athena_running_alter_statements) == 2 + + athena_running_alter_statement_tables = extract_rename_statement_table_names(athena_running_alter_statements[0]) + athena_running_alter_statement_origin_table = athena_running_alter_statement_tables.get("alter_table_names")[0] + athena_running_alter_statement_renamed_to_table = athena_running_alter_statement_tables.get( + "rename_to_table_names" + )[0] + assert project.test_schema in athena_running_alter_statement_origin_table + assert athena_running_alter_statement_renamed_to_table == f"`{temp_schema_name}`.`{relation_name}__bkp`" + + athena_running_alter_statement_tables = extract_rename_statement_table_names(athena_running_alter_statements[1]) + athena_running_alter_statement_origin_table = athena_running_alter_statement_tables.get("alter_table_names")[0] + athena_running_alter_statement_renamed_to_table = athena_running_alter_statement_tables.get( + "rename_to_table_names" + )[0] + + assert temp_schema_name in athena_running_alter_statement_origin_table + assert athena_running_alter_statement_renamed_to_table == f"`{project.test_schema}`.`{relation_name}`" + + project.run_sql(drop_temp_schema) diff --git a/tests/functional/adapter/test_unique_tmp_table_suffix.py b/tests/functional/adapter/test_unique_tmp_table_suffix.py index 563e5dcb..6632e83d 100644 --- a/tests/functional/adapter/test_unique_tmp_table_suffix.py +++ b/tests/functional/adapter/test_unique_tmp_table_suffix.py @@ -3,7 +3,7 @@ import pytest from tests.functional.adapter.utils.parse_dbt_run_output import ( extract_create_statement_table_names, - extract_running_create_statements, + extract_running_ddl_statements, ) from dbt.contracts.results import RunStatus @@ -55,7 +55,7 @@ def test__unique_tmp_table_suffix(self, project, capsys): assert first_model_run_result.status == RunStatus.Success out, _ = capsys.readouterr() - athena_running_create_statements = extract_running_create_statements(out, relation_name) + athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table") assert len(athena_running_create_statements) == 1 @@ -87,7 +87,7 @@ def test__unique_tmp_table_suffix(self, project, capsys): assert incremental_model_run_result.status == RunStatus.Success out, _ = capsys.readouterr() - athena_running_create_statements = extract_running_create_statements(out, relation_name) + athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table") assert len(athena_running_create_statements) == 1 @@ -119,7 +119,7 @@ def test__unique_tmp_table_suffix(self, project, capsys): assert incremental_model_run_result.status == RunStatus.Success out, _ = capsys.readouterr() - athena_running_create_statements = extract_running_create_statements(out, relation_name) + athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table") incremental_model_run_result_table_name_2 = extract_create_statement_table_names( athena_running_create_statements[0] diff --git a/tests/functional/adapter/utils/parse_dbt_run_output.py b/tests/functional/adapter/utils/parse_dbt_run_output.py index 4f448420..611c189d 100644 --- a/tests/functional/adapter/utils/parse_dbt_run_output.py +++ b/tests/functional/adapter/utils/parse_dbt_run_output.py @@ -1,10 +1,10 @@ import json import re -from typing import List +from typing import Dict, List -def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: str) -> List[str]: - sql_create_statements = [] +def extract_running_ddl_statements(dbt_run_capsys_output: str, relation_name: str, ddl_type: str) -> List[str]: + sql_ddl_statements = [] # Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..." for events_msg in dbt_run_capsys_output.split("\n")[1:]: base_msg_data = None @@ -21,16 +21,26 @@ def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: if base_msg_data: base_msg = base_msg_data.get("base_msg") if "Running Athena query:" in str(base_msg): - if "create table" in base_msg: - sql_create_statements.append(base_msg) + if ddl_type in base_msg: + sql_ddl_statements.append(base_msg) if base_msg_data.get("conn_name") == f"model.test.{relation_name}" and "sql" in base_msg_data: - if "create table" in base_msg_data.get("sql"): - sql_create_statements.append(base_msg_data.get("sql")) + if ddl_type in base_msg_data.get("sql"): + sql_ddl_statements.append(base_msg_data.get("sql")) - return sql_create_statements + return sql_ddl_statements def extract_create_statement_table_names(sql_create_statement: str) -> List[str]: table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement) return [table_name.rstrip() for table_name in table_names] + + +def extract_rename_statement_table_names(sql_alter_rename_statement: str) -> Dict[str, List[str]]: + alter_table_names = re.findall(r"(?s)(?<=alter table ).*?(?= rename)", sql_alter_rename_statement) + rename_to_table_names = re.findall(r"(?s)(?<=rename to ).*?(?=$)", sql_alter_rename_statement) + + return { + "alter_table_names": [table_name.rstrip() for table_name in alter_table_names], + "rename_to_table_names": [table_name.rstrip() for table_name in rename_to_table_names], + }