diff --git a/Makefile b/Makefile index eb55eb11..2e9cbc23 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,9 @@ print_target: build: @uv pip install --force-reinstall --no-deps -e . +deps: + @uv pip install --upgrade -r requirements.txt + docs: @rm -rf docs/dev/internals docs/_build @tox -e docs diff --git a/docs/changes/DM-45263.feature.rst b/docs/changes/DM-45263.feature.rst new file mode 100644 index 00000000..cf377fdd --- /dev/null +++ b/docs/changes/DM-45263.feature.rst @@ -0,0 +1,3 @@ +Added a new ``tap_schema_`` module designed to deprecate and eventually replace the ``tap`` module. +This module provides utilities for translating a Felis schema into a TAP_SCHEMA representation. +The command ``felis load-tap-schema`` can be used to activate this functionality. diff --git a/docs/dev/internals.rst b/docs/dev/internals.rst index ada84d7d..b31bdc19 100644 --- a/docs/dev/internals.rst +++ b/docs/dev/internals.rst @@ -7,33 +7,41 @@ Python API .. automodapi:: felis.datamodel :include-all-objects: -.. automodapi:: felis.metadata +.. automodapi:: felis.db.dialects :include-all-objects: :no-inheritance-diagram: -.. automodapi:: felis.tap +.. automodapi:: felis.db.sqltypes :include-all-objects: :no-inheritance-diagram: -.. automodapi:: felis.types - :include-all-objects: - -.. automodapi:: felis.db.dialects +.. automodapi:: felis.db.utils :include-all-objects: :no-inheritance-diagram: -.. automodapi:: felis.db.sqltypes +.. automodapi:: felis.db.variants :include-all-objects: :no-inheritance-diagram: -.. automodapi:: felis.db.utils +.. automodapi:: felis.metadata :include-all-objects: :no-inheritance-diagram: -.. automodapi:: felis.db.variants +.. automodapi:: felis.tap :include-all-objects: :no-inheritance-diagram: +.. automodapi:: felis.tap_schema + :include-all-objects: + :no-inheritance-diagram: + .. automodapi:: felis.tests.postgresql :include-all-objects: :no-inheritance-diagram: + +.. automodapi:: felis.tests.utils + :include-all-objects: + :no-inheritance-diagram: + +.. automodapi:: felis.types + :include-all-objects: diff --git a/docs/documenteer.toml b/docs/documenteer.toml index 188b11a6..ccccf045 100644 --- a/docs/documenteer.toml +++ b/docs/documenteer.toml @@ -21,6 +21,8 @@ nitpick_ignore = [ ["py:class", "sqlalchemy.orm.decl_api.Base"], ["py:class", "sqlalchemy.engine.mock.MockConnection"], ["py:class", "pydantic.main.BaseModel"], + ["py:exc", "pydantic.ValidationError"], + ["py:exc", "yaml.YAMLError"] ] nitpick_ignore_regex = [ # Bug in autodoc_pydantic. @@ -29,5 +31,6 @@ nitpick_ignore_regex = [ python_api_dir = "dev/internals" [sphinx.intersphinx.projects] -python = "https://docs.python.org/3/" -sqlalchemy = "https://docs.sqlalchemy.org/en/latest/" +python = "https://docs.python.org/3" +sqlalchemy = "https://docs.sqlalchemy.org/en/latest" +lsst = "https://pipelines.lsst.io/v/weekly" diff --git a/docs/user-guide/datatypes.rst b/docs/user-guide/datatypes.rst index 77ae7040..cd9caa2c 100644 --- a/docs/user-guide/datatypes.rst +++ b/docs/user-guide/datatypes.rst @@ -74,7 +74,7 @@ The following table shows these mapping: +-----------+---------------+----------+------------------+--------------+ | unicode | NVARCHAR | NVARCHAR | VARCHAR | unicodeChar | +-----------+---------------+----------+------------------+--------------+ -| text | TEXT | LONGTEXT | TEXT | uncodeChar | +| text | TEXT | LONGTEXT | TEXT | char | +-----------+---------------+----------+------------------+--------------+ | binary | BLOB | LONGBLOB | BYTEA | unsignedByte | +-----------+---------------+----------+------------------+--------------+ diff --git a/pyproject.toml b/pyproject.toml index 4dea77da..bc1ab460 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ dependencies = [ "click >= 7", "pyyaml >= 6", "pydantic >= 2, < 3", - "lsst-utils" + "lsst-utils", + "lsst-resources" ] requires-python = ">=3.11.0" dynamic = ["version"] @@ -55,7 +56,7 @@ zip-safe = true license-files = ["COPYRIGHT", "LICENSE"] [tool.setuptools.package-data] -"felis" = ["py.typed"] +"felis" = ["py.typed", "schemas/*.yaml"] [tool.setuptools.dynamic] version = { attr = "lsst_versions.get_lsst_version" } diff --git a/python/felis/cli.py b/python/felis/cli.py index 36727ebe..fbad0d7d 100644 --- a/python/felis/cli.py +++ b/python/felis/cli.py @@ -23,22 +23,21 @@ from __future__ import annotations -import io import logging from collections.abc import Iterable from typing import IO import click -import yaml from pydantic import ValidationError from sqlalchemy.engine import Engine, create_engine, make_url -from sqlalchemy.engine.mock import MockConnection +from sqlalchemy.engine.mock import MockConnection, create_mock_engine from . import __version__ from .datamodel import Schema -from .db.utils import DatabaseContext +from .db.utils import DatabaseContext, is_mock_url from .metadata import MetaDataBuilder from .tap import Tap11Base, TapLoadingVisitor, init_tables +from .tap_schema import DataLoader, TableManager __all__ = ["cli"] @@ -107,7 +106,7 @@ def create( dry_run: bool, output_file: IO[str] | None, ignore_constraints: bool, - file: IO, + file: IO[str], ) -> None: """Create database objects from the Felis file. @@ -133,8 +132,7 @@ def create( Felis file to read. """ try: - yaml_data = yaml.safe_load(file) - schema = Schema.model_validate(yaml_data, context={"id_generation": ctx.obj["id_generation"]}) + schema = Schema.from_stream(file, context={"id_generation": ctx.obj["id_generation"]}) url = make_url(engine_url) if schema_name: logger.info(f"Overriding schema name with: {schema_name}") @@ -261,7 +259,7 @@ def load_tap( tap_keys_table: str, tap_key_columns_table: str, tap_schema_index: int, - file: io.TextIOBase, + file: IO[str], ) -> None: """Load TAP metadata from a Felis file. @@ -304,8 +302,7 @@ def load_tap( The data will be loaded into the TAP_SCHEMA from the engine URL. The tables must have already been initialized or an error will occur. """ - yaml_data = yaml.load(file, Loader=yaml.SafeLoader) - schema = Schema.model_validate(yaml_data) + schema = Schema.from_stream(file) tap_tables = init_tables( tap_schema_name, @@ -345,6 +342,79 @@ def load_tap( tap_visitor.visit_schema(schema) +@cli.command("load-tap-schema", help="Load metadata from a Felis file into a TAP_SCHEMA database") +@click.option("--engine-url", envvar="FELIS_ENGINE_URL", help="SQLAlchemy Engine URL") +@click.option("--tap-schema-name", help="Name of the TAP_SCHEMA schema in the database") +@click.option( + "--tap-tables-postfix", help="Postfix which is applied to standard TAP_SCHEMA table names", default="" +) +@click.option("--tap-schema-index", type=int, help="TAP_SCHEMA index of the schema in this environment") +@click.option("--dry-run", is_flag=True, help="Execute dry run only. Does not insert any data.") +@click.option("--echo", is_flag=True, help="Print out the generated insert statements to stdout") +@click.option("--output-file", type=click.Path(), help="Write SQL commands to a file") +@click.argument("file", type=click.File()) +@click.pass_context +def load_tap_schema( + ctx: click.Context, + engine_url: str, + tap_schema_name: str, + tap_tables_postfix: str, + tap_schema_index: int, + dry_run: bool, + echo: bool, + output_file: str | None, + file: IO[str], +) -> None: + """Load TAP metadata from a Felis file. + + Parameters + ---------- + engine_url + SQLAlchemy Engine URL. + tap_tables_postfix + Postfix which is applied to standard TAP_SCHEMA table names. + tap_schema_index + TAP_SCHEMA index of the schema in this environment. + dry_run + Execute dry run only. Does not insert any data. + echo + Print out the generated insert statements to stdout. + output_file + Output file for writing generated SQL. + file + Felis file to read. + + Notes + ----- + The TAP_SCHEMA database must already exist or the command will fail. This + command will not initialize the TAP_SCHEMA tables. + """ + url = make_url(engine_url) + engine: Engine | MockConnection + if dry_run or is_mock_url(url): + engine = create_mock_engine(url, executor=None) + else: + engine = create_engine(engine_url) + mgr = TableManager( + engine=engine, + apply_schema_to_metadata=False if engine.dialect.name == "sqlite" else True, + schema_name=tap_schema_name, + table_name_postfix=tap_tables_postfix, + ) + + schema = Schema.from_stream(file, context={"id_generation": ctx.obj["id_generation"]}) + + DataLoader( + schema, + mgr, + engine, + tap_schema_index=tap_schema_index, + dry_run=dry_run, + print_sql=echo, + output_path=output_file, + ).load() + + @cli.command("validate", help="Validate one or more Felis YAML files") @click.option( "--check-description", is_flag=True, help="Check that all objects have a description", default=False @@ -372,7 +442,7 @@ def validate( check_redundant_datatypes: bool, check_tap_table_indexes: bool, check_tap_principal: bool, - files: Iterable[io.TextIOBase], + files: Iterable[IO[str]], ) -> None: """Validate one or more felis YAML files. @@ -406,9 +476,8 @@ def validate( file_name = getattr(file, "name", None) logger.info(f"Validating {file_name}") try: - data = yaml.load(file, Loader=yaml.SafeLoader) - Schema.model_validate( - data, + Schema.from_stream( + file, context={ "check_description": check_description, "check_redundant_datatypes": check_redundant_datatypes, diff --git a/python/felis/datamodel.py b/python/felis/datamodel.py index 236bc4be..53ccbf2c 100644 --- a/python/felis/datamodel.py +++ b/python/felis/datamodel.py @@ -26,10 +26,12 @@ import logging from collections.abc import Sequence from enum import StrEnum, auto -from typing import Annotated, Any, Literal, TypeAlias, Union +from typing import IO, Annotated, Any, Generic, Literal, TypeAlias, TypeVar, Union +import yaml from astropy import units as units # type: ignore from astropy.io.votable import ucd # type: ignore +from lsst.resources import ResourcePath, ResourcePathExpression from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator from .db.dialects import get_supported_dialects @@ -253,7 +255,7 @@ def check_units(self) -> Column: Raises ------ ValueError - Raised If both FITS and IVOA units are provided, or if the unit is + Raised if both FITS and IVOA units are provided, or if the unit is invalid. """ fits_unit = self.fits_tunit @@ -383,6 +385,58 @@ def check_precision(self) -> Column: raise ValueError("Precision is only valid for timestamp columns") return self + @model_validator(mode="before") + @classmethod + def check_votable_arraysize(cls, values: dict[str, Any]) -> dict[str, Any]: + """Set the default value for the ``votable_arraysize`` field, which + corresponds to ``arraysize`` in the IVOA VOTable standard. + + Parameters + ---------- + values + Values of the column. + + Returns + ------- + `dict` [ `str`, `Any` ] + The values of the column. + + Notes + ----- + Following the IVOA VOTable standard, an ``arraysize`` of 1 should not + be used. + """ + if values.get("name", None) is None or values.get("datatype", None) is None: + # Skip bad column data that will not validate + return values + arraysize = values.get("votable:arraysize", None) + if arraysize is None: + length = values.get("length", None) + datatype = values.get("datatype") + if length is not None and length > 1: + # Following the IVOA standard, arraysize of 1 is disallowed + if datatype == "char": + arraysize = str(length) + elif datatype in ("string", "unicode", "binary"): + arraysize = f"{length}*" + elif datatype in ("timestamp", "text"): + arraysize = "*" + if arraysize is not None: + values["votable:arraysize"] = arraysize + logger.debug( + f"Set default 'votable:arraysize' to '{arraysize}' on column '{values['name']}'" + + f" with datatype '{values['datatype']}' and length '{values.get('length', None)}'" + ) + else: + logger.debug(f"Using existing 'votable:arraysize' of '{arraysize}' on column '{values['name']}'") + if isinstance(values["votable:arraysize"], int): + logger.warning( + f"Usage of an integer value for 'votable:arraysize' in column '{values['name']}' is " + + "deprecated" + ) + values["votable:arraysize"] = str(arraysize) + return values + class Constraint(BaseObject): """Table constraint model.""" @@ -700,7 +754,10 @@ def visit_constraint(self, constraint: Constraint) -> None: self.add(constraint) -class Schema(BaseObject): +T = TypeVar("T", bound=BaseObject) + + +class Schema(BaseObject, Generic[T]): """Database schema model. This represents a database schema, which contains one or more tables. @@ -942,3 +999,118 @@ def __contains__(self, id: str) -> bool: The ID of the object to check. """ return id in self.id_map + + def find_object_by_id(self, id: str, obj_type: type[T]) -> T: + """Find an object with the given type by its ID. + + Parameters + ---------- + id + The ID of the object to find. + obj_type + The type of the object to find. + + Returns + ------- + BaseObject + The object with the given ID and type. + + Raises + ------ + KeyError + If the object with the given ID is not found in the schema. + TypeError + If the object that is found does not have the right type. + + Notes + ----- + The actual return type is the user-specified argument ``T``, which is + expected to be a subclass of `BaseObject`. + """ + obj = self[id] + if not isinstance(obj, obj_type): + raise TypeError(f"Object with ID '{id}' is not of type '{obj_type.__name__}'") + return obj + + def get_table_by_column(self, column: Column) -> Table: + """Find the table that contains a column. + + Parameters + ---------- + column + The column to find. + + Returns + ------- + `Table` + The table that contains the column. + + Raises + ------ + ValueError + If the column is not found in any table. + """ + for table in self.tables: + if column in table.columns: + return table + raise ValueError(f"Column '{column.name}' not found in any table") + + @classmethod + def from_uri(cls, resource_path: ResourcePathExpression, context: dict[str, Any] = {}) -> Schema: + """Load a `Schema` from a string representing a ``ResourcePath``. + + Parameters + ---------- + resource_path + The ``ResourcePath`` pointing to a YAML file. + context + Pydantic context to be used in validation. + + Returns + ------- + `str` + The ID of the object. + + Raises + ------ + yaml.YAMLError + Raised if there is an error loading the YAML data. + ValueError + Raised if there is an error reading the resource. + pydantic.ValidationError + Raised if the schema fails validation. + """ + logger.debug(f"Loading schema from: '{resource_path}'") + try: + rp_stream = ResourcePath(resource_path).read() + except Exception as e: + raise ValueError(f"Error reading resource from '{resource_path}' : {e}") from e + yaml_data = yaml.safe_load(rp_stream) + return Schema.model_validate(yaml_data, context=context) + + @classmethod + def from_stream(cls, source: IO[str], context: dict[str, Any] = {}) -> Schema: + """Load a `Schema` from a file stream which should contain YAML data. + + Parameters + ---------- + source + The file stream to read from. + context + Pydantic context to be used in validation. + + Returns + ------- + `Schema` + The Felis schema loaded from the stream. + + Raises + ------ + yaml.YAMLError + Raised if there is an error loading the YAML file. + pydantic.ValidationError + Raised if the schema fails validation. + """ + logger.debug("Loading schema from: '%s'", source) + yaml_data = yaml.safe_load(source) + return Schema.model_validate(yaml_data, context=context) diff --git a/python/felis/db/utils.py b/python/felis/db/utils.py index 2cb169e6..efa3c3c5 100644 --- a/python/felis/db/utils.py +++ b/python/felis/db/utils.py @@ -106,6 +106,43 @@ def string_to_typeengine( return type_obj +def is_mock_url(url: URL) -> bool: + """Check if the engine URL is a mock URL. + + Parameters + ---------- + url + The SQLAlchemy engine URL. + + Returns + ------- + bool + True if the URL is a mock URL, False otherwise. + """ + return (url.drivername == "sqlite" and url.database is None) or ( + url.drivername != "sqlite" and url.host is None + ) + + +def is_valid_engine(engine: Engine | MockConnection | None) -> bool: + """Check if the engine is valid. + + The engine cannot be none; it must not be a mock connection; and it must + not be a mock URL which is missing a host or, for sqlite, a database name. + + Parameters + ---------- + engine + The SQLAlchemy engine or mock connection. + + Returns + ------- + bool + True if the engine is valid, False otherwise. + """ + return engine is not None and not isinstance(engine, MockConnection) and not is_mock_url(engine.url) + + class SQLWriter: """Write SQL statements to stdout or a file. @@ -193,12 +230,19 @@ def execute(self, statement: Any) -> ResultProxy: """ if isinstance(statement, str): statement = text(statement) - if isinstance(self.engine, MockConnection): + if isinstance(self.engine, Engine): + try: + with self.engine.begin() as connection: + result = connection.execute(statement) + return result + except SQLAlchemyError as e: + connection.rollback() + logger.error(f"Error executing statement: {e}") + raise + elif isinstance(self.engine, MockConnection): return self.engine.connect().execute(statement) else: - with self.engine.begin() as connection: - result = connection.execute(statement) - return result + raise ValueError("Unsupported engine type:" + str(type(self.engine))) class DatabaseContext: @@ -218,7 +262,7 @@ def __init__(self, metadata: MetaData, engine: Engine | MockConnection): self.engine = engine self.dialect_name = engine.dialect.name self.metadata = metadata - self.conn = ConnectionWrapper(engine) + self.connection = ConnectionWrapper(engine) def initialize(self) -> None: """Create the schema in the database if it does not exist. @@ -240,14 +284,14 @@ def initialize(self) -> None: try: if self.dialect_name == "mysql": logger.debug(f"Checking if MySQL database exists: {schema_name}") - result = self.conn.execute(text(f"SHOW DATABASES LIKE '{schema_name}'")) + result = self.execute(text(f"SHOW DATABASES LIKE '{schema_name}'")) if result.fetchone(): raise ValueError(f"MySQL database '{schema_name}' already exists.") logger.debug(f"Creating MySQL database: {schema_name}") - self.conn.execute(text(f"CREATE DATABASE {schema_name}")) + self.execute(text(f"CREATE DATABASE {schema_name}")) elif self.dialect_name == "postgresql": logger.debug(f"Checking if PG schema exists: {schema_name}") - result = self.conn.execute( + result = self.execute( text( f""" SELECT schema_name @@ -259,7 +303,7 @@ def initialize(self) -> None: if result.fetchone(): raise ValueError(f"PostgreSQL schema '{schema_name}' already exists.") logger.debug(f"Creating PG schema: {schema_name}") - self.conn.execute(CreateSchema(schema_name)) + self.execute(CreateSchema(schema_name)) elif self.dialect_name == "sqlite": # Just silently ignore this operation for SQLite. The database # will still be created if it does not exist and the engine @@ -285,13 +329,15 @@ def drop(self) -> None: schema. For other variants, this is an unsupported operation. """ schema_name = self.metadata.schema + if not self.engine.dialect.name == "sqlite" and self.metadata.schema is None: + raise ValueError("Schema name is required to drop the schema.") try: if self.dialect_name == "mysql": logger.debug(f"Dropping MySQL database if exists: {schema_name}") - self.conn.execute(text(f"DROP DATABASE IF EXISTS {schema_name}")) + self.execute(text(f"DROP DATABASE IF EXISTS {schema_name}")) elif self.dialect_name == "postgresql": logger.debug(f"Dropping PostgreSQL schema if exists: {schema_name}") - self.conn.execute(DropSchema(schema_name, if_exists=True, cascade=True)) + self.execute(DropSchema(schema_name, if_exists=True, cascade=True)) elif self.dialect_name == "sqlite": if isinstance(self.engine, Engine): logger.debug("Dropping tables in SQLite schema") @@ -304,7 +350,21 @@ def drop(self) -> None: def create_all(self) -> None: """Create all tables in the schema using the metadata object.""" - self.metadata.create_all(self.engine) + if isinstance(self.engine, Engine): + # Use a transaction for a real connection. + with self.engine.begin() as conn: + try: + self.metadata.create_all(bind=conn) + conn.commit() + except SQLAlchemyError as e: + conn.rollback() + logger.error(f"Error creating tables: {e}") + raise + elif isinstance(self.engine, MockConnection): + # Mock connection so no need for a transaction. + self.metadata.create_all(self.engine) + else: + raise ValueError("Unsupported engine type: " + str(type(self.engine))) @staticmethod def create_mock_engine(engine_url: str | URL, output_file: IO[str] | None = None) -> MockConnection: @@ -327,3 +387,23 @@ def create_mock_engine(engine_url: str | URL, output_file: IO[str] | None = None engine = create_mock_engine(engine_url, executor=writer.write, paramstyle="pyformat") writer.dialect = engine.dialect return engine + + def execute(self, statement: Any) -> ResultProxy: + """Execute a SQL statement on the engine and return the result. + + Parameters + ---------- + statement + The SQL statement to execute. + + Returns + ------- + ``sqlalchemy.engine.ResultProxy`` + The result of the statement execution. + + Notes + ----- + This is just a wrapper around the execution method of the connection + object, which may execute on a real or mock connection. + """ + return self.connection.execute(statement) diff --git a/python/felis/schemas/tap_schema_std.yaml b/python/felis/schemas/tap_schema_std.yaml new file mode 100644 index 00000000..64a0f6f6 --- /dev/null +++ b/python/felis/schemas/tap_schema_std.yaml @@ -0,0 +1,273 @@ +name: TAP_SCHEMA +version: "1.1" +description: A TAP-standard-mandated schema to describe tablesets in a TAP 1.1 service +tables: +- name: "schemas" + description: description of schemas in this tableset + primaryKey: "#schemas.schema_name" + tap:table_index: 100000 + mysql:engine: "InnoDB" + columns: + - name: "schema_name" + datatype: "string" + description: schema name for reference to tap_schema.schemas + length: 64 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 1 + - name: "utype" + datatype: "string" + description: lists the utypes of schemas in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 2 + - name: "description" + datatype: "string" + description: describes schemas in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 3 + - name: "schema_index" + datatype: "int" + description: recommended sort order when listing schemas + tap:principal: 1 + tap:std: 1 + tap:column_index: 4 +- name: "tables" + description: description of tables in this tableset + primaryKey: "#tables.table_name" + tap:table_index: 101000 + mysql:engine: "InnoDB" + columns: + - name: schema_name + datatype: string + description: the schema this table belongs to + length: 64 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 1 + - name: table_name + datatype: string + description: the fully qualified table name + length: 128 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 2 + - name: table_type + datatype: string + description: "one of: table view" + length: 8 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 3 + - name: utype + datatype: string + description: lists the utype of tables in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 4 + - name: description + datatype: string + description: describes tables in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 5 + - name: table_index + datatype: int + description: recommended sort order when listing tables + tap:principal: 1 + tap:std: 1 + tap:column_index: 6 + constraints: + - name: "k1" + "@type": ForeignKey + columns: ["#tables.schema_name"] + referencedColumns: ["#schemas.schema_name"] +- name: "columns" + description: description of columns in this tableset + primaryKey: ["#columns.table_name", "#columns.column_name"] + tap_table_index: 102000 + mysql:engine: "InnoDB" + columns: + - name: table_name + datatype: string + description: the table this column belongs to + length: 128 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 1 + - name: column_name + datatype: string + description: the column name + length: 64 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 2 + - name: utype + datatype: string + description: lists the utypes of columns in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 3 + - name: ucd + datatype: string + description: lists the UCDs of columns in the tableset + length: 64 + tap:principal: 1 + tap:std: 1 + tap:column_index: 4 + - name: unit + datatype: string + description: lists the unit used for column values in the tableset + length: 64 + tap:principal: 1 + tap:std: 1 + tap:column_index: 5 + - name: description + datatype: string + description: describes the columns in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 6 + - name: datatype + datatype: string + description: lists the ADQL datatype of columns in the tableset + length: 64 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 7 + - name: arraysize + datatype: string + description: lists the size of variable-length columns in the tableset + length: 16 + tap:principal: 1 + tap:std: 1 + tap:column_index: 8 + - name: xtype + datatype: string + description: a DALI or custom extended type annotation + length: 64 + tap:principal: 1 + tap:std: 1 + tap:column_index: 9 + - name: size + datatype: int + description: "deprecated: use arraysize" + tap:principal: 1 + tap:std: 1 + tap:column_index: 10 + - name: principal + datatype: int + description: a principal column; 1 means 1, 0 means 0 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 11 + - name: indexed + datatype: int + description: an indexed column; 1 means 1, 0 means 0 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 12 + - name: std + datatype: int + description: a standard column; 1 means 1, 0 means 0 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 13 + - name: column_index + datatype: int + description: recommended sort order when listing columns + tap:principal: 1 + tap:std: 1 + tap:column_index: 14 + constraints: + - name: "k2" + "@type": ForeignKey + columns: ["#columns.table_name"] + referencedColumns: ["#tables.table_name"] +- name: "keys" + description: description of foreign keys in this tableset + primaryKey: "#keys.key_id" + tap:table_index: 103000 + mysql:engine: "InnoDB" + columns: + - name: key_id + datatype: string + description: unique key to join to tap_schema.key_columns + length: 64 + nullable: false + - name: from_table + datatype: string + description: the table with the foreign key + length: 128 + nullable: false + - name: target_table + datatype: string + description: the table with the primary key + length: 128 + nullable: false + - name: utype + datatype: string + description: lists the utype of keys in the tableset + length: 512 + - name: description + datatype: string + description: describes keys in the tableset + length: 512 + constraints: + - name: "k3" + "@type": ForeignKey + columns: ["#keys.from_table"] + referencedColumns: ["#tables.table_name"] + - name: "k4" + "@type": ForeignKey + columns: ["#keys.target_table"] + referencedColumns: ["#tables.table_name"] +- name: "key_columns" + description: description of foreign key columns in this tableset + tap:table_index: 104000 + mysql:engine: "InnoDB" + columns: + - name: key_id + datatype: string + length: 64 + nullable: false + - name: from_column + datatype: string + length: 64 + nullable: false + - name: target_column + datatype: string + length: 64 + nullable: false + constraints: + - name: "k5" + "@type": ForeignKey + columns: ["#key_columns.key_id"] + referencedColumns: ["#keys.key_id"] + # FIXME: These can't be defined as FK constraints, because they refer + # to non-unique columns, e.g., column_name from the columns table. + # - name: "k6" + # "@type": ForeignKey + # columns: ["#key_columns.from_column"] + # referencedColumns: ["#columns.column_name"] + # - name: "k7" + # "@type": ForeignKey + # columns: ["#key_columns.target_column"] + # referencedColumns: ["#columns.column_name"] diff --git a/python/felis/tap.py b/python/felis/tap.py index efb116ba..258e0090 100644 --- a/python/felis/tap.py +++ b/python/felis/tap.py @@ -407,11 +407,7 @@ def visit_column(self, column_obj: datamodel.Column, table_obj: Table) -> Tap11B felis_type = FelisType.felis_type(felis_datatype.value) column.datatype = column_obj.votable_datatype or felis_type.votable_name - column.arraysize = column_obj.votable_arraysize or ( - column_obj.length if (column_obj.length is not None and column_obj.length > 1) else None - ) - if (felis_type.is_timestamp or column_obj.datatype == "text") and column.arraysize is None: - column.arraysize = "*" + column.arraysize = column_obj.votable_arraysize def _is_int(s: str) -> bool: try: diff --git a/python/felis/tap_schema.py b/python/felis/tap_schema.py new file mode 100644 index 00000000..74120cc6 --- /dev/null +++ b/python/felis/tap_schema.py @@ -0,0 +1,644 @@ +"""Provides utilities for creating and populating the TAP_SCHEMA database.""" + +# This file is part of felis. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import logging +import os +import re +from typing import Any + +from lsst.resources import ResourcePath +from sqlalchemy import MetaData, Table, text +from sqlalchemy.engine import Connection, Engine +from sqlalchemy.engine.mock import MockConnection +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.schema import CreateSchema +from sqlalchemy.sql.dml import Insert + +from felis import datamodel +from felis.datamodel import Schema +from felis.db.utils import is_valid_engine +from felis.metadata import MetaDataBuilder + +from .types import FelisType + +__all__ = ["TableManager", "DataLoader"] + +logger = logging.getLogger(__name__) + + +class TableManager: + """Manage creation of TAP_SCHEMA tables. + + Parameters + ---------- + engine + The SQLAlchemy engine for reflecting the TAP_SCHEMA tables from an + existing database. + This can be a mock connection or None, in which case the internal + TAP_SCHEMA schema will be used by loading an internal YAML file. + schema_name + The name of the schema to use for the TAP_SCHEMA tables. + Leave as None to use the standard name of "TAP_SCHEMA". + apply_schema_to_metadata + If True, apply the schema to the metadata as well as the tables. + If False, these will be set to None, e.g., for sqlite. + table_name_postfix + A string to append to all the standard table names. + This needs to be used in a way such that the resultant table names + map to tables within the TAP_SCHEMA database. + + Notes + ----- + The TAP_SCHEMA schema must either have been created already, in which case + the ``engine`` should be provided. Or the internal TAP_SCHEMA schema will + be used if ``engine`` is None or a ``MockConnection``. + """ + + _TABLE_NAMES_STD = ["schemas", "tables", "columns", "keys", "key_columns"] + """The standard table names for the TAP_SCHEMA tables.""" + + _SCHEMA_NAME_STD = "TAP_SCHEMA" + """The standard schema name for the TAP_SCHEMA tables.""" + + def __init__( + self, + engine: Engine | MockConnection | None = None, + schema_name: str | None = None, + apply_schema_to_metadata: bool = True, + table_name_postfix: str = "", + ): + """Initialize the table manager.""" + self.table_name_postfix = table_name_postfix + self.apply_schema_to_metadata = apply_schema_to_metadata + self.schema_name = schema_name or TableManager._SCHEMA_NAME_STD + + if is_valid_engine(engine): + assert isinstance(engine, Engine) + logger.debug( + "Reflecting TAP_SCHEMA database from existing database at %s", + engine.url._replace(password="***"), + ) + self._reflect(engine) + else: + self._load_yaml() + + self._create_table_map() + self._check_tables() + + def _reflect(self, engine: Engine) -> None: + """Reflect the TAP_SCHEMA database tables into the metadata. + + Parameters + ---------- + engine + The SQLAlchemy engine to use to reflect the tables. + """ + self._metadata = MetaData(schema=self.schema_name if self.apply_schema_to_metadata else None) + try: + self.metadata.reflect(bind=engine) + except SQLAlchemyError as e: + logger.error("Error reflecting TAP_SCHEMA database: %s", e) + raise + + def _load_yaml(self) -> None: + """Load the standard TAP_SCHEMA schema from a Felis package + resource. + """ + self._load_schema() + if self.schema_name != TableManager._SCHEMA_NAME_STD: + self.schema.name = self.schema_name + else: + self.schema_name = self.schema.name + + self._metadata = MetaDataBuilder( + self.schema, + apply_schema_to_metadata=self.apply_schema_to_metadata, + apply_schema_to_tables=self.apply_schema_to_metadata, + ).build() + + logger.debug("Loaded TAP_SCHEMA '%s' from YAML resource", self.schema_name) + + def __getitem__(self, table_name: str) -> Table: + """Get one of the TAP_SCHEMA tables by its standard TAP_SCHEMA name. + + Parameters + ---------- + table_name + The name of the table to get. + + Returns + ------- + Table + The table with the given name. + + Notes + ----- + This implements array semantics for the table manager, allowing + tables to be accessed by their standard TAP_SCHEMA names. + """ + if table_name not in self._table_map: + raise KeyError(f"Table '{table_name}' not found in table map") + return self.metadata.tables[self._table_map[table_name]] + + @property + def schema(self) -> Schema: + """Get the TAP_SCHEMA schema. + + Returns + ------- + Schema + The TAP_SCHEMA schema. + + Notes + ----- + This will only be set if the TAP_SCHEMA schema was loaded from a + Felis package resource. In the case where the TAP_SCHEMA schema was + reflected from an existing database, this will be None. + """ + return self._schema + + @property + def metadata(self) -> MetaData: + """Get the metadata for the TAP_SCHEMA tables. + + Returns + ------- + `~sqlalchemy.sql.schema.MetaData` + The metadata for the TAP_SCHEMA tables. + + Notes + ----- + This will either be the metadata that was reflected from an existing + database or the metadata that was loaded from a Felis package resource. + """ + return self._metadata + + @classmethod + def get_tap_schema_std_path(cls) -> str: + """Get the path to the standard TAP_SCHEMA schema resource. + + Returns + ------- + str + The path to the standard TAP_SCHEMA schema resource. + """ + return os.path.join(os.path.dirname(__file__), "schemas", "tap_schema_std.yaml") + + @classmethod + def get_tap_schema_std_resource(cls) -> ResourcePath: + """Get the standard TAP_SCHEMA schema resource. + + Returns + ------- + `~lsst.resources.ResourcePath` + The standard TAP_SCHEMA schema resource. + """ + return ResourcePath("resource://felis/schemas/tap_schema_std.yaml") + + @classmethod + def get_table_names_std(cls) -> list[str]: + """Get the standard column names for the TAP_SCHEMA tables. + + Returns + ------- + list + The standard table names for the TAP_SCHEMA tables. + """ + return cls._TABLE_NAMES_STD + + @classmethod + def get_schema_name_std(cls) -> str: + """Get the standard schema name for the TAP_SCHEMA tables. + + Returns + ------- + str + The standard schema name for the TAP_SCHEMA tables. + """ + return cls._SCHEMA_NAME_STD + + @classmethod + def load_schema_resource(cls) -> Schema: + """Load the standard TAP_SCHEMA schema from a Felis package + resource into a Felis `~felis.datamodel.Schema`. + + Returns + ------- + Schema + The TAP_SCHEMA schema. + """ + rp = cls.get_tap_schema_std_resource() + return Schema.from_uri(rp, context={"id_generation": True}) + + def _load_schema(self) -> None: + """Load the TAP_SCHEMA schema from a Felis package resource.""" + self._schema = self.load_schema_resource() + + def _create_table_map(self) -> None: + """Create a mapping of standard table names to the table names modified + with a postfix, as well as the prepended schema name if it is set. + + Returns + ------- + dict + A dictionary mapping the standard table names to the modified + table names. + + Notes + ----- + This is a private method that is called during initialization, allowing + us to use table names like ``schemas11`` such as those used by the CADC + TAP library instead of the standard table names. It also maps between + the standard table names and those with the schema name prepended like + SQLAlchemy uses. + """ + self._table_map = { + table_name: ( + f"{self.schema_name + '.' if self.apply_schema_to_metadata else ''}" + f"{table_name}{self.table_name_postfix}" + ) + for table_name in TableManager.get_table_names_std() + } + logger.debug(f"Created TAP_SCHEMA table map: {self._table_map}") + + def _check_tables(self) -> None: + """Check that there is a valid mapping to each standard table. + + Raises + ------ + KeyError + If a table is missing from the table map. + """ + for table_name in TableManager.get_table_names_std(): + self[table_name] + + def _create_schema(self, engine: Engine) -> None: + """Create the database schema for TAP_SCHEMA if it does not already + exist. + + Parameters + ---------- + engine + The SQLAlchemy engine to use to create the schema. + + Notes + ----- + This method only creates the schema in the database. It does not create + the tables. + """ + create_schema_functions = { + "postgresql": self._create_schema_postgresql, + "mysql": self._create_schema_mysql, + } + + dialect_name = engine.dialect.name + if dialect_name == "sqlite": + # SQLite doesn't have schemas. + return + + create_function = create_schema_functions.get(dialect_name) + + if create_function: + with engine.begin() as connection: + create_function(connection) + else: + # Some other database engine we don't currently know how to handle. + raise NotImplementedError( + f"Database engine '{engine.dialect.name}' is not supported for schema creation" + ) + + def _create_schema_postgresql(self, connection: Connection) -> None: + """Create the schema in a PostgreSQL database. + + Parameters + ---------- + connection + The SQLAlchemy connection to use to create the schema. + """ + connection.execute(CreateSchema(self.schema_name, if_not_exists=True)) + + def _create_schema_mysql(self, connection: Connection) -> None: + """Create the schema in a MySQL database. + + Parameters + ---------- + connection + The SQLAlchemy connection to use to create the schema. + """ + connection.execute(text(f"CREATE DATABASE IF NOT EXISTS {self.schema_name}")) + + def initialize_database(self, engine: Engine) -> None: + """Initialize a database with the TAP_SCHEMA tables. + + Parameters + ---------- + engine + The SQLAlchemy engine to use to create the tables. + """ + logger.info("Creating TAP_SCHEMA database '%s'", self.metadata.schema) + self._create_schema(engine) + self.metadata.create_all(engine) + + +class DataLoader: + """Load data into the TAP_SCHEMA tables. + + Parameters + ---------- + schema + The Felis ``Schema`` to load into the TAP_SCHEMA tables. + mgr + The table manager that contains the TAP_SCHEMA tables. + engine + The SQLAlchemy engine to use to connect to the database. + tap_schema_index + The index of the schema in the TAP_SCHEMA database. + output_path + The file to write the SQL statements to. If None, printing will be + suppressed. + print_sql + If True, print the SQL statements that will be executed. + dry_run + If True, the data will not be loaded into the database. + """ + + def __init__( + self, + schema: Schema, + mgr: TableManager, + engine: Engine | MockConnection, + tap_schema_index: int = 0, + output_path: str | None = None, + print_sql: bool = False, + dry_run: bool = False, + ): + self.schema = schema + self.mgr = mgr + self.engine = engine + self.tap_schema_index = tap_schema_index + self.inserts: list[Insert] = [] + self.output_path = output_path + self.print_sql = print_sql + self.dry_run = dry_run + + def load(self) -> None: + """Load the schema data into the TAP_SCHEMA tables. + + Notes + ----- + This will generate inserts for the data, print the SQL statements if + requested, save the SQL statements to a file if requested, and load the + data into the database if not in dry run mode. These are done as + sequential operations rather than for each insert. The logic is that + the user may still want the complete SQL output to be printed or saved + to a file even if loading into the database causes errors. If there are + errors when inserting into the database, the SQLAlchemy error message + should indicate which SQL statement caused the error. + """ + self._generate_all_inserts() + if self.print_sql: + # Print to stdout. + self._print_sql() + if self.output_path: + # Print to an output file. + self._write_sql_to_file() + if not self.dry_run: + # Execute the inserts if not in dry run mode. + self._execute_inserts() + else: + logger.info("Dry run: not loading data into database") + + def _insert_schemas(self) -> None: + """Insert the schema data into the schemas table.""" + schema_record = { + "schema_name": self.schema.name, + "utype": self.schema.votable_utype, + "description": self.schema.description, + "schema_index": self.tap_schema_index, + } + self._insert("schemas", schema_record) + + def _get_table_name(self, table: datamodel.Table) -> str: + """Get the name of the table with the schema name prepended. + + Parameters + ---------- + table + The table to get the name for. + + Returns + ------- + str + The name of the table with the schema name prepended. + """ + return f"{self.schema.name}.{table.name}" + + def _insert_tables(self) -> None: + """Insert the table data into the tables table.""" + for table in self.schema.tables: + table_record = { + "schema_name": self.schema.name, + "table_name": self._get_table_name(table), + "table_type": "table", + "utype": table.votable_utype, + "description": table.description, + "table_index": 0 if table.tap_table_index is None else table.tap_table_index, + } + self._insert("tables", table_record) + + def _insert_columns(self) -> None: + """Insert the column data into the columns table.""" + for table in self.schema.tables: + for column in table.columns: + felis_type = FelisType.felis_type(column.datatype.value) + arraysize = str(column.votable_arraysize) if column.votable_arraysize else None + size = DataLoader._get_size(column) + indexed = DataLoader._is_indexed(column, table) + tap_column_index = column.tap_column_index + unit = column.ivoa_unit or column.fits_tunit + + column_record = { + "table_name": self._get_table_name(table), + "column_name": column.name, + "datatype": felis_type.votable_name, + "arraysize": arraysize, + "size": size, + "xtype": column.votable_xtype, + "description": column.description, + "utype": column.votable_utype, + "unit": unit, + "ucd": column.ivoa_ucd, + "indexed": indexed, + "principal": column.tap_principal, + "std": column.tap_std, + "column_index": tap_column_index, + } + self._insert("columns", column_record) + + def _insert_keys(self) -> None: + """Insert the foreign keys into the keys and key_columns tables.""" + for table in self.schema.tables: + for constraint in table.constraints: + if isinstance(constraint, datamodel.ForeignKeyConstraint): + # Handle keys table + referenced_column = self.schema.find_object_by_id( + constraint.referenced_columns[0], datamodel.Column + ) + referenced_table = self.schema.get_table_by_column(referenced_column) + key_record = { + "key_id": constraint.name, + "from_table": self._get_table_name(table), + "target_table": self._get_table_name(referenced_table), + "description": constraint.description, + "utype": constraint.votable_utype, + } + self._insert("keys", key_record) + + # Handle key_columns table + from_column = self.schema.find_object_by_id(constraint.columns[0], datamodel.Column) + target_column = self.schema.find_object_by_id( + constraint.referenced_columns[0], datamodel.Column + ) + key_columns_record = { + "key_id": constraint.name, + "from_column": from_column.name, + "target_column": target_column.name, + } + self._insert("key_columns", key_columns_record) + + def _generate_all_inserts(self) -> None: + """Generate the inserts for all the data.""" + self.inserts.clear() + self._insert_schemas() + self._insert_tables() + self._insert_columns() + self._insert_keys() + logger.debug("Generated %d insert statements", len(self.inserts)) + + def _execute_inserts(self) -> None: + """Load the `~felis.datamodel.Schema` data into the TAP_SCHEMA + tables. + """ + if isinstance(self.engine, Engine): + with self.engine.connect() as connection: + transaction = connection.begin() + try: + for insert in self.inserts: + connection.execute(insert) + transaction.commit() + except Exception as e: + logger.error("Error loading data into database: %s", e) + transaction.rollback() + raise + + def _compiled_inserts(self) -> list[str]: + """Compile the inserts to SQL. + + Returns + ------- + list + A list of the compiled insert statements. + """ + return [ + str(insert.compile(self.engine, compile_kwargs={"literal_binds": True})) + for insert in self.inserts + ] + + def _print_sql(self) -> None: + """Print the generated inserts to stdout.""" + for insert_str in self._compiled_inserts(): + print(insert_str) + + def _write_sql_to_file(self) -> None: + """Write the generated insert statements to a file.""" + if not self.output_path: + raise ValueError("No output path specified") + with open(self.output_path, "w") as outfile: + for insert_str in self._compiled_inserts(): + outfile.write(insert_str + "\n") + + def _insert(self, table_name: str, record: list[Any] | dict[str, Any]) -> None: + """Generate an insert statement for a record. + + Parameters + ---------- + table_name + The name of the table to insert the record into. + record + The record to insert into the table. + """ + table = self.mgr[table_name] + insert_statement = table.insert().values(record) + self.inserts.append(insert_statement) + + @staticmethod + def _get_size(column: datamodel.Column) -> int | None: + """Get the size of the column. + + Parameters + ---------- + column + The column to get the size for. + + Returns + ------- + int or None + The size of the column or None if not applicable. + """ + arraysize = column.votable_arraysize + + if not arraysize: + return None + + arraysize_str = str(arraysize) + if arraysize_str.isdigit(): + return int(arraysize_str) + + match = re.match(r"^([0-9]+)\*$", arraysize_str) + if match and match.group(1) is not None: + return int(match.group(1)) + + return None + + @staticmethod + def _is_indexed(column: datamodel.Column, table: datamodel.Table) -> int: + """Check if the column is indexed in the table. + + Parameters + ---------- + column + The column to check. + table + The table to check. + + Returns + ------- + int + 1 if the column is indexed, 0 otherwise. + """ + if isinstance(table.primary_key, str) and table.primary_key == column.id: + return 1 + for index in table.indexes: + if index.columns and len(index.columns) == 1 and index.columns[0] == column.id: + return 1 + return 0 diff --git a/python/felis/tests/utils.py b/python/felis/tests/utils.py new file mode 100644 index 00000000..f68d0489 --- /dev/null +++ b/python/felis/tests/utils.py @@ -0,0 +1,122 @@ +"""Test utility functions.""" + +# This file is part of felis. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import logging +import os +import shutil +import tempfile +from collections.abc import Iterator +from contextlib import contextmanager +from typing import IO + +__all__ = ["open_test_file", "mk_temp_dir", "rm_temp_dir"] + +TEST_DATA_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "tests", "data")) +"""The directory containing test data files.""" + +TEST_TMP_DIR = os.path.normpath(os.path.join(TEST_DATA_DIR, "..")) +"""The directory for temporary files.""" + +__all__ = ["open_test_file", "mk_temp_dir", "rm_temp_dir"] + +logger = logging.getLogger(__name__) + + +def get_test_file_path(file_name: str) -> str: + """Return the path to a test file. + + Parameters + ---------- + file_name + The name of the test file. + + Returns + ------- + str + The path to the test file. + + Raises + ------ + FileNotFoundError + Raised if the file does not exist. + """ + file_path = os.path.join(TEST_DATA_DIR, file_name) + if not os.path.exists(file_path): + raise FileNotFoundError(file_path) + return file_path + + +@contextmanager +def open_test_file(file_name: str) -> Iterator[IO[str]]: + """Return a file object for a test file using a context manager. + + Parameters + ---------- + file_name + The name of the test file. + + Returns + ------- + `Iterator` [ `IO` [ `str` ] ] + A file object for the test file. + + Raises + ------ + FileNotFoundError + Raised if the file does not exist. + """ + logger.debug("Opening test file: %s", file_name) + file_path = get_test_file_path(file_name) + file = open(file_path) + try: + yield file + finally: + file.close() + + +def mk_temp_dir(parent_dir: str = TEST_TMP_DIR) -> str: + """Create a temporary directory for testing. + + Parameters + ---------- + parent_dir + The parent directory for the temporary directory. + + Returns + ------- + str + The path to the temporary directory. + """ + return tempfile.mkdtemp(dir=parent_dir) + + +def rm_temp_dir(temp_dir: str) -> None: + """Remove a temporary directory. + + Parameters + ---------- + temp_dir + The path to the temporary directory. + """ + logger.debug("Removing temporary directory: %s", temp_dir) + shutil.rmtree(temp_dir, ignore_errors=True) diff --git a/requirements.txt b/requirements.txt index 25a32e61..f679edd8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ click >= 7 pyyaml >= 6 pydantic >= 2, < 3 lsst-utils +lsst-resources diff --git a/tests/data/test_tap_schema.yaml b/tests/data/test_tap_schema.yaml new file mode 100644 index 00000000..71550c8b --- /dev/null +++ b/tests/data/test_tap_schema.yaml @@ -0,0 +1,93 @@ +name: "test_schema" +description: "Test schema" +votable:utype: "Schema" +tables: + - name: "table1" + description: "Test table 1" + primaryKey: "#table1.id" + tap:table_index: 2 + votable:utype: "Table" + columns: + - name: "id" + datatype: "int" + description: "Primary key for this table" + - name: "fk" + datatype: "int" + description: "Foreign key pointing to table2" + - name: "indexed_field" + datatype: "int" + description: "Field with index" + - name: "boolean_field" + datatype: "boolean" + description: "Boolean field" + - name: "byte_field" + datatype: "byte" + description: "Byte field" + - name: "short_field" + datatype: "short" + description: "Short field" + - name: "int_field" + datatype: "int" + description: "Int field" + - name: "long_field" + datatype: "long" + description: "Long field" + - name: "float_field" + datatype: "float" + description: "Float field" + - name: "double_field" + datatype: "double" + description: "Double field" + - name: "char_field" + datatype: "char" + length: 64 + description: "Char field" + - name: "string_field" + datatype: "string" + length: 256 + description: "String field" + - name: "text_field" + datatype: "text" + description: "Text field" + - name: "unicode_field" + datatype: "unicode" + length: 128 + description: "Unicode field" + - name: "timestamp_field" + datatype: "timestamp" + # votable:arraysize: 64 + votable:xtype: "timestamp" + description: "Timestamp field" + votable:utype: "Obs:Timestamp" + ivoa:unit: "s" + ivoa:ucd: "time.epoch" + tap:principal: 1 + tap:std: 1 + tap:column_index: 42 + - name: "binary_field" + datatype: "binary" + length: 1024 + description: "Binary field" + constraints: + - name: "fk_table1_to_table2" + "@type": "ForeignKey" + description: "Foreign key from table1 to table2" + votable:utype: "ForeignKey" + columns: + - "#table1.fk" + referencedColumns: + - "#table2.id" + indexes: + - name: "idx_table1_indexed_field" + columns: + - "#table1.indexed_field" + - name: "table2" + description: "Test table 2" + primaryKey: "#table2.id" + votable:utype: "Table" + tap:table_index: 3 + columns: + - name: "id" + datatype: "int" + description: "Test column" + votable:utype: "Column" diff --git a/tests/data/test_tap_schema_nonstandard.yaml b/tests/data/test_tap_schema_nonstandard.yaml new file mode 100644 index 00000000..7db8bfd0 --- /dev/null +++ b/tests/data/test_tap_schema_nonstandard.yaml @@ -0,0 +1,274 @@ +--- +name: tap_schema11 +version: "1.1" +description: Test of TAP_SCHEMA with non-standard schema and column names +tables: +- name: "schemas11" + description: description of schemas in this tableset + primaryKey: "#schemas11.schema_name" + tap:table_index: 100000 + mysql:engine: "InnoDB" + columns: + - name: "schema_name" + datatype: "string" + description: schema name for reference to tap_schema.schemas + length: 64 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 1 + - name: "utype" + datatype: "string" + description: lists the utypes of schemas in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 2 + - name: "description" + datatype: "string" + description: describes schemas in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 3 + - name: "schema_index" + datatype: "int" + description: recommended sort order when listing schemas + tap:principal: 1 + tap:std: 1 + tap:column_index: 4 +- name: "tables11" + description: description of tables in this tableset + primaryKey: "#tables11.table_name" + tap:table_index: 101000 + mysql:engine: "InnoDB" + columns: + - name: schema_name + datatype: string + description: the schema this table belongs to + length: 64 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 1 + - name: table_name + datatype: string + description: the fully qualified table name + length: 128 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 2 + - name: table_type + datatype: string + description: "one of: table view" + length: 8 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 3 + - name: utype + datatype: string + description: lists the utype of tables in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 4 + - name: description + datatype: string + description: describes tables in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 5 + - name: table_index + datatype: int + description: recommended sort order when listing tables + tap:principal: 1 + tap:std: 1 + tap:column_index: 6 + constraints: + - name: "k1" + "@type": ForeignKey + columns: ["#tables11.schema_name"] + referencedColumns: ["#schemas11.schema_name"] +- name: "columns11" + description: description of columns in this tableset + primaryKey: ["#columns11.table_name", "#columns11.column_name"] + tap_table_index: 102000 + mysql:engine: "InnoDB" + columns: + - name: table_name + datatype: string + description: the table this column belongs to + length: 128 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 1 + - name: column_name + datatype: string + description: the column name + length: 64 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 2 + - name: utype + datatype: string + description: lists the utypes of columns in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 3 + - name: ucd + datatype: string + description: lists the UCDs of columns in the tableset + length: 64 + tap:principal: 1 + tap:std: 1 + tap:column_index: 4 + - name: unit + datatype: string + description: lists the unit used for column values in the tableset + length: 64 + tap:principal: 1 + tap:std: 1 + tap:column_index: 5 + - name: description + datatype: string + description: describes the columns in the tableset + length: 512 + tap:principal: 1 + tap:std: 1 + tap:column_index: 6 + - name: datatype + datatype: string + description: lists the ADQL datatype of columns in the tableset + length: 64 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 7 + - name: arraysize + datatype: string + description: lists the size of variable-length columns in the tableset + length: 16 + tap:principal: 1 + tap:std: 1 + tap:column_index: 8 + - name: xtype + datatype: string + description: a DALI or custom extended type annotation + length: 64 + tap:principal: 1 + tap:std: 1 + tap:column_index: 9 + - name: size + datatype: int + description: "deprecated: use arraysize" + tap:principal: 1 + tap:std: 1 + tap:column_index: 10 + - name: principal + datatype: int + description: a principal column; 1 means 1, 0 means 0 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 11 + - name: indexed + datatype: int + description: an indexed column; 1 means 1, 0 means 0 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 12 + - name: std + datatype: int + description: a standard column; 1 means 1, 0 means 0 + nullable: false + tap:principal: 1 + tap:std: 1 + tap:column_index: 13 + - name: column_index + datatype: int + description: recommended sort order when listing columns + tap:principal: 1 + tap:std: 1 + tap:column_index: 14 + constraints: + - name: "k2" + "@type": ForeignKey + columns: ["#columns11.table_name"] + referencedColumns: ["#tables11.table_name"] +- name: "keys11" + description: description of foreign keys in this tableset + primaryKey: "#keys11.key_id" + tap:table_index: 103000 + mysql:engine: "InnoDB" + columns: + - name: key_id + datatype: string + description: unique key to join to tap_schema.key_columns + length: 64 + nullable: false + - name: from_table + datatype: string + description: the table with the foreign key + length: 128 + nullable: false + - name: target_table + datatype: string + description: the table with the primary key + length: 128 + nullable: false + - name: utype + datatype: string + description: lists the utype of keys in the tableset + length: 512 + - name: description + datatype: string + description: describes keys in the tableset + length: 512 + constraints: + - name: "k3" + "@type": ForeignKey + columns: ["#keys11.from_table"] + referencedColumns: ["#tables11.table_name"] + - name: "k4" + "@type": ForeignKey + columns: ["#keys11.target_table"] + referencedColumns: ["#tables11.table_name"] +- name: "key_columns11" + description: description of foreign key columns in this tableset + tap:table_index: 104000 + mysql:engine: "InnoDB" + columns: + - name: key_id + datatype: string + length: 64 + nullable: false + - name: from_column + datatype: string + length: 64 + nullable: false + - name: target_column + datatype: string + length: 64 + nullable: false + constraints: + - name: "k5" + "@type": ForeignKey + columns: ["#key_columns11.key_id"] + referencedColumns: ["#keys11.key_id"] + # FIXME: These can't be defined as FK constraints, because they refer + # to non-unique columns, e.g., column_name from the columns table. + # - name: "k6" + # "@type": ForeignKey + # columns: ["#key_columns.from_column"] + # referencedColumns: ["#columns.column_name"] + # - name: "k7" + # "@type": ForeignKey + # columns: ["#key_columns.target_column"] + # referencedColumns: ["#columns.column_name"] diff --git a/tests/test_cli.py b/tests/test_cli.py index 16077e58..65949d1b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -26,6 +26,7 @@ from click.testing import CliRunner +import felis.tap_schema as tap_schema from felis.cli import cli from felis.db.dialects import get_supported_dialects @@ -167,6 +168,26 @@ def test_initialize_and_drop(self) -> None: ) self.assertTrue(result.exit_code != 0) + def test_load_tap_schema(self) -> None: + """Test for ``load-tap-schema`` command.""" + # Create the TAP_SCHEMA database. + url = f"sqlite:///{self.tmpdir}/tap_schema.sqlite3" + runner = CliRunner() + tap_schema_path = tap_schema.TableManager.get_tap_schema_std_path() + result = runner.invoke( + cli, + ["--id-generation", "create", f"--engine-url={url}", tap_schema_path], + catch_exceptions=False, + ) + self.assertEqual(result.exit_code, 0) + + # Load the TAP_SCHEMA data. + runner = CliRunner() + result = runner.invoke( + cli, ["load-tap-schema", f"--engine-url={url}", TEST_YAML], catch_exceptions=False + ) + self.assertEqual(result.exit_code, 0) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_datamodel.py b/tests/test_datamodel.py index 26e38275..269b3359 100644 --- a/tests/test_datamodel.py +++ b/tests/test_datamodel.py @@ -20,10 +20,12 @@ # along with this program. If not, see . import os +import pathlib import unittest from collections import defaultdict import yaml +from lsst.resources import ResourcePath from pydantic import ValidationError from felis.datamodel import ( @@ -38,6 +40,7 @@ Table, UniqueConstraint, ) +from felis.tests.utils import get_test_file_path, open_test_file TESTDIR = os.path.abspath(os.path.dirname(__file__)) TEST_YAML = os.path.join(TESTDIR, "data", "test.yml") @@ -519,6 +522,68 @@ def test_id_generation(self) -> None: with self.assertRaises(ValidationError): Schema.model_validate(yaml_data, context={"id_generation": False}) + def test_get_table_by_column(self) -> None: + """Test the ``get_table_by_column`` method.""" + # Test that the correct table is returned when searching by column. + test_col = Column(name="test_column", id="#test_tbl.test_col", datatype="string", length=256) + test_tbl = Table(name="test_table", id="#test_tbl", columns=[test_col]) + sch = Schema(name="testSchema", id="#test_sch_id", tables=[test_tbl]) + self.assertEqual(sch.get_table_by_column(test_col), test_tbl) + + # Test that an error is raised when the column is not found. + bad_col = Column(name="bad_column", id="#test_tbl.bad_column", datatype="string", length=256) + with self.assertRaises(ValueError): + sch.get_table_by_column(bad_col) + + def test_find_object_by_id(self) -> None: + test_col = Column(name="test_column", id="#test_tbl.test_col", datatype="string", length=256) + test_tbl = Table(name="test_table", id="#test_tbl", columns=[test_col]) + sch = Schema(name="testSchema", id="#test_sch_id", tables=[test_tbl]) + self.assertEqual(sch.find_object_by_id("#test_tbl.test_col", Column), test_col) + with self.assertRaises(KeyError): + sch.find_object_by_id("#bad_id", Column) + with self.assertRaises(TypeError): + sch.find_object_by_id("#test_tbl", Column) + + def test_from_file(self) -> None: + """Test loading a schema from a file.""" + # Test file object. + with open_test_file("sales.yaml") as test_file: + schema = Schema.from_stream(test_file) + self.assertIsInstance(schema, Schema) + + # Test path string. + test_file_str = get_test_file_path("sales.yaml") + schema = Schema.from_stream(open(test_file_str)) + self.assertIsInstance(schema, Schema) + + # Path object. + test_file_path = pathlib.Path(test_file_str) + schema = Schema.from_uri(test_file_path) + self.assertIsInstance(schema, Schema) + + def test_from_resource(self) -> None: + """Test loading a schema from a resource.""" + # Test loading a schema from a resource string. + schema = Schema.from_uri( + "resource://felis/schemas/tap_schema_std.yaml", context={"id_generation": True} + ) + self.assertIsInstance(schema, Schema) + + # Test loading a schema from a ResourcePath. + schema = Schema.from_uri( + ResourcePath("resource://felis/schemas/tap_schema_std.yaml"), context={"id_generation": True} + ) + self.assertIsInstance(schema, Schema) + + # Test loading from a nonexistant resource. + with self.assertRaises(ValueError): + Schema.from_uri("resource://fake/schemas/bad_schema.yaml") + + # Without ID generation enabled, this schema should fail validation. + with self.assertRaises(ValidationError): + Schema.from_uri("resource://felis/schemas/tap_schema_std.yaml") + class SchemaVersionTest(unittest.TestCase): """Test the schema version.""" diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 18342950..e5a927b4 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -25,7 +25,6 @@ import yaml from sqlalchemy import ( CheckConstraint, - Connection, Constraint, ForeignKeyConstraint, Index, @@ -53,10 +52,6 @@ def setUp(self) -> None: with open(TEST_YAML) as data: self.yaml_data = yaml.safe_load(data) - def connection(self) -> Connection: - """Return a connection to the database.""" - return self.engine.connect() - def test_create_all(self) -> None: """Create all tables in the schema using the metadata object and a SQLite connection. @@ -76,78 +71,76 @@ def _sorted_constraints(constraints: set[Constraint]) -> list[Constraint]: [c for c in constraints if not isinstance(c, PrimaryKeyConstraint)], key=lambda c: c.name ) - with self.connection() as connection: - schema = Schema.model_validate(self.yaml_data) - schema.name = "main" - builder = MetaDataBuilder(schema) - md = builder.build() - - ctx = DatabaseContext(md, connection) - - ctx.create_all() - - md_db = MetaData() - md_db.reflect(connection, schema=schema.name) - - self.assertEqual(md_db.tables.keys(), md.tables.keys()) - - for md_table_name in md.tables.keys(): - md_table = md.tables[md_table_name] - md_db_table = md_db.tables[md_table_name] - self.assertEqual(md_table.columns.keys(), md_db_table.columns.keys()) - for md_column_name in md_table.columns.keys(): - md_column = md_table.columns[md_column_name] - md_db_column = md_db_table.columns[md_column_name] - self.assertEqual(type(md_column.type), type(md_db_column.type)) - self.assertEqual(md_column.nullable, md_db_column.nullable) - self.assertEqual(md_column.primary_key, md_db_column.primary_key) - self.assertTrue( - (md_table.constraints and md_db_table.constraints) - or (not md_table.constraints and not md_table.constraints), - "Constraints not created correctly", - ) - if md_table.constraints: - self.assertEqual(len(md_table.constraints), len(md_db_table.constraints)) - md_constraints = _sorted_constraints(md_table.constraints) - md_db_constraints = _sorted_constraints(md_db_table.constraints) - for md_constraint, md_db_constraint in zip(md_constraints, md_db_constraints): - self.assertEqual(md_constraint.name, md_db_constraint.name) - self.assertEqual(md_constraint.deferrable, md_db_constraint.deferrable) - self.assertEqual(md_constraint.initially, md_db_constraint.initially) - self.assertEqual( - type(md_constraint), type(md_db_constraint), "Constraint types do not match" - ) - if isinstance(md_constraint, ForeignKeyConstraint) and isinstance( - md_db_constraint, ForeignKeyConstraint - ): - md_fk: ForeignKeyConstraint = md_constraint - md_db_fk: ForeignKeyConstraint = md_db_constraint - self.assertEqual(md_fk.referred_table.name, md_db_fk.referred_table.name) - self.assertEqual(md_fk.column_keys, md_db_fk.column_keys) - elif isinstance(md_constraint, UniqueConstraint) and isinstance( - md_db_constraint, UniqueConstraint - ): - md_uniq: UniqueConstraint = md_constraint - md_db_uniq: UniqueConstraint = md_db_constraint - self.assertEqual(md_uniq.columns.keys(), md_db_uniq.columns.keys()) - elif isinstance(md_constraint, CheckConstraint) and isinstance( - md_db_constraint, CheckConstraint - ): - md_check: CheckConstraint = md_constraint - md_db_check: CheckConstraint = md_db_constraint - self.assertEqual(str(md_check.sqltext), str(md_db_check.sqltext)) - self.assertTrue( - (md_table.indexes and md_db_table.indexes) - or (not md_table.indexes and not md_table.indexes), - "Indexes not created correctly", - ) - if md_table.indexes: - md_indexes = _sorted_indexes(md_table.indexes) - md_db_indexes = _sorted_indexes(md_db_table.indexes) - self.assertEqual(len(md_indexes), len(md_db_indexes)) - for md_index, md_db_index in zip(md_table.indexes, md_db_table.indexes): - self.assertEqual(md_index.name, md_db_index.name) - self.assertEqual(md_index.columns.keys(), md_db_index.columns.keys()) + schema = Schema.model_validate(self.yaml_data) + schema.name = "main" + builder = MetaDataBuilder(schema) + md = builder.build() + + ctx = DatabaseContext(md, self.engine) + + ctx.create_all() + + md_db = MetaData() + md_db.reflect(self.engine.connect(), schema=schema.name) + + self.assertEqual(md_db.tables.keys(), md.tables.keys()) + + for md_table_name in md.tables.keys(): + md_table = md.tables[md_table_name] + md_db_table = md_db.tables[md_table_name] + self.assertEqual(md_table.columns.keys(), md_db_table.columns.keys()) + for md_column_name in md_table.columns.keys(): + md_column = md_table.columns[md_column_name] + md_db_column = md_db_table.columns[md_column_name] + self.assertEqual(type(md_column.type), type(md_db_column.type)) + self.assertEqual(md_column.nullable, md_db_column.nullable) + self.assertEqual(md_column.primary_key, md_db_column.primary_key) + self.assertTrue( + (md_table.constraints and md_db_table.constraints) + or (not md_table.constraints and not md_table.constraints), + "Constraints not created correctly", + ) + if md_table.constraints: + self.assertEqual(len(md_table.constraints), len(md_db_table.constraints)) + md_constraints = _sorted_constraints(md_table.constraints) + md_db_constraints = _sorted_constraints(md_db_table.constraints) + for md_constraint, md_db_constraint in zip(md_constraints, md_db_constraints): + self.assertEqual(md_constraint.name, md_db_constraint.name) + self.assertEqual(md_constraint.deferrable, md_db_constraint.deferrable) + self.assertEqual(md_constraint.initially, md_db_constraint.initially) + self.assertEqual( + type(md_constraint), type(md_db_constraint), "Constraint types do not match" + ) + if isinstance(md_constraint, ForeignKeyConstraint) and isinstance( + md_db_constraint, ForeignKeyConstraint + ): + md_fk: ForeignKeyConstraint = md_constraint + md_db_fk: ForeignKeyConstraint = md_db_constraint + self.assertEqual(md_fk.referred_table.name, md_db_fk.referred_table.name) + self.assertEqual(md_fk.column_keys, md_db_fk.column_keys) + elif isinstance(md_constraint, UniqueConstraint) and isinstance( + md_db_constraint, UniqueConstraint + ): + md_uniq: UniqueConstraint = md_constraint + md_db_uniq: UniqueConstraint = md_db_constraint + self.assertEqual(md_uniq.columns.keys(), md_db_uniq.columns.keys()) + elif isinstance(md_constraint, CheckConstraint) and isinstance( + md_db_constraint, CheckConstraint + ): + md_check: CheckConstraint = md_constraint + md_db_check: CheckConstraint = md_db_constraint + self.assertEqual(str(md_check.sqltext), str(md_db_check.sqltext)) + self.assertTrue( + (md_table.indexes and md_db_table.indexes) or (not md_table.indexes and not md_table.indexes), + "Indexes not created correctly", + ) + if md_table.indexes: + md_indexes = _sorted_indexes(md_table.indexes) + md_db_indexes = _sorted_indexes(md_db_table.indexes) + self.assertEqual(len(md_indexes), len(md_db_indexes)) + for md_index, md_db_index in zip(md_table.indexes, md_db_table.indexes): + self.assertEqual(md_index.name, md_db_index.name) + self.assertEqual(md_index.columns.keys(), md_db_index.columns.keys()) def test_builder(self) -> None: """Test that the information in the metadata object created by the diff --git a/tests/test_postgresql.py b/tests/test_postgres.py similarity index 100% rename from tests/test_postgresql.py rename to tests/test_postgres.py diff --git a/tests/test_tap_schema.py b/tests/test_tap_schema.py new file mode 100644 index 00000000..87fb020b --- /dev/null +++ b/tests/test_tap_schema.py @@ -0,0 +1,327 @@ +# This file is part of felis. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import os +import unittest +from typing import Any + +from sqlalchemy import Engine, MetaData, create_engine, select + +from felis import tap +from felis.datamodel import Schema +from felis.tap_schema import DataLoader, TableManager +from felis.tests.utils import mk_temp_dir, open_test_file, rm_temp_dir + + +class TableManagerTestCase(unittest.TestCase): + """Test the `TableManager` class.""" + + def setUp(self) -> None: + """Set up the test case.""" + with open_test_file("sales.yaml") as test_file: + self.schema = Schema.from_stream(test_file) + + def test_create_table_manager(self) -> None: + """Test the TAP table manager class.""" + mgr = TableManager() + + schema_name = mgr.schema_name + + # Check the created metadata and tables. + self.assertNotEqual(len(mgr.metadata.tables), 0) + self.assertEqual(mgr.metadata.schema, schema_name) + for table_name in mgr.get_table_names_std(): + mgr[table_name] + + # Make sure that creating a new table manager works when one has + # already been created. + mgr = TableManager() + + +class DataLoaderTestCase(unittest.TestCase): + """Test the `DataLoader` class.""" + + def setUp(self) -> None: + """Set up the test case.""" + with open_test_file("sales.yaml") as test_file: + self.schema = Schema.from_stream(test_file) + + self.tmpdir = mk_temp_dir() + + def tearDown(self) -> None: + """Clean up temporary directory.""" + rm_temp_dir(self.tmpdir) + + def test_sqlite(self) -> None: + """Test the `DataLoader` using an in-memory SQLite database.""" + engine = create_engine("sqlite:///:memory:") + + mgr = TableManager(apply_schema_to_metadata=False) + mgr.initialize_database(engine) + + loader = DataLoader(self.schema, mgr, engine) + loader.load() + + def test_sql_output(self) -> None: + """Test printing SQL to stdout and writing SQL to a file.""" + engine = create_engine("sqlite:///:memory:") + mgr = TableManager(apply_schema_to_metadata=False) + loader = DataLoader(self.schema, mgr, engine, dry_run=True, print_sql=True) + loader.load() + + sql_path = os.path.join(self.tmpdir, "test_tap_schema_print_sql.sql") + loader = DataLoader(self.schema, mgr, engine, dry_run=True, print_sql=True, output_path=sql_path) + loader.load() + + self.assertTrue(os.path.exists(sql_path)) + with open(sql_path) as sql_file: + sql_data = sql_file.read() + insert_count = sql_data.count("INSERT INTO") + self.assertEqual( + insert_count, + 21, + f"Expected 21 'INSERT INTO' statements, found {insert_count}", + ) + + +def _find_row(rows: list[dict[str, Any]], column_name: str, value: str) -> dict[str, Any]: + next_row = next( + (row for row in rows if row[column_name] == value), + None, + ) + assert next_row is not None + assert isinstance(next_row, dict) + return next_row + + +def _fetch_results(_engine: Engine, _metadata: MetaData) -> dict: + results: dict[str, Any] = {} + with _engine.connect() as connection: + for table_name in TableManager.get_table_names_std(): + tap_table = _metadata.tables[table_name] + primary_key_columns = tap_table.primary_key.columns + stmt = select(tap_table).order_by(*primary_key_columns) + result = connection.execute(stmt) + column_data = [row._asdict() for row in result] + results[table_name] = column_data + return results + + +class TapSchemaDataTest(unittest.TestCase): + """Test the validity of generated TAP SCHEMA data.""" + + def setUp(self) -> None: + """Set up the test case.""" + with open_test_file("test_tap_schema.yaml") as test_file: + self.schema = Schema.from_stream(test_file, context={"id_generation": True}) + + self.engine = create_engine("sqlite:///:memory:") + + mgr = TableManager(apply_schema_to_metadata=False) + mgr.initialize_database(self.engine) + self.mgr = mgr + + loader = DataLoader(self.schema, mgr, self.engine, tap_schema_index=2) + loader.load() + + self.md = MetaData() + self.md.reflect(self.engine) + + def test_schemas(self) -> None: + schemas_table = self.mgr["schemas"] + with self.engine.connect() as connection: + result = connection.execute(select(schemas_table)) + schema_data = [row._asdict() for row in result] + + self.assertEqual(len(schema_data), 1) + + schema = schema_data[0] + self.assertEqual(schema["schema_name"], "test_schema") + self.assertEqual(schema["description"], "Test schema") + self.assertEqual(schema["utype"], "Schema") + self.assertEqual(schema["schema_index"], 2) + + def test_tables(self) -> None: + tables_table = self.mgr["tables"] + with self.engine.connect() as connection: + result = connection.execute(select(tables_table)) + table_data = [row._asdict() for row in result] + + self.assertEqual(len(table_data), 2) + + table = table_data[0] + assert isinstance(table, dict) + self.assertEqual(table["schema_name"], "test_schema") + self.assertEqual(table["table_name"], f"{self.schema.name}.table1") + self.assertEqual(table["table_type"], "table") + self.assertEqual(table["utype"], "Table") + self.assertEqual(table["description"], "Test table 1") + self.assertEqual(table["table_index"], 2) + + def test_columns(self) -> None: + columns_table = self.mgr["columns"] + with self.engine.connect() as connection: + result = connection.execute(select(columns_table)) + column_data = [row._asdict() for row in result] + + table1_rows = [row for row in column_data if row["table_name"] == f"{self.schema.name}.table1"] + self.assertNotEqual(len(table1_rows), 0) + + boolean_col = _find_row(table1_rows, "column_name", "boolean_field") + self.assertEqual(boolean_col["datatype"], "boolean") + self.assertEqual(boolean_col["arraysize"], None) + + byte_col = _find_row(table1_rows, "column_name", "byte_field") + self.assertEqual(byte_col["datatype"], "unsignedByte") + self.assertEqual(byte_col["arraysize"], None) + + short_col = _find_row(table1_rows, "column_name", "short_field") + self.assertEqual(short_col["datatype"], "short") + self.assertEqual(short_col["arraysize"], None) + + int_col = _find_row(table1_rows, "column_name", "int_field") + self.assertEqual(int_col["datatype"], "int") + self.assertEqual(int_col["arraysize"], None) + + float_col = _find_row(table1_rows, "column_name", "float_field") + self.assertEqual(float_col["datatype"], "float") + self.assertEqual(float_col["arraysize"], None) + + double_col = _find_row(table1_rows, "column_name", "double_field") + self.assertEqual(double_col["datatype"], "double") + self.assertEqual(double_col["arraysize"], None) + + long_col = _find_row(table1_rows, "column_name", "long_field") + self.assertEqual(long_col["datatype"], "long") + self.assertEqual(long_col["arraysize"], None) + + unicode_col = _find_row(table1_rows, "column_name", "unicode_field") + self.assertEqual(unicode_col["datatype"], "unicodeChar") + self.assertEqual(unicode_col["arraysize"], "128*") + + binary_col = _find_row(table1_rows, "column_name", "binary_field") + self.assertEqual(binary_col["datatype"], "unsignedByte") + self.assertEqual(binary_col["arraysize"], "1024*") + + ts = _find_row(table1_rows, "column_name", "timestamp_field") + self.assertEqual(ts["datatype"], "char") + self.assertEqual(ts["xtype"], "timestamp") + self.assertEqual(ts["description"], "Timestamp field") + self.assertEqual(ts["utype"], "Obs:Timestamp") + self.assertEqual(ts["unit"], "s") + self.assertEqual(ts["ucd"], "time.epoch") + self.assertEqual(ts["principal"], 1) + self.assertEqual(ts["std"], 1) + self.assertEqual(ts["column_index"], 42) + self.assertEqual(ts["size"], None) + self.assertEqual(ts["arraysize"], "*") + + char_col = _find_row(table1_rows, "column_name", "char_field") + self.assertEqual(char_col["datatype"], "char") + self.assertEqual(char_col["arraysize"], "64") + + str_col = _find_row(table1_rows, "column_name", "string_field") + self.assertEqual(str_col["datatype"], "char") + self.assertEqual(str_col["arraysize"], "256*") + + txt_col = _find_row(table1_rows, "column_name", "text_field") + self.assertEqual(txt_col["datatype"], "char") + self.assertEqual(txt_col["arraysize"], "*") + + def test_keys(self) -> None: + keys_table = self.mgr["keys"] + with self.engine.connect() as connection: + result = connection.execute(select(keys_table)) + key_data = [row._asdict() for row in result] + + self.assertEqual(len(key_data), 1) + + key = key_data[0] + assert isinstance(key, dict) + + self.assertEqual(key["key_id"], "fk_table1_to_table2") + self.assertEqual(key["from_table"], f"{self.schema.name}.table1") + self.assertEqual(key["target_table"], f"{self.schema.name}.table2") + self.assertEqual(key["description"], "Foreign key from table1 to table2") + self.assertEqual(key["utype"], "ForeignKey") + + def test_key_columns(self) -> None: + key_columns_table = self.mgr["key_columns"] + with self.engine.connect() as connection: + result = connection.execute(select(key_columns_table)) + key_column_data = [row._asdict() for row in result] + + self.assertEqual(len(key_column_data), 1) + + key_column = key_column_data[0] + assert isinstance(key_column, dict) + + self.assertEqual(key_column["key_id"], "fk_table1_to_table2") + self.assertEqual(key_column["from_column"], "fk") + self.assertEqual(key_column["target_column"], "id") + + def test_bad_table_name(self) -> None: + """Test getting a bad TAP_SCHEMA table name.""" + with self.assertRaises(KeyError): + self.mgr["bad_table"] + + def test_compare_to_tap(self) -> None: + """Test that the generated data matches the records generated by the + ``tap`` module, which ``tap_schema`` is designed to deprecate and + eventually replace. + """ + tap_tables = tap.init_tables() + + # Load the TAP_SCHEMA data using the tap module. + tap_engine = create_engine("sqlite:///:memory:") + tap.Tap11Base.metadata.create_all(tap_engine) + visitor = tap.TapLoadingVisitor( + tap_engine, + tap_tables=tap_tables, + tap_schema_index=2, + ) + visitor.visit_schema(self.schema) + + # Reflect the generated TAP_SCHEMA data generated by the tap module. + tap_md = MetaData() + tap_md.reflect(tap_engine) + + # Gather data generated by tap. + tap_results = _fetch_results(tap_engine, tap_md) + + # Gather data generated by tap_schema. + tap_schema_results = _fetch_results(self.engine, self.md) + + # Table names should match. + self.assertSetEqual(set(tap_results.keys()), set(tap_schema_results.keys())) + + # Perform a row-by-row comparison of the data. + for table_name in tap_results: + print(f"Comparing {table_name}") + tap_data = tap_results[table_name] + tap_schema_data = tap_schema_results[table_name] + + self.assertEqual(len(tap_data), len(tap_schema_data)) + + for tap_row, tap_schema_row in zip(tap_data, tap_schema_data): + print("tap: " + str(tap_row)) + print("tap_schema: " + str(tap_schema_row)) + self.assertDictEqual(tap_row, tap_schema_row) diff --git a/tests/test_tap_schema_postgres.py b/tests/test_tap_schema_postgres.py new file mode 100644 index 00000000..1691f5b5 --- /dev/null +++ b/tests/test_tap_schema_postgres.py @@ -0,0 +1,149 @@ +# This file is part of felis. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import gc +import os +import unittest + +from sqlalchemy.engine import create_engine +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.schema import CreateSchema + +from felis.datamodel import Schema +from felis.db.utils import DatabaseContext +from felis.metadata import MetaDataBuilder +from felis.tap_schema import DataLoader, TableManager +from felis.tests.utils import open_test_file + +try: + from testing.postgresql import Postgresql +except ImportError: + Postgresql = None + +TESTDIR = os.path.abspath(os.path.dirname(__file__)) +TEST_YAML = os.path.join(TESTDIR, "data", "sales.yaml") + + +class TestTapSchemaPostgresql(unittest.TestCase): + """Test TAP_SCHEMA for PostgreSQL""" + + def setUp(self) -> None: + """Set up a local PostgreSQL database and a test schema.""" + # Skip the test if the testing.postgresql package is not installed. + if not Postgresql: + self.skipTest("testing.postgresql not installed") + + # Start a PostgreSQL database for testing. + self.postgresql = Postgresql() + url = self.postgresql.url() + self.engine = create_engine(url) + + # Setup a test schema. + self.test_schema = Schema.from_uri(TEST_YAML) + + def test_create_metadata(self) -> None: + """Test loading of data into a PostgreSQL TAP_SCHEMA database created + by the `~felis.tap_schema.TableManager`. + """ + try: + # Create the TAP_SCHEMA database. + mgr = TableManager() + mgr.initialize_database(self.engine) + + # Load the test data into the database. + loader = DataLoader(self.test_schema, mgr, self.engine, 1) + loader.load() + finally: + # Drop the schema. + DatabaseContext(metadata=mgr.metadata, engine=self.engine).drop() + + def test_reflect_database(self) -> None: + """Test reflecting an existing PostgreSQL TAP_SCHEMA database into a + `~felis.tap_schema.TableManager`. + """ + try: + # Build the TAP_SCHEMA database independently of the TableManager. + schema = TableManager.load_schema_resource() + md = MetaDataBuilder(schema).build() + with self.engine.connect() as conn: + trans = conn.begin() + try: + print(f"Creating schema '{schema.name}'") + conn.execute(CreateSchema(schema.name, if_not_exists=False)) + trans.commit() + except SQLAlchemyError as e: + trans.rollback() + self.fail(f"Failed to create schema: {e}") + try: + print(f"Creating tables in schema: {md.schema}") + md.create_all(self.engine) + except SQLAlchemyError as e: + self.fail(f"Failed to create database: {e}") + + # Reflect the existing database into a TableManager. + mgr = TableManager(engine=self.engine) + self.assertIsNotNone(mgr.metadata) + self.assertGreater(len(mgr.metadata.tables), 0) + table_names = set( + [table_name.replace(f"{schema.name}.", "") for table_name in mgr.metadata.tables.keys()] + ) + self.assertEqual(table_names, set(TableManager.get_table_names_std())) + + # See if test data can be loaded successfully using the existing + # database. + loader = DataLoader(self.test_schema, mgr, self.engine, 1) + loader.load() + finally: + # Drop the schema. + DatabaseContext(metadata=mgr.metadata, engine=self.engine).drop() + + def test_nonstandard_names(self) -> None: + """Test the TAP table manager class with non-standard names for the + schema and columns, which are present in the test YAML file used + to create the TAP_SCHEMA database. + """ + try: + with open_test_file("test_tap_schema_nonstandard.yaml") as file: + sch = Schema.from_stream(file, context={"id_generation": True}) + md = MetaDataBuilder(sch).build() + ctx = DatabaseContext(md, self.engine) + ctx.initialize() + ctx.create_all() + + postfix = "11" + mgr = TableManager(engine=self.engine, table_name_postfix=postfix, schema_name=sch.name) + for table_name in mgr.get_table_names_std(): + table = mgr[table_name] + self.assertEqual(table.name, f"{table_name}{postfix}".replace(f"{sch.name}", "")) + finally: + if ctx: + ctx.drop() + + def test_bad_engine(self) -> None: + """Test the TableManager class with an invalid engine.""" + bad_engine = create_engine("postgresql+psycopg2://fake_user:fake_password@fake_host:5555") + with self.assertRaises(SQLAlchemyError): + TableManager(engine=bad_engine) + + def tearDown(self) -> None: + """Tear down the test case.""" + gc.collect() + self.engine.dispose() diff --git a/ups/felis.table b/ups/felis.table index 32e27fb7..e7e94bd1 100644 --- a/ups/felis.table +++ b/ups/felis.table @@ -1,5 +1,6 @@ setupRequired(sconsUtils) setupRequired(utils) +setupRequired(resources) envPrepend(PATH, ${PRODUCT_DIR}/bin) envPrepend(PYTHONPATH, ${PRODUCT_DIR}/python)