From b5624a69cbf2922e40299b4f6e9b291e9633cfc7 Mon Sep 17 00:00:00 2001 From: "Alexis A." Date: Thu, 12 Sep 2024 15:35:05 +0200 Subject: [PATCH 1/2] feat(airflow): add provision_dev_db dag --- airflow/dags/provision_dev_db.py | 164 +++++++++++++++++++++++++++++++ airflow/include/container.py | 25 ++++- airflow/include/scalingo.py | 69 +++++++++++++ 3 files changed, 255 insertions(+), 3 deletions(-) create mode 100644 airflow/dags/provision_dev_db.py create mode 100644 airflow/include/scalingo.py diff --git a/airflow/dags/provision_dev_db.py b/airflow/dags/provision_dev_db.py new file mode 100644 index 000000000..89af2e737 --- /dev/null +++ b/airflow/dags/provision_dev_db.py @@ -0,0 +1,164 @@ +import os + +import pendulum +import requests +from airflow.decorators import dag, task +from airflow.models.param import Param +from include.container import Container +from psycopg2 import sql + + +def download_file_by_chunks(url, local_filename, chunk_size=8192): + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(local_filename, "wb") as f: + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + return local_filename + + +@dag( + start_date=pendulum.datetime(2024, 1, 1), + schedule="@once", + catchup=False, + doc_md=__doc__, + default_args={"owner": "Alexis Athlani", "retries": 3}, + params={ + "db_name": Param(default="DB_NAME", type="string"), + "user": Param(default="DB_USER", type="string", enum=["alexis", "sofian"]), + }, +) +def provision_dev_db(): # noqa: C901 + local_backup_path = "/tmp/backup.tar.gz" + + @task.python(retries=0) + def check_db_name_is_valid_identifier(**context): + db_name = context["params"]["db_name"] + allowed_letters = "abcdefghijklmnopqrstuvwxyz-" + invalid_letters_found = [] + for letter in db_name: + if letter not in allowed_letters: + invalid_letters_found.append(letter) + + if invalid_letters_found: + raise ValueError(f"Invalid characters found in db_name: {invalid_letters_found}") + + return f"{db_name} is a valid identifier" + + @task.python(retries=0) + def check_db_name_does_not_exist(**context): + db_name = context["params"]["db_name"] + conn = Container().psycopg2_dbt_conn() + cursor = conn.cursor() + params = {"dbname": db_name} + query = """ + select exists( + SELECT datname FROM pg_catalog.pg_database WHERE lower(datname) = lower(%(dbname)s) + ) + """ + cursor.execute(query, params) + result = cursor.fetchone() + if result[0]: + raise ValueError(f"Database {db_name} already exists") + + @task.python + def get_download_url(): + return Container().scalingo().get_latest_backup_url() + + @task.python + def delete_backup_if_exists_before(): + if os.path.exists(local_backup_path): + os.remove(local_backup_path) + + @task.python + def download_backup(url): + download_file_by_chunks(url, local_backup_path) + + @task.bash + def extract_backup(): + return f"tar -xvzf {local_backup_path} -C /tmp" + + @task.python + def get_extracted_backup_path(): + files_in_tmp = os.listdir("/tmp") + for file in files_in_tmp: + if file.endswith(".pgsql"): + return f"/tmp/{file}" + + raise ValueError("No .pgsql file found in /tmp") + + @task.python + def create_db(**context): + db_name = context["params"]["db_name"] + conn = Container().psycopg2_dbt_conn() + conn.autocommit = True + + cur = conn.cursor() + cur.execute(sql.SQL("CREATE DATABASE {};").format(sql.Identifier(db_name))) + cur.close() + + @task.bash + def restore_db(extracted_file_path, **context): + dbname = Container().dbt_conn_url(dbname=context["params"]["db_name"]).replace("_", "-") + command = [ + "pg_restore", + "--dbname", + dbname, + "--no-owner", + "--no-privileges", + "--no-tablespaces", + "--no-comments", + "--clean", + "--if-exists", + extracted_file_path, + ] + + return " ".join(command) + + @task.python + def grant_permission_to_user(**context): + db_name = context["params"]["db_name"] + user = context["params"]["user"] + conn = Container().psycopg2_dbt_conn() + conn.autocommit = True + cur = conn.cursor() + cur.execute( + sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {};").format(sql.Identifier(db_name), sql.Identifier(user)) + ) + cur.execute(sql.SQL("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {};").format(sql.Identifier(user))) + cur.close() + + @task.python + def delete_backup_if_exists_after(extracted_file_path): + if os.path.exists(local_backup_path): + os.remove(local_backup_path) + if os.path.exists(extracted_file_path): + os.remove(extracted_file_path) + + check_db_name = check_db_name_is_valid_identifier() + db_does_not_exists = check_db_name_does_not_exist() + url = get_download_url() + delete_before = delete_backup_if_exists_before() + download = download_backup(url) + extract = extract_backup() + extracted_file_path = get_extracted_backup_path() + create = create_db() + restore = restore_db(extracted_file_path) + delete_after = delete_backup_if_exists_after(extracted_file_path) + grant_permission = grant_permission_to_user() + + ( + check_db_name + >> db_does_not_exists + >> url + >> delete_before + >> download + >> extract + >> create + >> restore + >> grant_permission + >> delete_after + ) + + +provision_dev_db() diff --git a/airflow/include/container.py b/airflow/include/container.py index 6306cc4d1..a4c5f569b 100644 --- a/airflow/include/container.py +++ b/airflow/include/container.py @@ -1,7 +1,7 @@ from os import getenv +from urllib.parse import quote_plus import pysftp -import sqlalchemy from airflow.hooks.base import BaseHook from dependency_injector import containers, providers from gdaltools import PgConnectionString @@ -10,14 +10,15 @@ from s3fs import S3FileSystem from .mattermost import Mattermost +from .scalingo import ScalingoClient def db_str_for_ogr2ogr(dbname: str, user: str, password: str, host: str, port: int) -> str: return f"PG:dbname='{dbname}' host='{host}' port='{port}' user='{user}' password='{password}'" -def create_sql_alchemy_conn(url: str) -> sqlalchemy.engine.base.Connection: - return sqlalchemy.create_engine(url) +def db_str_url(dbname: str, user: str, password: str, host: str, port: int) -> str: + return f"postgresql://{user}:{password}@{host}:{port}/{dbname}" class Container(containers.DeclarativeContainer): @@ -47,6 +48,14 @@ class Container(containers.DeclarativeContainer): host=getenv("DBT_DB_HOST"), port=getenv("DBT_DB_PORT"), ) + dbt_conn_url = providers.Factory( + db_str_url, + dbname=getenv("DBT_DB_NAME"), + user=getenv("DBT_DB_USER"), + password=quote_plus(getenv("DBT_DB_PASSWORD")), + host=getenv("DBT_DB_HOST"), + port=int(getenv("DBT_DB_PORT")), + ) # DEV connections gdal_dev_conn = providers.Factory( @@ -120,3 +129,13 @@ class Container(containers.DeclarativeContainer): mattermost_webhook_url=getenv("MATTERMOST_WEBHOOK_URL"), channel=getenv("MATTERMOST_CHANNEL"), ) + + scalingo = providers.Factory( + ScalingoClient, + api_token=getenv("SCALINGO_API_TOKEN"), + app_name=getenv("SCALINGO_APP_NAME"), + addon_id=getenv("SCALINGO_ADDON_ID"), + backup_dir=getenv("SCALINGO_BACKUP_DIR"), + scalingo_subdomain=getenv("SCALINGO_SUBDOMAIN"), + scalingo_subdomain_db=getenv("SCALINGO_SUBDOMAIN_DB"), + ) diff --git a/airflow/include/scalingo.py b/airflow/include/scalingo.py new file mode 100644 index 000000000..148284c39 --- /dev/null +++ b/airflow/include/scalingo.py @@ -0,0 +1,69 @@ +import requests + + +class ScalingoClient: + def __init__( + self, + api_token: str, + app_name: str, + addon_id: str, + backup_dir: str, + scalingo_subdomain: str, + scalingo_subdomain_db: str, + ): + self.api_token = api_token + self.app_name = app_name + self.addon_id = addon_id + self.backup_dir = backup_dir + self.scalingo_subdomain = scalingo_subdomain + self.scalingo_subdomain_db = scalingo_subdomain_db + + @property + def api_bearer_token(self): + url = "https://auth.scalingo.com/v1/tokens/exchange" + basic_auth_password = self.api_token + basic_auth_username = "" + + response = requests.post(url, auth=(basic_auth_username, basic_auth_password)) + response.raise_for_status() + return response.json()["token"] + + @property + def headers(self): + return { + "Authorization": f"Bearer {self.api_bearer_token}", + "Content-Type": "application/json", + } + + @property + def api_bearer_token_addon(self): + url = f"https://{self.scalingo_subdomain}/v1/apps/{self.app_name}/addons/{self.addon_id}/token" + + response = requests.post(url, headers=self.headers) + response.raise_for_status() + return response.json()["addon"]["token"] + + @property + def addons_headers(self): + return { + "Authorization": f"Bearer {self.api_bearer_token_addon}", + "Content-Type": "application/json", + } + + def get_backups(self): + url = f"https://{self.scalingo_subdomain_db}/api/databases/{self.addon_id}/backups" + response = requests.get(url, headers=self.addons_headers) + response.raise_for_status() + return response.json()["database_backups"] + + def get_backup_url(self, backup_id: str): + url = f"https://{self.scalingo_subdomain_db}/api/databases/{self.addon_id}/backups/{backup_id}/archive" + response = requests.get(url, headers=self.addons_headers) + response.raise_for_status() + return response.json()["download_url"] + + def get_latest_backup_url(self): + backups = self.get_backups() + last_backup = backups[0] + last_backup_id = last_backup["id"] + return self.get_backup_url(last_backup_id) From ee140dd39a190a84c6f332fdae04f4f6ddf5deaa Mon Sep 17 00:00:00 2001 From: "Alexis A." Date: Thu, 12 Sep 2024 15:37:02 +0200 Subject: [PATCH 2/2] feat(scalingo): remove backup_dir from client --- airflow/include/scalingo.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/airflow/include/scalingo.py b/airflow/include/scalingo.py index 148284c39..e93e232a0 100644 --- a/airflow/include/scalingo.py +++ b/airflow/include/scalingo.py @@ -7,14 +7,12 @@ def __init__( api_token: str, app_name: str, addon_id: str, - backup_dir: str, scalingo_subdomain: str, scalingo_subdomain_db: str, ): self.api_token = api_token self.app_name = app_name self.addon_id = addon_id - self.backup_dir = backup_dir self.scalingo_subdomain = scalingo_subdomain self.scalingo_subdomain_db = scalingo_subdomain_db