diff --git a/README.md b/README.md index a21afd9..ffe8dbd 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ pg-schema-diff plan --dsn "postgres://postgres:postgres@localhost:5432/postgres" - Partitions - Functions/Triggers (functions created by extensions are ignored) - Sequences +- Extensions *A comprehensive set of features to ensure the safety of planned migrations:* - Dangerous operations are flagged as hazards and must be approved before a migration can be applied. @@ -142,5 +143,3 @@ an object, it will be treated as a drop and an add # Contributing This project is in its early stages. We appreciate all the feature/bug requests we receive, but we have limited cycles to review direct code contributions at this time. See [Contributing](CONTRIBUTING.md) to learn more. - - diff --git a/internal/migration_acceptance_tests/acceptance_test.go b/internal/migration_acceptance_tests/acceptance_test.go index caaa88a..3329f16 100644 --- a/internal/migration_acceptance_tests/acceptance_test.go +++ b/internal/migration_acceptance_tests/acceptance_test.go @@ -82,12 +82,6 @@ func (suite *acceptanceTestSuite) runTestCases(acceptanceTestCases []acceptanceT } func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expectations, planOpts []diff.PlanOpt) { - // onDbInitQueries will be run on both the old database before the migration and the new database before pg_dump - onDbInitQueries := []string{ - // Enable an extension to enforce that diffing works with extensions enabled - `CREATE EXTENSION amcheck;`, - } - // normalize the subtest if expects.outputState == nil { expects.outputState = tc.newSchemaDDL @@ -98,7 +92,7 @@ func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expe suite.Require().NoError(err) defer oldDb.DropDB() // Apply the old schema - suite.Require().NoError(applyDDL(oldDb, append(onDbInitQueries, tc.oldSchemaDDL...))) + suite.Require().NoError(applyDDL(oldDb, tc.oldSchemaDDL)) // Migrate the old DB oldDBConnPool, err := sql.Open("pgx", oldDb.GetDSN()) @@ -144,7 +138,7 @@ func (suite *acceptanceTestSuite) runSubtest(tc acceptanceTestCase, expects expe oldDbDump, err := pgdump.GetDump(oldDb, pgdump.WithSchemaOnly()) suite.Require().NoError(err) - newDbDump := suite.directlyRunDDLAndGetDump(append(onDbInitQueries, expects.outputState...)) + newDbDump := suite.directlyRunDDLAndGetDump(expects.outputState) suite.Equal(newDbDump, oldDbDump, prettySprintPlan(plan)) // Make sure no diff is found if we try to regenerate a plan diff --git a/internal/migration_acceptance_tests/extensions_cases_test.go b/internal/migration_acceptance_tests/extensions_cases_test.go new file mode 100644 index 0000000..f7998ac --- /dev/null +++ b/internal/migration_acceptance_tests/extensions_cases_test.go @@ -0,0 +1,72 @@ +package migration_acceptance_tests + +import "github.com/stripe/pg-schema-diff/pkg/diff" + +var extensionAcceptanceTestCases = []acceptanceTestCase{ + { + name: "no-op", + oldSchemaDDL: []string{ + ` + CREATE EXTENSION pg_trgm; + CREATE EXTENSION amcheck; + `, + }, + newSchemaDDL: []string{ + ` + CREATE EXTENSION pg_trgm; + CREATE EXTENSION amcheck; + `, + }, + vanillaExpectations: expectations{ + empty: true, + }, + dataPackingExpectations: expectations{ + empty: true, + }, + }, + { + name: "create multiple extensions", + oldSchemaDDL: []string{}, + newSchemaDDL: []string{ + ` + CREATE EXTENSION pg_trgm; + CREATE EXTENSION amcheck; + `, + }, + }, + { + name: "drop one extension", + oldSchemaDDL: []string{ + ` + CREATE EXTENSION pg_trgm; + CREATE EXTENSION amcheck; + `, + }, + newSchemaDDL: []string{ + ` + CREATE EXTENSION pg_trgm; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, + }, + { + name: "upgrade an extension implicitly and explicitly", + oldSchemaDDL: []string{ + ` + CREATE EXTENSION pg_trgm WITH VERSION '1.5'; + CREATE EXTENSION amcheck WITH VERSION '1.3'; + `, + }, + newSchemaDDL: []string{ + ` + CREATE EXTENSION pg_trgm WITH VERSION '1.6'; + CREATE EXTENSION AMCHECK; + `, + }, + expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeExtensionVersionUpgrade}, + }, +} + +func (suite *acceptanceTestSuite) TestExtensionAcceptanceTestCases() { + suite.runTestCases(extensionAcceptanceTestCases) +} diff --git a/internal/migration_acceptance_tests/function_cases_test.go b/internal/migration_acceptance_tests/function_cases_test.go index 504b2b1..9c4bfaf 100644 --- a/internal/migration_acceptance_tests/function_cases_test.go +++ b/internal/migration_acceptance_tests/function_cases_test.go @@ -125,6 +125,25 @@ var functionAcceptanceTestCases = []acceptanceTestCase{ `}, expectedHazardTypes: []diff.MigrationHazardType{diff.MigrationHazardTypeHasUntrackableDependencies}, }, + { + name: "Create function with an extensinon that also creates functions installed", + oldSchemaDDL: []string{ + ` + CREATE EXTENSION amcheck; + `, + }, + newSchemaDDL: []string{ + ` + CREATE EXTENSION amcheck; + + CREATE FUNCTION add(a integer, b integer) RETURNS integer + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT + RETURN a + b; + `, + }, + }, { name: "Drop functions (with conflicting names)", oldSchemaDDL: []string{ diff --git a/internal/migration_acceptance_tests/schema_cases_test.go b/internal/migration_acceptance_tests/schema_cases_test.go index 0570d55..63d0203 100644 --- a/internal/migration_acceptance_tests/schema_cases_test.go +++ b/internal/migration_acceptance_tests/schema_cases_test.go @@ -10,6 +10,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ name: "No-op", oldSchemaDDL: []string{ ` + CREATE EXTENSION amcheck; + CREATE TABLE fizz( ); @@ -71,6 +73,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ }, newSchemaDDL: []string{ ` + CREATE EXTENSION amcheck; + CREATE TABLE fizz( ); @@ -138,9 +142,11 @@ var schemaAcceptanceTests = []acceptanceTestCase{ }, }, { - name: "Drop table, Add Table, Drop seq, Add Seq, Drop Funcs, Add Funcs, Drop Triggers, Add Triggers", + name: "Drop table, Add Table, Drop Seq, Add Seq, Drop Funcs, Add Funcs, Drop Triggers, Add Triggers, Create Extension, Drop Extension, Create Index Using Extension", oldSchemaDDL: []string{ ` + CREATE EXTENSION amcheck; + CREATE TABLE fizz( ); @@ -206,6 +212,8 @@ var schemaAcceptanceTests = []acceptanceTestCase{ }, newSchemaDDL: []string{ ` + CREATE EXTENSION pg_trgm; + CREATE TABLE fizz( ); @@ -262,12 +270,14 @@ var schemaAcceptanceTests = []acceptanceTestCase{ foo INT, bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, fizz timestamptz DEFAULT CURRENT_TIMESTAMP, - buzz REAL NOT NULL + buzz REAL NOT NULL, + quux TEXT ); ALTER TABLE bar ADD CONSTRAINT "FOO_CHECK" CHECK ( foo < bar ); CREATE INDEX bar_normal_idx ON bar(bar); CREATE INDEX bar_another_normal_id ON bar(bar DESC, fizz DESC); - CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + CREATE UNIQUE INDEX bar_unique_idx ON bar(fizz, buzz); + CREATE INDEX gin_index ON bar USING gin (quux gin_trgm_ops); CREATE FUNCTION check_content() RETURNS TRIGGER AS $$ BEGIN @@ -286,15 +296,18 @@ var schemaAcceptanceTests = []acceptanceTestCase{ expectedHazardTypes: []diff.MigrationHazardType{ diff.MigrationHazardTypeDeletesData, diff.MigrationHazardTypeHasUntrackableDependencies, + diff.MigrationHazardTypeIndexBuild, }, dataPackingExpectations: expectations{ outputState: []string{ ` + CREATE EXTENSION pg_trgm; + CREATE TABLE fizz( ); CREATE SEQUENCE new_foobar_sequence - AS SMALLINT + AS SMALLINT INCREMENT BY 4 MINVALUE 10 MAXVALUE 200 START WITH 20 CACHE 10 NO CYCLE @@ -346,12 +359,14 @@ var schemaAcceptanceTests = []acceptanceTestCase{ foo INT, bar DOUBLE PRECISION NOT NULL DEFAULT 8.8, fizz timestamptz DEFAULT CURRENT_TIMESTAMP, - buzz REAL NOT NULL + buzz REAL NOT NULL, + quux TEXT ); ALTER TABLE bar ADD CONSTRAINT "FOO_CHECK" CHECK ( foo < bar ); CREATE INDEX bar_normal_idx ON bar(bar); CREATE INDEX bar_another_normal_id ON bar(bar DESC, fizz DESC); CREATE UNIQUE INDEX bar_unique_idx on bar(fizz, buzz); + CREATE INDEX gin_index ON bar USING gin (quux gin_trgm_ops); CREATE FUNCTION check_content() RETURNS TRIGGER AS $$ BEGIN diff --git a/internal/queries/queries.sql b/internal/queries/queries.sql index 96e7a72..15f2523 100644 --- a/internal/queries/queries.sql +++ b/internal/queries/queries.sql @@ -226,3 +226,15 @@ AND NOT EXISTS ( AND ext_depend.objid = pg_seq.seqrelid AND ext_depend.deptype = 'e' ); + +-- name: GetExtensions :many +SELECT + ext.oid, + ext.extname::TEXT AS extension_name, + ext.extversion AS extension_version, + extension_namespace.nspname::TEXT AS schema_name +FROM pg_catalog.pg_namespace AS extension_namespace +INNER JOIN + pg_catalog.pg_extension AS ext + ON ext.extnamespace = extension_namespace.oid +WHERE extension_namespace.nspname = 'public'; diff --git a/internal/queries/queries.sql.go b/internal/queries/queries.sql.go index a2adac6..d0aee76 100644 --- a/internal/queries/queries.sql.go +++ b/internal/queries/queries.sql.go @@ -221,6 +221,54 @@ func (q *Queries) GetDependsOnFunctions(ctx context.Context, arg GetDependsOnFun return items, nil } +const getExtensions = `-- name: GetExtensions :many +SELECT + ext.oid, + ext.extname::TEXT AS extension_name, + ext.extversion AS extension_version, + extension_namespace.nspname::TEXT AS schema_name +FROM pg_catalog.pg_namespace AS extension_namespace +INNER JOIN + pg_catalog.pg_extension AS ext + ON ext.extnamespace = extension_namespace.oid +WHERE extension_namespace.nspname = 'public' +` + +type GetExtensionsRow struct { + Oid interface{} + ExtensionName string + ExtensionVersion string + SchemaName string +} + +func (q *Queries) GetExtensions(ctx context.Context) ([]GetExtensionsRow, error) { + rows, err := q.db.QueryContext(ctx, getExtensions) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetExtensionsRow + for rows.Next() { + var i GetExtensionsRow + if err := rows.Scan( + &i.Oid, + &i.ExtensionName, + &i.ExtensionVersion, + &i.SchemaName, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getFunctions = `-- name: GetFunctions :many SELECT pg_proc.oid, diff --git a/internal/schema/schema.go b/internal/schema/schema.go index 2d7f636..df82b1e 100644 --- a/internal/schema/schema.go +++ b/internal/schema/schema.go @@ -44,8 +44,9 @@ func (o SchemaQualifiedName) IsEmpty() bool { } type Schema struct { - Tables []Table - Indexes []Index + Extensions []Extension + Tables []Table + Indexes []Index Sequences []Sequence Functions []Function @@ -70,8 +71,8 @@ func (s Schema) Normalize() Schema { s.Tables = normTables s.Indexes = sortSchemaObjectsByName(s.Indexes) - s.Sequences = sortSchemaObjectsByName(s.Sequences) + s.Extensions = sortSchemaObjectsByName(s.Extensions) var normFunctions []Function for _, function := range sortSchemaObjectsByName(s.Functions) { @@ -104,6 +105,11 @@ func (s Schema) Hash() (string, error) { return fmt.Sprintf("%x", hashVal), nil } +type Extension struct { + SchemaQualifiedName + Version string +} + type Table struct { Name string Columns []Column @@ -283,10 +289,15 @@ func (t Trigger) GetName() string { return t.OwningTable.GetFQEscapedName() + "_" + t.EscapedName } -// GetPublicSchema fetches the "public" schema. It is a non-atomic operation +// GetPublicSchema fetches the "public" schema. It is a non-atomic operation. func GetPublicSchema(ctx context.Context, db queries.DBTX) (Schema, error) { q := queries.New(db) + extensions, err := fetchExtensions(ctx, q) + if err != nil { + return Schema{}, fmt.Errorf("fetchExtensions: %w", err) + } + tables, err := fetchTables(ctx, q) if err != nil { return Schema{}, fmt.Errorf("fetchTables: %w", err) @@ -313,14 +324,34 @@ func GetPublicSchema(ctx context.Context, db queries.DBTX) (Schema, error) { } return Schema{ - Tables: tables, - Indexes: indexes, - Sequences: sequences, - Functions: functions, - Triggers: triggers, + Extensions: extensions, + Tables: tables, + Indexes: indexes, + Sequences: sequences, + Functions: functions, + Triggers: triggers, }, nil } +func fetchExtensions(ctx context.Context, q *queries.Queries) ([]Extension, error) { + rawExtensions, err := q.GetExtensions(ctx) + if err != nil { + return nil, fmt.Errorf("GetExtensions(): %w", err) + } + + var extensions []Extension + for _, e := range rawExtensions { + extensions = append(extensions, Extension{ + SchemaQualifiedName: SchemaQualifiedName{ + EscapedName: EscapeIdentifier(e.ExtensionName), + SchemaName: e.SchemaName, + }, + Version: e.ExtensionVersion, + }) + } + return extensions, nil +} + func fetchTables(ctx context.Context, q *queries.Queries) ([]Table, error) { rawTables, err := q.GetTables(ctx) if err != nil { diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go index 3b9141a..76d62fe 100644 --- a/internal/schema/schema_test.go +++ b/internal/schema/schema_test.go @@ -35,6 +35,8 @@ var ( { name: "Simple test", ddl: []string{` + CREATE EXTENSION pg_trgm WITH VERSION '1.6'; + CREATE SEQUENCE foobar_sequence AS BIGINT INCREMENT BY 2 @@ -72,6 +74,7 @@ var ( ALTER TABLE foo ADD CONSTRAINT author_check CHECK (author IS NOT NULL AND LENGTH(author) > 0) NO INHERIT NOT VALID; CREATE INDEX some_idx ON foo USING hash (content); CREATE UNIQUE INDEX some_unique_idx ON foo (created_at DESC, author ASC); + CREATE INDEX some_gin_idx ON foo USING GIN (author gin_trgm_ops); CREATE FUNCTION increment_version() RETURNS TRIGGER AS $$ BEGIN @@ -86,8 +89,17 @@ var ( WHEN (OLD.* IS DISTINCT FROM NEW.*) EXECUTE PROCEDURE increment_version(); `}, - expectedHash: "bfe91c68ad532e0b", + expectedHash: "59c2e1ea7460fbbb", expectedSchema: schema.Schema{ + Extensions: []schema.Extension{ + { + SchemaQualifiedName: schema.SchemaQualifiedName{ + EscapedName: schema.EscapeIdentifier("pg_trgm"), + SchemaName: "public", + }, + Version: "1.6", + }, + }, Tables: []schema.Table{ { Name: "foo", @@ -142,8 +154,12 @@ var ( }, Indexes: []schema.Index{ { - TableName: "foo", - Name: "foo_pkey", Columns: []string{"id"}, IsPk: true, IsUnique: true, ConstraintName: "foo_pkey", + TableName: "foo", + Name: "foo_pkey", + Columns: []string{"id"}, + IsPk: true, + IsUnique: true, + ConstraintName: "foo_pkey", GetIndexDefStmt: "CREATE UNIQUE INDEX foo_pkey ON public.foo USING btree (id)", }, { @@ -152,10 +168,18 @@ var ( GetIndexDefStmt: "CREATE INDEX some_idx ON public.foo USING hash (content)", }, { - TableName: "foo", - Name: "some_unique_idx", Columns: []string{"created_at", "author"}, IsPk: false, IsUnique: true, + TableName: "foo", + Name: "some_unique_idx", + Columns: []string{"created_at", "author"}, + IsUnique: true, GetIndexDefStmt: "CREATE UNIQUE INDEX some_unique_idx ON public.foo USING btree (created_at DESC, author)", }, + { + TableName: "foo", + Name: "some_gin_idx", + Columns: []string{"author"}, + GetIndexDefStmt: "CREATE INDEX some_gin_idx ON public.foo USING gin (author gin_trgm_ops)", + }, }, Functions: []schema.Function{ { @@ -243,7 +267,7 @@ var ( EXECUTE PROCEDURE increment_version(); `}, - expectedHash: "fcc62a48df881935", + expectedHash: "58b60e7d949ba226", expectedSchema: schema.Schema{ Tables: []schema.Table{ { @@ -441,7 +465,7 @@ var ( PRIMARY KEY (author, id) ) FOR VALUES IN ('some author 1'); `}, - expectedHash: "2019e87411d440c4", + expectedHash: "6b678f2eae7b824d", expectedSchema: schema.Schema{ Tables: []schema.Table{ { @@ -490,7 +514,7 @@ var ( "serial" SERIAL NOT NULL ); `}, - expectedHash: "d9ce32695f0154de", + expectedHash: "fe72d1b3a50d54b9", expectedSchema: schema.Schema{ Tables: []schema.Table{ { @@ -550,7 +574,7 @@ var ( ALTER TABLE foobar ADD CONSTRAINT foobar_id_check CHECK (id > 0) NOT VALID; CREATE UNIQUE INDEX foobar_idx ON foobar(content); `}, - expectedHash: "4bb4902b92bd1baf", + expectedHash: "ef43a41b2ac96e18", expectedSchema: schema.Schema{ Tables: []schema.Table{ { @@ -698,7 +722,7 @@ var ( WHEN (OLD.* IS DISTINCT FROM NEW.*) EXECUTE PROCEDURE test.increment_version(); `}, - expectedHash: "41ccaa5beac7cde0", + expectedHash: "72c5e264fb96ed86", expectedSchema: schema.Schema{ Tables: []schema.Table{ { @@ -790,7 +814,7 @@ var ( { name: "Empty Schema", ddl: nil, - expectedHash: "651bc0b5adc6120e", + expectedHash: "4cc7f4f2dd81ec29", expectedSchema: schema.Schema{ Tables: nil, }, @@ -802,7 +826,7 @@ var ( value TEXT ); `}, - expectedHash: "9116df7b20ebce8b", + expectedHash: "24b2cd8d56d1cedd", expectedSchema: schema.Schema{ Tables: []schema.Table{ { diff --git a/pkg/diff/plan.go b/pkg/diff/plan.go index abe1f13..77025fc 100644 --- a/pkg/diff/plan.go +++ b/pkg/diff/plan.go @@ -17,6 +17,7 @@ const ( MigrationHazardTypeIndexDropped MigrationHazardType = "INDEX_DROPPED" MigrationHazardTypeImpactsDatabasePerformance MigrationHazardType = "IMPACTS_DATABASE_PERFORMANCE" MigrationHazardTypeIsUserGenerated MigrationHazardType = "IS_USER_GENERATED" + MigrationHazardTypeExtensionVersionUpgrade MigrationHazardType = "UPGRADING_EXTENSION_VERSION" ) // MigrationHazard represents a hazard that a statement poses to a database diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index 750156a..d7e867d 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -51,7 +51,15 @@ var ( } migrationHazardSequenceCannotTrackDependencies = MigrationHazard{ Type: MigrationHazardTypeHasUntrackableDependencies, - Message: "sequence has no owner, so it can't be tracked. It may be in use by a table or function", + Message: "This sequence has no owner, so it cannot be tracked. It may be in use by a table or function.", + } + migrationHazardExtensionDroppedCannotTrackDependencies = MigrationHazard{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "This extension may be in use by tables, indexes, functions, triggers, etc. Tihs statement will be ran last, so this may be OK.", + } + migrationHazardExtensionAlteredVersionUpgraded = MigrationHazard{ + Type: MigrationHazardTypeExtensionVersionUpgrade, + Message: "This extension's version is being upgraded. Be sure the newer version is backwards compatible with your use case.", } ) @@ -100,15 +108,20 @@ type ( triggerDiff struct { oldAndNew[schema.Trigger] } + + extensionDiff struct { + oldAndNew[schema.Extension] + } ) type schemaDiff struct { oldAndNew[schema.Schema] - tableDiffs listDiff[schema.Table, tableDiff] - indexDiffs listDiff[schema.Index, indexDiff] - sequenceDiffs listDiff[schema.Sequence, sequenceDiff] - functionDiffs listDiff[schema.Function, functionDiff] - triggerDiffs listDiff[schema.Trigger, triggerDiff] + extensionDiffs listDiff[schema.Extension, extensionDiff] + tableDiffs listDiff[schema.Table, tableDiff] + indexDiffs listDiff[schema.Index, indexDiff] + sequenceDiffs listDiff[schema.Sequence, sequenceDiff] + functionDiffs listDiff[schema.Function, functionDiff] + triggerDiffs listDiff[schema.Trigger, triggerDiff] } func (sd schemaDiff) resolveToSQL() ([]Statement, error) { @@ -146,6 +159,21 @@ func (sd schemaDiff) resolveToSQL() ([]Statement, error) { // on other schema objects func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { + extensionDiffs, err := diffLists( + old.Extensions, + new.Extensions, + func(old, new schema.Extension, _, _ int) (extensionDiff, bool, error) { + return extensionDiff{ + oldAndNew[schema.Extension]{ + old: old, + new: new, + }, + }, false, nil + }) + if err != nil { + return schemaDiff{}, false, fmt.Errorf("diffing extensions: %w", err) + } + tableDiffs, err := diffLists(old.Tables, new.Tables, buildTableDiff) if err != nil { return schemaDiff{}, false, fmt.Errorf("diffing tables: %w", err) @@ -215,11 +243,12 @@ func buildSchemaDiff(old, new schema.Schema) (schemaDiff, bool, error) { old: old, new: new, }, - tableDiffs: tableDiffs, - indexDiffs: indexesDiff, - sequenceDiffs: sequencesDiffs, - functionDiffs: functionDiffs, - triggerDiffs: triggerDiffs, + extensionDiffs: extensionDiffs, + tableDiffs: tableDiffs, + indexDiffs: indexesDiff, + sequenceDiffs: sequencesDiffs, + functionDiffs: functionDiffs, + triggerDiffs: triggerDiffs, }, false, nil } @@ -350,6 +379,11 @@ func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { return nil, fmt.Errorf("resolving table sql graphs: %w", err) } + extensionStatements, err := diff.extensionDiffs.resolveToSQLGroupedByEffect(&extensionSQLGenerator{}) + if err != nil { + return nil, fmt.Errorf("resolving extension sql graphs: %w", err) + } + indexesInNewSchemaByTableName := make(map[string][]schema.Index) for _, idx := range diff.new.Indexes { indexesInNewSchemaByTableName[idx.TableName] = append(indexesInNewSchemaByTableName[idx.TableName], idx) @@ -394,18 +428,16 @@ func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { functionsInNewSchemaByName := buildSchemaObjByNameMap(diff.new.Functions) - functionSQLVertexGenerator := functionSQLVertexGenerator{ + functionGraphs, err := diff.functionDiffs.resolveToSQLGraph(&functionSQLVertexGenerator{ functionsInNewSchemaByName: functionsInNewSchemaByName, - } - functionGraphs, err := diff.functionDiffs.resolveToSQLGraph(&functionSQLVertexGenerator) + }) if err != nil { return nil, fmt.Errorf("resolving function sql graphs: %w", err) } - triggerSQLVertexGenerator := triggerSQLVertexGenerator{ + triggerGraphs, err := diff.triggerDiffs.resolveToSQLGraph(&triggerSQLVertexGenerator{ functionsInNewSchemaByName: functionsInNewSchemaByName, - } - triggerGraphs, err := diff.triggerDiffs.resolveToSQLGraph(&triggerSQLVertexGenerator) + }) if err != nil { return nil, fmt.Errorf("resolving trigger sql graphs: %w", err) } @@ -432,7 +464,19 @@ func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { return nil, fmt.Errorf("unioning table and trigger graphs: %w", err) } - return tableGraphs.toOrderedStatements() + graphStatements, err := tableGraphs.toOrderedStatements() + if err != nil { + return nil, fmt.Errorf("getting ordered statements from tableGraph: %w", err) + } + + // We enable extensions first and disable them last since their dependencies may span across + // all other entities in the database. + var statements []Statement + statements = append(statements, extensionStatements.Adds...) + statements = append(statements, extensionStatements.Alters...) + statements = append(statements, graphStatements...) + statements = append(statements, extensionStatements.Deletes...) + return statements, nil } func buildSchemaObjByNameMap[S schema.Object](s []S) map[string]S { @@ -552,7 +596,7 @@ func (t *tableSQLVertexGenerator) Alter(diff tableDiff) ([]Statement, error) { checkConSQLGenerator := checkConstraintSQLGenerator{tableName: diff.new.Name} checkConGeneratedSQL, err := diff.checkConstraintDiff.resolveToSQLGroupedByEffect(&checkConSQLGenerator) if err != nil { - return nil, fmt.Errorf("Resolving check constraints diff: %w", err) + return nil, fmt.Errorf("resolving check constraints diff: %w", err) } var stmts []Statement @@ -1461,6 +1505,61 @@ func (s sequenceOwnershipSQLVertexGenerator) GetDeleteDependencies(_ schema.Sequ return nil } +type extensionSQLGenerator struct{} + +func (e *extensionSQLGenerator) Add(extension schema.Extension) ([]Statement, error) { + s := fmt.Sprintf( + "CREATE EXTENSION %s WITH SCHEMA %s", + extension.EscapedName, + schema.EscapeIdentifier(extension.SchemaName), + ) + + if len(extension.Version) != 0 { + s += fmt.Sprintf(" VERSION %s", schema.EscapeIdentifier(extension.Version)) + } + + return []Statement{{ + DDL: s, + Timeout: statementTimeoutDefault, + Hazards: nil, + }}, nil +} + +func (e *extensionSQLGenerator) Delete(extension schema.Extension) ([]Statement, error) { + return []Statement{{ + DDL: fmt.Sprintf("DROP EXTENSION %s", extension.EscapedName), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardExtensionDroppedCannotTrackDependencies}, + }}, nil +} + +func (e *extensionSQLGenerator) Alter(diff extensionDiff) ([]Statement, error) { + var statements []Statement + if diff.new.Version != diff.old.Version { + if len(diff.new.Version) == 0 { + // This is an implicit upgrade to the latest extension version. + statements = append(statements, Statement{ + DDL: fmt.Sprintf("ALTER EXTENSION %s UPDATE", diff.new.EscapedName), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardExtensionAlteredVersionUpgraded}, + }) + } else { + // We optimistically assume an update path from the old to new version exists. When we + // validate the plan later, any issues will be caught and an error will be thrown. + statements = append(statements, Statement{ + DDL: fmt.Sprintf( + "ALTER EXTENSION %s UPDATE TO %s", + diff.new.EscapedName, + schema.EscapeIdentifier(diff.new.Version), + ), + Timeout: statementTimeoutDefault, + Hazards: []MigrationHazard{migrationHazardExtensionAlteredVersionUpgraded}, + }) + } + } + return statements, nil +} + type functionSQLVertexGenerator struct { // functionsInNewSchemaByName is a map of function new to functions in the new schema. // These functions are not necessarily new diff --git a/pkg/schema/schema_test.go b/pkg/schema/schema_test.go index 8b76bc0..8864fd0 100644 --- a/pkg/schema/schema_test.go +++ b/pkg/schema/schema_test.go @@ -28,6 +28,8 @@ func (suite *schemaTestSuite) TearDownSuite() { func (suite *schemaTestSuite) TestGetPublicSchemaHash() { const ( ddl = ` + CREATE EXTENSION pg_trgm WITH VERSION '1.6'; + CREATE FUNCTION add(a integer, b integer) RETURNS integer LANGUAGE SQL IMMUTABLE @@ -73,7 +75,7 @@ func (suite *schemaTestSuite) TestGetPublicSchemaHash() { EXECUTE PROCEDURE increment_version(); ` - expectedHash = "5fc27d73cebea55" + expectedHash = "7c9c30dde1b65875" ) db, err := suite.pgEngine.CreateDatabase() suite.Require().NoError(err)