Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ajout d'un dag pour créer une base de dev automatiquement #574

Open
wants to merge 2 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions airflow/dags/provision_dev_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import os

import pendulum
import requests
from airflow.decorators import dag, task
from airflow.models.param import Param
from include.container import Container
from psycopg2 import sql


def download_file_by_chunks(url, local_filename, chunk_size=8192):
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(local_filename, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
return local_filename


@dag(
start_date=pendulum.datetime(2024, 1, 1),
schedule="@once",
catchup=False,
doc_md=__doc__,
default_args={"owner": "Alexis Athlani", "retries": 3},
params={
"db_name": Param(default="DB_NAME", type="string"),
"user": Param(default="DB_USER", type="string", enum=["alexis", "sofian"]),
},
)
def provision_dev_db(): # noqa: C901
local_backup_path = "/tmp/backup.tar.gz"

@task.python(retries=0)
def check_db_name_is_valid_identifier(**context):
db_name = context["params"]["db_name"]
allowed_letters = "abcdefghijklmnopqrstuvwxyz-"
invalid_letters_found = []
for letter in db_name:
if letter not in allowed_letters:
invalid_letters_found.append(letter)

if invalid_letters_found:
raise ValueError(f"Invalid characters found in db_name: {invalid_letters_found}")

return f"{db_name} is a valid identifier"

@task.python(retries=0)
def check_db_name_does_not_exist(**context):
db_name = context["params"]["db_name"]
conn = Container().psycopg2_dbt_conn()
cursor = conn.cursor()
params = {"dbname": db_name}
query = """
select exists(
SELECT datname FROM pg_catalog.pg_database WHERE lower(datname) = lower(%(dbname)s)
)
"""
cursor.execute(query, params)
result = cursor.fetchone()
if result[0]:
raise ValueError(f"Database {db_name} already exists")

@task.python
def get_download_url():
return Container().scalingo().get_latest_backup_url()

@task.python
def delete_backup_if_exists_before():
if os.path.exists(local_backup_path):
os.remove(local_backup_path)

@task.python
def download_backup(url):
download_file_by_chunks(url, local_backup_path)

@task.bash
def extract_backup():
return f"tar -xvzf {local_backup_path} -C /tmp"

@task.python
def get_extracted_backup_path():
files_in_tmp = os.listdir("/tmp")
for file in files_in_tmp:
if file.endswith(".pgsql"):
return f"/tmp/{file}"

raise ValueError("No .pgsql file found in /tmp")

@task.python
def create_db(**context):
db_name = context["params"]["db_name"]
conn = Container().psycopg2_dbt_conn()
conn.autocommit = True

cur = conn.cursor()
cur.execute(sql.SQL("CREATE DATABASE {};").format(sql.Identifier(db_name)))
cur.close()

@task.bash
def restore_db(extracted_file_path, **context):
dbname = Container().dbt_conn_url(dbname=context["params"]["db_name"]).replace("_", "-")
command = [
"pg_restore",
"--dbname",
dbname,
"--no-owner",
"--no-privileges",
"--no-tablespaces",
"--no-comments",
"--clean",
"--if-exists",
extracted_file_path,
]

return " ".join(command)

@task.python
def grant_permission_to_user(**context):
db_name = context["params"]["db_name"]
user = context["params"]["user"]
conn = Container().psycopg2_dbt_conn()
conn.autocommit = True
cur = conn.cursor()
cur.execute(
sql.SQL("GRANT ALL PRIVILEGES ON DATABASE {} TO {};").format(sql.Identifier(db_name), sql.Identifier(user))
)
cur.execute(sql.SQL("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO {};").format(sql.Identifier(user)))
cur.close()

@task.python
def delete_backup_if_exists_after(extracted_file_path):
if os.path.exists(local_backup_path):
os.remove(local_backup_path)
if os.path.exists(extracted_file_path):
os.remove(extracted_file_path)

check_db_name = check_db_name_is_valid_identifier()
db_does_not_exists = check_db_name_does_not_exist()
url = get_download_url()
delete_before = delete_backup_if_exists_before()
download = download_backup(url)
extract = extract_backup()
extracted_file_path = get_extracted_backup_path()
create = create_db()
restore = restore_db(extracted_file_path)
delete_after = delete_backup_if_exists_after(extracted_file_path)
grant_permission = grant_permission_to_user()

(
check_db_name
>> db_does_not_exists
>> url
>> delete_before
>> download
>> extract
>> create
>> restore
>> grant_permission
>> delete_after
)


provision_dev_db()
25 changes: 22 additions & 3 deletions airflow/include/container.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from os import getenv
from urllib.parse import quote_plus

import pysftp
import sqlalchemy
from airflow.hooks.base import BaseHook
from dependency_injector import containers, providers
from gdaltools import PgConnectionString
Expand All @@ -10,14 +10,15 @@
from s3fs import S3FileSystem

from .mattermost import Mattermost
from .scalingo import ScalingoClient


def db_str_for_ogr2ogr(dbname: str, user: str, password: str, host: str, port: int) -> str:
return f"PG:dbname='{dbname}' host='{host}' port='{port}' user='{user}' password='{password}'"


def create_sql_alchemy_conn(url: str) -> sqlalchemy.engine.base.Connection:
return sqlalchemy.create_engine(url)
def db_str_url(dbname: str, user: str, password: str, host: str, port: int) -> str:
return f"postgresql://{user}:{password}@{host}:{port}/{dbname}"


class Container(containers.DeclarativeContainer):
Expand Down Expand Up @@ -47,6 +48,14 @@ class Container(containers.DeclarativeContainer):
host=getenv("DBT_DB_HOST"),
port=getenv("DBT_DB_PORT"),
)
dbt_conn_url = providers.Factory(
db_str_url,
dbname=getenv("DBT_DB_NAME"),
user=getenv("DBT_DB_USER"),
password=quote_plus(getenv("DBT_DB_PASSWORD")),
host=getenv("DBT_DB_HOST"),
port=int(getenv("DBT_DB_PORT")),
)

# DEV connections
gdal_dev_conn = providers.Factory(
Expand Down Expand Up @@ -120,3 +129,13 @@ class Container(containers.DeclarativeContainer):
mattermost_webhook_url=getenv("MATTERMOST_WEBHOOK_URL"),
channel=getenv("MATTERMOST_CHANNEL"),
)

scalingo = providers.Factory(
ScalingoClient,
api_token=getenv("SCALINGO_API_TOKEN"),
app_name=getenv("SCALINGO_APP_NAME"),
addon_id=getenv("SCALINGO_ADDON_ID"),
backup_dir=getenv("SCALINGO_BACKUP_DIR"),
scalingo_subdomain=getenv("SCALINGO_SUBDOMAIN"),
scalingo_subdomain_db=getenv("SCALINGO_SUBDOMAIN_DB"),
)
67 changes: 67 additions & 0 deletions airflow/include/scalingo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import requests


class ScalingoClient:
def __init__(
self,
api_token: str,
app_name: str,
addon_id: str,
scalingo_subdomain: str,
scalingo_subdomain_db: str,
):
self.api_token = api_token
self.app_name = app_name
self.addon_id = addon_id
self.scalingo_subdomain = scalingo_subdomain
self.scalingo_subdomain_db = scalingo_subdomain_db

@property
def api_bearer_token(self):
url = "https://auth.scalingo.com/v1/tokens/exchange"
basic_auth_password = self.api_token
basic_auth_username = ""

response = requests.post(url, auth=(basic_auth_username, basic_auth_password))
response.raise_for_status()
return response.json()["token"]

@property
def headers(self):
return {
"Authorization": f"Bearer {self.api_bearer_token}",
"Content-Type": "application/json",
}

@property
def api_bearer_token_addon(self):
url = f"https://{self.scalingo_subdomain}/v1/apps/{self.app_name}/addons/{self.addon_id}/token"

response = requests.post(url, headers=self.headers)
response.raise_for_status()
return response.json()["addon"]["token"]

@property
def addons_headers(self):
return {
"Authorization": f"Bearer {self.api_bearer_token_addon}",
"Content-Type": "application/json",
}

def get_backups(self):
url = f"https://{self.scalingo_subdomain_db}/api/databases/{self.addon_id}/backups"
response = requests.get(url, headers=self.addons_headers)
response.raise_for_status()
return response.json()["database_backups"]

def get_backup_url(self, backup_id: str):
url = f"https://{self.scalingo_subdomain_db}/api/databases/{self.addon_id}/backups/{backup_id}/archive"
response = requests.get(url, headers=self.addons_headers)
response.raise_for_status()
return response.json()["download_url"]

def get_latest_backup_url(self):
backups = self.get_backups()
last_backup = backups[0]
last_backup_id = last_backup["id"]
return self.get_backup_url(last_backup_id)