diff --git a/pkg/diff/policy_sql_generator.go b/pkg/diff/policy_sql_generator.go index 00fd5fe..f09878f 100644 --- a/pkg/diff/policy_sql_generator.go +++ b/pkg/diff/policy_sql_generator.go @@ -262,17 +262,17 @@ func (psg *policySQLVertexGenerator) Alter(diff policyDiff) ([]Statement, error) }}, nil } -func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy) string { - return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName) +func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy, diffType diffType) sqlVertexId { + return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName, diffType) } -func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string) string { - return buildVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName)) +func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName), diffType) } func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolicy schema.Policy) ([]dependency, error) { deps := []dependency{ - mustRun(psg.GetSQLVertexId(newPolicy), diffTypeDelete).beforeSchemaObj(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter), + mustRun(psg.GetSQLVertexId(newPolicy, diffTypeDelete)).before(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)), } newTargetColumns, err := getTargetColumns(newPolicy.Columns, psg.newSchemaColumnsByName) @@ -282,7 +282,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic // Run afterSchemaObj the new columns are added/altered for _, tc := range newTargetColumns { - deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).afterSchemaObj(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter))) } if !cmp.Equal(oldPolicy, schema.Policy{}) { @@ -294,7 +294,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic for _, tc := range oldTargetColumns { // It only needs to run beforeSchemaObj the delete if the column is actually being deleted if _, stillExists := psg.newSchemaColumnsByName[tc.GetName()]; !stillExists { - deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).beforeSchemaObj(buildColumnVertexId(tc.Name), diffTypeDelete)) + deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).before(buildColumnVertexId(tc.Name, diffTypeDelete))) } } } @@ -311,8 +311,8 @@ func (psg *policySQLVertexGenerator) GetDeleteDependencies(pol schema.Policy) ([ } // The policy needs to be deleted beforeSchemaObj all the columns it references are deleted or add/altered for _, c := range columns { - deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).beforeSchemaObj(buildColumnVertexId(c.Name), diffTypeDelete)) - deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).beforeSchemaObj(buildColumnVertexId(c.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeDelete))) + deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeAddAlter))) } return deps, nil diff --git a/pkg/diff/sql_generator.go b/pkg/diff/sql_generator.go index e1547ae..2100a12 100644 --- a/pkg/diff/sql_generator.go +++ b/pkg/diff/sql_generator.go @@ -9,6 +9,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/stripe/pg-schema-diff/internal/graph" "github.com/stripe/pg-schema-diff/internal/pgidentifier" "github.com/stripe/pg-schema-diff/internal/schema" ) @@ -598,11 +599,18 @@ func (schemaSQLGenerator) Alter(diff schemaDiff) ([]Statement, error) { return nil, fmt.Errorf("resolving trigger diff: %w", err) } partialGraph = concatPartialGraphs(partialGraph, triggersPartialGraph) - graph, err := graphFromPartials(partialGraph) + sqlGraph, err := graphFromPartials(partialGraph) if err != nil { return nil, fmt.Errorf("converting to graph: %w", err) } - graphStatements, err := graph.toOrderedStatements() + + var buf strings.Builder + if err := graph.EncodeDOT[sqlVertex](sqlGraph.Graph, &buf, true); err != nil { + panic(err) + } + fmt.Println(buf.String()) + + graphStatements, err := sqlGraph.toOrderedStatements() if err != nil { return nil, fmt.Errorf("getting ordered statements: %w", err) } @@ -1039,18 +1047,22 @@ func replicaIdentityAlterType(identity schema.ReplicaIdentity) (string, error) { return "", fmt.Errorf("unknown/unsupported replica identity %s: %w", identity, ErrNotImplemented) } -func (t *tableSQLVertexGenerator) GetSQLVertexId(table schema.Table) string { - return buildTableVertexId(table.SchemaQualifiedName) +func (t *tableSQLVertexGenerator) GetSQLVertexId(table schema.Table, diffType diffType) sqlVertexId { + return buildTableVertexId(table.SchemaQualifiedName, diffType) +} + +func buildTableVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("table", name.GetFQEscapedName(), diffType) } func (t *tableSQLVertexGenerator) GetAddAlterDependencies(table, _ schema.Table) ([]dependency, error) { deps := []dependency{ - mustRun(t.GetSQLVertexId(table), diffTypeAddAlter).afterSchemaObj(t.GetSQLVertexId(table), diffTypeDelete), + mustRun(t.GetSQLVertexId(table, diffTypeAddAlter)).after(t.GetSQLVertexId(table, diffTypeDelete)), } if table.ParentTable != nil { deps = append(deps, - mustRun(t.GetSQLVertexId(table), diffTypeAddAlter).afterSchemaObj(buildTableVertexId(*table.ParentTable), diffTypeAddAlter), + mustRun(t.GetSQLVertexId(table, diffTypeAddAlter)).after(buildTableVertexId(*table.ParentTable, diffTypeAddAlter)), ) } return deps, nil @@ -1118,7 +1130,7 @@ func (t *tableSQLVertexGenerator) GetDeleteDependencies(table schema.Table) ([]d var deps []dependency if table.ParentTable != nil { deps = append(deps, - mustRun(t.GetSQLVertexId(table), diffTypeDelete).afterSchemaObj(buildTableVertexId(*table.ParentTable), diffTypeDelete), + mustRun(t.GetSQLVertexId(table, diffTypeDelete)).after(buildTableVertexId(*table.ParentTable, diffTypeDelete)), ) } return deps, nil @@ -1370,17 +1382,17 @@ func (csg *columnSQLVertexGenerator) alterColumnPrefix(col schema.Column) string return fmt.Sprintf("%s ALTER COLUMN %s", alterTablePrefix(csg.tableName), schema.EscapeIdentifier(col.Name)) } -func (csg *columnSQLVertexGenerator) GetSQLVertexId(column schema.Column) string { - return buildColumnVertexId(column.Name) +func (csg *columnSQLVertexGenerator) GetSQLVertexId(column schema.Column, diffType diffType) sqlVertexId { + return buildColumnVertexId(column.Name, diffType) } -func buildColumnVertexId(columnName string) string { - return buildVertexId("column", columnName) +func buildColumnVertexId(columnName string, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("column", columnName, diffType) } func (csg *columnSQLVertexGenerator) GetAddAlterDependencies(col, _ schema.Column) ([]dependency, error) { return []dependency{ - mustRun(csg.GetSQLVertexId(col), diffTypeDelete).beforeSchemaObj(csg.GetSQLVertexId(col), diffTypeAddAlter), + mustRun(csg.GetSQLVertexId(col, diffTypeDelete)).before(csg.GetSQLVertexId(col, diffTypeAddAlter)), }, nil } @@ -1474,8 +1486,8 @@ func (rsg *renameConflictingIndexSQLVertexGenerator) Alter(_ indexDiff) ([]State return nil, nil } -func (*renameConflictingIndexSQLVertexGenerator) GetSQLVertexId(index schema.Index) string { - return buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName()) +func (*renameConflictingIndexSQLVertexGenerator) GetSQLVertexId(index schema.Index, diffType diffType) sqlVertexId { + return buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName(), diffType) } func (rsg *renameConflictingIndexSQLVertexGenerator) GetAddAlterDependencies(_, _ schema.Index) ([]dependency, error) { @@ -1486,8 +1498,8 @@ func (rsg *renameConflictingIndexSQLVertexGenerator) GetDeleteDependencies(_ sch return nil, nil } -func buildRenameConflictingIndexVertexId(indexName schema.SchemaQualifiedName) string { - return buildVertexId("indexrename", indexName.GetName()) +func buildRenameConflictingIndexVertexId(indexName schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("indexrename", indexName.GetName(), diffType) } type indexSQLVertexGenerator struct { @@ -1738,21 +1750,25 @@ func buildAttachIndex(index schema.Index) Statement { } } -func (*indexSQLVertexGenerator) GetSQLVertexId(index schema.Index) string { - return buildIndexVertexId(index.GetSchemaQualifiedName()) +func (*indexSQLVertexGenerator) GetSQLVertexId(index schema.Index, diffType diffType) sqlVertexId { + return buildIndexVertexId(index.GetSchemaQualifiedName(), diffType) +} + +func buildIndexVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("index", name.GetFQEscapedName(), diffType) } func (isg *indexSQLVertexGenerator) GetAddAlterDependencies(index, _ schema.Index) ([]dependency, error) { dependencies := []dependency{ - mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).afterSchemaObj(buildTableVertexId(index.OwningTable), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeAddAlter)).after(buildTableVertexId(index.OwningTable, diffTypeAddAlter)), // To allow for online changes to indexes, rename the older version of the index (if it exists) beforeSchemaObj the new version is added - mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).afterSchemaObj(buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName()), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeAddAlter)).after(buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName(), diffTypeAddAlter)), } if index.ParentIdx != nil { // Partitions of indexes must be created afterSchemaObj the parent index is created dependencies = append(dependencies, - mustRun(isg.GetSQLVertexId(index), diffTypeAddAlter).afterSchemaObj(buildIndexVertexId(*index.ParentIdx), diffTypeAddAlter)) + mustRun(isg.GetSQLVertexId(index, diffTypeAddAlter)).after(buildIndexVertexId(*index.ParentIdx, diffTypeAddAlter))) } return dependencies, nil @@ -1760,16 +1776,16 @@ func (isg *indexSQLVertexGenerator) GetAddAlterDependencies(index, _ schema.Inde func (isg *indexSQLVertexGenerator) GetDeleteDependencies(index schema.Index) ([]dependency, error) { dependencies := []dependency{ - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).afterSchemaObj(buildTableVertexId(index.OwningTable), diffTypeDelete), + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).after(buildTableVertexId(index.OwningTable, diffTypeDelete)), // Drop the index afterSchemaObj it has been potentially renamed - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).afterSchemaObj(buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName()), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).after(buildRenameConflictingIndexVertexId(index.GetSchemaQualifiedName(), diffTypeAddAlter)), } if index.ParentIdx != nil { // Since dropping the parent index will cause the partition of the index to drop, the parent drop should come // beforeSchemaObj dependencies = append(dependencies, - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).afterSchemaObj(buildIndexVertexId(*index.ParentIdx), diffTypeDelete)) + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).after(buildIndexVertexId(*index.ParentIdx, diffTypeDelete))) } dependencies = append(dependencies, isg.addDepsOnTableAddAlterIfNecessary(index)...) @@ -1787,14 +1803,14 @@ func (isg *indexSQLVertexGenerator) addDepsOnTableAddAlterIfNecessary(index sche // These dependencies will force the index deletion statement to come beforeSchemaObj the table AddAlter addAlterColumnDeps := []dependency{ - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).beforeSchemaObj(buildTableVertexId(index.OwningTable), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).before(buildTableVertexId(index.OwningTable, diffTypeAddAlter)), } if parentTable.ParentTable != nil { // If the table is partitioned, columns modifications occur on the base table not the children. Thus, we // need the dependency to also be on the parent table add/alter statements addAlterColumnDeps = append( addAlterColumnDeps, - mustRun(isg.GetSQLVertexId(index), diffTypeDelete).beforeSchemaObj(buildTableVertexId(*parentTable.ParentTable), diffTypeAddAlter), + mustRun(isg.GetSQLVertexId(index, diffTypeDelete)).before(buildTableVertexId(*parentTable.ParentTable, diffTypeAddAlter)), ) } @@ -1914,13 +1930,13 @@ func (csg *checkConstraintSQLVertexGenerator) Alter(diff checkConstraintDiff) ([ return stmts, nil } -func (*checkConstraintSQLVertexGenerator) GetSQLVertexId(con schema.CheckConstraint) string { - return buildVertexId("checkconstraint", con.Name) +func (*checkConstraintSQLVertexGenerator) GetSQLVertexId(con schema.CheckConstraint, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("checkconstraint", con.Name, diffType) } func (csg *checkConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ schema.CheckConstraint) ([]dependency, error) { deps := []dependency{ - mustRun(csg.GetSQLVertexId(con), diffTypeDelete).beforeSchemaObj(csg.GetSQLVertexId(con), diffTypeAddAlter), + mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).before(csg.GetSQLVertexId(con, diffTypeAddAlter)), } targetColumns, err := getTargetColumns(con.KeyColumns, csg.newSchemaColumnsByName) @@ -1939,10 +1955,10 @@ func (csg *checkConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ sch if isOnValidNotNullPreExistingColumn { // If the NOT NULL check constraint is on a pre-existing column, then we should ensure it is added beforeSchemaObj // the column alter. - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeAddAlter).beforeSchemaObj(buildColumnVertexId(targetColumns[0].Name), diffTypeAddAlter)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeAddAlter)).before(buildColumnVertexId(targetColumns[0].Name, diffTypeAddAlter))) } else { for _, tc := range targetColumns { - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeAddAlter).afterSchemaObj(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeAddAlter)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter))) } } return deps, nil @@ -1962,22 +1978,22 @@ func (csg *checkConstraintSQLVertexGenerator) GetDeleteDependencies(con schema.C // are backed with a check constraint. // // For all other check constraints, they can rely on the type of the column. Thus, we should drop these - // check constraint beforeSchemaObj any columns are altered because the new type might not be compatible with the old + // check constraint before any columns are altered because the new type might not be compatible with the old // check constraint. if isValidNotNullCC(con) { tc := targetColumns[0] if _, ok := csg.deletedColumnsByName[tc.Name]; ok { // If the column is being deleted, we should drop the not null check constraint beforeSchemaObj the column is deleted. - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeDelete).beforeSchemaObj(buildColumnVertexId(tc.Name), diffTypeDelete)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).before(buildColumnVertexId(tc.Name, diffTypeDelete))) } else { // Otherwise, we should drop the not null check constraint afterSchemaObj the column is altered. This dependency // doesn't need to be explicitly, since our topological sort prioritizes adds/alters over deletes. Nevertheless, // we'll add it for clarity and to ensure that an error is returned if the delete is not placed afterSchemaObj the alter. - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeDelete).afterSchemaObj(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter))) } } else { for _, tc := range targetColumns { - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeDelete).beforeSchemaObj(buildColumnVertexId(tc.Name), diffTypeAddAlter)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).before(buildColumnVertexId(tc.Name, diffTypeAddAlter))) // This is a weird quirk of our graph system, where if a -> b and b -> c and b does-not-exist, b will be // implicitly created s.t. a -> b -> c (https://github.com/stripe/pg-schema-diff/issues/84) // @@ -1985,7 +2001,7 @@ func (csg *checkConstraintSQLVertexGenerator) GetDeleteDependencies(con schema.C // the column, and "c" is the alter/addition of the column. We do not want this behavior. We only want // a -> b -> c iff the column is being deleted. if _, ok := csg.deletedColumnsByName[tc.Name]; ok { - deps = append(deps, mustRun(csg.GetSQLVertexId(con), diffTypeDelete).beforeSchemaObj(buildColumnVertexId(tc.Name), diffTypeDelete)) + deps = append(deps, mustRun(csg.GetSQLVertexId(con, diffTypeDelete)).before(buildColumnVertexId(tc.Name, diffTypeDelete))) } } } @@ -2048,8 +2064,8 @@ func (*attachPartitionSQLVertexGenerator) Delete(_ schema.Table) ([]Statement, e return nil, nil } -func (*attachPartitionSQLVertexGenerator) GetSQLVertexId(table schema.Table) string { - return fmt.Sprintf("attachpartition_%s", table.GetName()) +func (*attachPartitionSQLVertexGenerator) GetSQLVertexId(table schema.Table, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("attachpartition", table.GetName(), diffType) } func (a *attachPartitionSQLVertexGenerator) GetAddAlterDependencies(table, old schema.Table) ([]dependency, error) { @@ -2059,7 +2075,7 @@ func (a *attachPartitionSQLVertexGenerator) GetAddAlterDependencies(table, old s } deps := []dependency{ - mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).afterSchemaObj(buildTableVertexId(table.SchemaQualifiedName), diffTypeAddAlter), + mustRun(a.GetSQLVertexId(table, diffTypeAddAlter)).after(buildTableVertexId(table.SchemaQualifiedName, diffTypeAddAlter)), } if _, baseTableIsNew := a.addedTablesByName[table.ParentTable.GetName()]; baseTableIsNew { @@ -2068,14 +2084,14 @@ func (a *attachPartitionSQLVertexGenerator) GetAddAlterDependencies(table, old s // have the PK (this is useful when creating the fresh database schema for migration validation) // If we attach the partition afterSchemaObj the index is built, the index will be automatically built by Postgres for _, idx := range a.indexesInNewSchemaByTableName[table.ParentTable.GetName()] { - deps = append(deps, mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).beforeSchemaObj(buildIndexVertexId(idx.GetSchemaQualifiedName()), diffTypeAddAlter)) + deps = append(deps, mustRun(a.GetSQLVertexId(table, diffTypeAddAlter)).before(buildIndexVertexId(idx.GetSchemaQualifiedName(), diffTypeAddAlter))) } return deps, nil } a.isPartitionAttachedAfterIdxBuildsByTableName[table.GetName()] = true for _, idx := range a.indexesInNewSchemaByTableName[table.GetName()] { - deps = append(deps, mustRun(a.GetSQLVertexId(table), diffTypeAddAlter).afterSchemaObj(buildIndexVertexId(idx.GetSchemaQualifiedName()), diffTypeAddAlter)) + deps = append(deps, mustRun(a.GetSQLVertexId(table, diffTypeAddAlter)).after(buildIndexVertexId(idx.GetSchemaQualifiedName(), diffTypeAddAlter))) } return deps, nil } @@ -2221,25 +2237,25 @@ func (f *foreignKeyConstraintSQLVertexGenerator) Alter(diff foreignKeyConstraint return stmts, nil } -func (*foreignKeyConstraintSQLVertexGenerator) GetSQLVertexId(con schema.ForeignKeyConstraint) string { - return buildVertexId("fkconstraint", con.GetName()) +func (*foreignKeyConstraintSQLVertexGenerator) GetSQLVertexId(con schema.ForeignKeyConstraint, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("fkconstraint", con.GetName(), diffType) } func (f *foreignKeyConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ schema.ForeignKeyConstraint) ([]dependency, error) { deps := []dependency{ - mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).afterSchemaObj(f.GetSQLVertexId(con), diffTypeDelete), - mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).afterSchemaObj(buildTableVertexId(con.OwningTable), diffTypeAddAlter), - mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).afterSchemaObj(buildTableVertexId(con.ForeignTable), diffTypeAddAlter), + mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(f.GetSQLVertexId(con, diffTypeDelete)), + mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(buildTableVertexId(con.OwningTable, diffTypeAddAlter)), + mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(buildTableVertexId(con.ForeignTable, diffTypeAddAlter)), } // This is the slightly lazy way of ensuring the foreign key constraint is added afterSchemaObj the requisite index is // built and marked as valid. // We __could__ do this just for the index the fk depends on, but that's slightly more wiring than we need right now // because of partitioned indexes, which are only valid when all child indexes have been built for _, i := range f.indexesInNewSchemaByTableName[con.ForeignTable.GetName()] { - deps = append(deps, mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).afterSchemaObj(buildIndexVertexId(i.GetSchemaQualifiedName()), diffTypeAddAlter)) + deps = append(deps, mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(buildIndexVertexId(i.GetSchemaQualifiedName(), diffTypeAddAlter))) // Build a dependency on any child index if the index is partitioned for _, c := range f.childrenInNewSchemaByPartitionedIndexName[i.GetName()] { - deps = append(deps, mustRun(f.GetSQLVertexId(con), diffTypeAddAlter).afterSchemaObj(buildIndexVertexId(c.GetSchemaQualifiedName()), diffTypeAddAlter)) + deps = append(deps, mustRun(f.GetSQLVertexId(con, diffTypeAddAlter)).after(buildIndexVertexId(c.GetSchemaQualifiedName(), diffTypeAddAlter))) } } @@ -2248,17 +2264,17 @@ func (f *foreignKeyConstraintSQLVertexGenerator) GetAddAlterDependencies(con, _ func (f *foreignKeyConstraintSQLVertexGenerator) GetDeleteDependencies(con schema.ForeignKeyConstraint) ([]dependency, error) { deps := []dependency{ - mustRun(f.GetSQLVertexId(con), diffTypeDelete).beforeSchemaObj(buildTableVertexId(con.OwningTable), diffTypeDelete), - mustRun(f.GetSQLVertexId(con), diffTypeDelete).beforeSchemaObj(buildTableVertexId(con.ForeignTable), diffTypeDelete), + mustRun(f.GetSQLVertexId(con, diffTypeDelete)).before(buildTableVertexId(con.OwningTable, diffTypeDelete)), + mustRun(f.GetSQLVertexId(con, diffTypeDelete)).before(buildTableVertexId(con.ForeignTable, diffTypeDelete)), } // This is the slightly lazy way of ensuring the foreign key constraint is deleted beforeSchemaObj the index it depends on is deleted // We __could__ do this just for the index the fk depends on, but that's slightly more wiring than we need right now // because of partitioned indexes, which are only valid when all child indexes have been built for _, i := range f.indexInOldSchemaByTableName[con.ForeignTable.GetName()] { - deps = append(deps, mustRun(f.GetSQLVertexId(con), diffTypeDelete).beforeSchemaObj(buildIndexVertexId(i.GetSchemaQualifiedName()), diffTypeDelete)) + deps = append(deps, mustRun(f.GetSQLVertexId(con, diffTypeDelete)).before(buildIndexVertexId(i.GetSchemaQualifiedName(), diffTypeDelete))) // Build a dependency on any child index if the index is partitioned for _, c := range f.childrenInOldSchemaByPartitionedIndexName[i.GetName()] { - deps = append(deps, mustRun(f.GetSQLVertexId(con), diffTypeDelete).beforeSchemaObj(buildIndexVertexId(c.GetSchemaQualifiedName()), diffTypeDelete)) + deps = append(deps, mustRun(f.GetSQLVertexId(con, diffTypeDelete)).before(buildIndexVertexId(c.GetSchemaQualifiedName(), diffTypeDelete))) } } return deps, nil @@ -2355,17 +2371,21 @@ func (s *sequenceSQLVertexGenerator) buildAddAlterSequenceStatement(seq schema.S } } -func (s *sequenceSQLVertexGenerator) GetSQLVertexId(seq schema.Sequence) string { - return buildSequenceVertexId(seq.SchemaQualifiedName) +func (s *sequenceSQLVertexGenerator) GetSQLVertexId(seq schema.Sequence, diffType diffType) sqlVertexId { + return buildSequenceVertexId(seq.SchemaQualifiedName, diffType) +} + +func buildSequenceVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("sequence", name.GetFQEscapedName(), diffType) } func (s *sequenceSQLVertexGenerator) GetAddAlterDependencies(new schema.Sequence, _ schema.Sequence) ([]dependency, error) { deps := []dependency{ - mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).afterSchemaObj(s.GetSQLVertexId(new), diffTypeDelete), + mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).after(s.GetSQLVertexId(new, diffTypeDelete)), } if new.Owner != nil { // Sequences should be added/altered beforeSchemaObj the table they are owned by - deps = append(deps, mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).beforeSchemaObj(buildTableVertexId(new.Owner.TableName), diffTypeAddAlter)) + deps = append(deps, mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).before(buildTableVertexId(new.Owner.TableName, diffTypeAddAlter))) } return deps, nil } @@ -2379,7 +2399,7 @@ func (s *sequenceSQLVertexGenerator) GetDeleteDependencies(seq schema.Sequence) // old owner column delete (equivalent to add/alter) and the sequence add/alter. We can get away with this because // we, so far, no columns are ever "re-created". If we ever do support that, we'll need to revisit this. if seq.Owner != nil { - deps = append(deps, mustRun(s.GetSQLVertexId(seq), diffTypeDelete).afterSchemaObj(buildTableVertexId(seq.Owner.TableName), diffTypeDelete)) + deps = append(deps, mustRun(s.GetSQLVertexId(seq, diffTypeDelete)).after(buildTableVertexId(seq.Owner.TableName, diffTypeDelete))) } return deps, nil } @@ -2402,10 +2422,6 @@ func (s *sequenceSQLVertexGenerator) isDeletedWithColumns(seq schema.Sequence) b return false } -func buildSequenceVertexId(name schema.SchemaQualifiedName) string { - return buildVertexId("sequence", name.GetFQEscapedName()) -} - type sequenceOwnershipSQLVertexGenerator struct{} func (s sequenceOwnershipSQLVertexGenerator) Add(seq schema.Sequence) ([]Statement, error) { @@ -2440,8 +2456,8 @@ func (s sequenceOwnershipSQLVertexGenerator) buildAlterOwnershipStmt(new schema. } } -func (s sequenceOwnershipSQLVertexGenerator) GetSQLVertexId(seq schema.Sequence) string { - return fmt.Sprintf("%s-ownership", buildSequenceVertexId(seq.SchemaQualifiedName)) +func (s sequenceOwnershipSQLVertexGenerator) GetSQLVertexId(seq schema.Sequence, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("sequence_ownership", seq.SchemaQualifiedName.GetFQEscapedName(), diffType) } func (s sequenceOwnershipSQLVertexGenerator) GetAddAlterDependencies(new schema.Sequence, old schema.Sequence) ([]dependency, error) { @@ -2451,17 +2467,17 @@ func (s sequenceOwnershipSQLVertexGenerator) GetAddAlterDependencies(new schema. deps := []dependency{ // Always change ownership afterSchemaObj the sequence has been added/altered - mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).afterSchemaObj(buildSequenceVertexId(new.SchemaQualifiedName), diffTypeAddAlter), + mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).after(buildSequenceVertexId(new.SchemaQualifiedName, diffTypeAddAlter)), } if old.Owner != nil { // Always update ownership beforeSchemaObj the old owner has been deleted - deps = append(deps, mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).beforeSchemaObj(buildTableVertexId(old.Owner.TableName), diffTypeDelete)) + deps = append(deps, mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).before(buildTableVertexId(old.Owner.TableName, diffTypeDelete))) } if new.Owner != nil { // Always update ownership afterSchemaObj the new owner has been created - deps = append(deps, mustRun(s.GetSQLVertexId(new), diffTypeAddAlter).afterSchemaObj(buildTableVertexId(new.Owner.TableName), diffTypeAddAlter)) + deps = append(deps, mustRun(s.GetSQLVertexId(new, diffTypeAddAlter)).after(buildTableVertexId(new.Owner.TableName, diffTypeAddAlter))) } return deps, nil @@ -2591,8 +2607,12 @@ func canFunctionDependenciesBeTracked(function schema.Function) bool { return function.Language == "sql" } -func (f *functionSQLVertexGenerator) GetSQLVertexId(function schema.Function) string { - return buildFunctionVertexId(function.SchemaQualifiedName) +func (f *functionSQLVertexGenerator) GetSQLVertexId(function schema.Function, diffType diffType) sqlVertexId { + return buildFunctionVertexId(function.SchemaQualifiedName, diffType) +} + +func buildFunctionVertexId(name schema.SchemaQualifiedName, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("function", name.GetFQEscapedName(), diffType) } func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFunction schema.Function) ([]dependency, error) { @@ -2601,7 +2621,7 @@ func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFun // because there won't be one if it is being added/altered var deps []dependency for _, depFunction := range newFunction.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(newFunction), diffTypeAddAlter).afterSchemaObj(buildFunctionVertexId(depFunction), diffTypeAddAlter)) + deps = append(deps, mustRun(f.GetSQLVertexId(newFunction, diffTypeAddAlter)).after(buildFunctionVertexId(depFunction, diffTypeAddAlter))) } if !cmp.Equal(oldFunction, schema.Function{}) { @@ -2609,7 +2629,7 @@ func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFun // If the old version of the function calls other functions that are being deleted come, those deletions // must come afterSchemaObj the function is altered, so it is no longer dependent on those dropped functions for _, depFunction := range oldFunction.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(newFunction), diffTypeAddAlter).beforeSchemaObj(buildFunctionVertexId(depFunction), diffTypeDelete)) + deps = append(deps, mustRun(f.GetSQLVertexId(newFunction, diffTypeAddAlter)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) } } @@ -2619,15 +2639,11 @@ func (f *functionSQLVertexGenerator) GetAddAlterDependencies(newFunction, oldFun func (f *functionSQLVertexGenerator) GetDeleteDependencies(function schema.Function) ([]dependency, error) { var deps []dependency for _, depFunction := range function.DependsOnFunctions { - deps = append(deps, mustRun(f.GetSQLVertexId(function), diffTypeDelete).beforeSchemaObj(buildFunctionVertexId(depFunction), diffTypeDelete)) + deps = append(deps, mustRun(f.GetSQLVertexId(function, diffTypeDelete)).before(buildFunctionVertexId(depFunction, diffTypeDelete))) } return deps, nil } -func buildFunctionVertexId(name schema.SchemaQualifiedName) string { - return buildVertexId("function", name.GetFQEscapedName()) -} - type triggerSQLVertexGenerator struct { // functionsInNewSchemaByName is a map of function new to functions in the new schema. // These functions are not necessarily new @@ -2666,8 +2682,8 @@ func (t *triggerSQLVertexGenerator) Alter(diff triggerDiff) ([]Statement, error) }}, nil } -func (t *triggerSQLVertexGenerator) GetSQLVertexId(trigger schema.Trigger) string { - return buildVertexId("trigger", trigger.GetName()) +func (t *triggerSQLVertexGenerator) GetSQLVertexId(trigger schema.Trigger, diffType diffType) sqlVertexId { + return buildSchemaObjVertexId("trigger", trigger.GetName(), diffType) } func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigger schema.Trigger) ([]dependency, error) { @@ -2675,8 +2691,8 @@ func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigg // added and dropped in the same migration. Thus, we don't need a dependency on the delete node of a function // because there won't be one if it is being added/altered deps := []dependency{ - mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).afterSchemaObj(buildFunctionVertexId(newTrigger.Function), diffTypeAddAlter), - mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).afterSchemaObj(buildTableVertexId(newTrigger.OwningTable), diffTypeAddAlter), + mustRun(t.GetSQLVertexId(newTrigger, diffTypeAddAlter)).after(buildFunctionVertexId(newTrigger.Function, diffTypeAddAlter)), + mustRun(t.GetSQLVertexId(newTrigger, diffTypeAddAlter)).after(buildTableVertexId(newTrigger.OwningTable, diffTypeAddAlter)), } if !cmp.Equal(oldTrigger, schema.Trigger{}) { @@ -2684,7 +2700,7 @@ func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigg // If the old version of the trigger called a function being deleted, the function deletion must come afterSchemaObj the // trigger is altered, so the trigger no longer has a dependency on the function deps = append(deps, - mustRun(t.GetSQLVertexId(newTrigger), diffTypeAddAlter).beforeSchemaObj(buildFunctionVertexId(oldTrigger.Function), diffTypeDelete), + mustRun(t.GetSQLVertexId(newTrigger, diffTypeAddAlter)).before(buildFunctionVertexId(oldTrigger.Function, diffTypeDelete)), ) } @@ -2693,15 +2709,11 @@ func (t *triggerSQLVertexGenerator) GetAddAlterDependencies(newTrigger, oldTrigg func (t *triggerSQLVertexGenerator) GetDeleteDependencies(trigger schema.Trigger) ([]dependency, error) { return []dependency{ - mustRun(t.GetSQLVertexId(trigger), diffTypeDelete).beforeSchemaObj(buildFunctionVertexId(trigger.Function), diffTypeDelete), - mustRun(t.GetSQLVertexId(trigger), diffTypeDelete).beforeSchemaObj(buildTableVertexId(trigger.OwningTable), diffTypeDelete), + mustRun(t.GetSQLVertexId(trigger, diffTypeDelete)).before(buildFunctionVertexId(trigger.Function, diffTypeDelete)), + mustRun(t.GetSQLVertexId(trigger, diffTypeDelete)).before(buildTableVertexId(trigger.OwningTable, diffTypeDelete)), }, nil } -func buildVertexId(objType string, id string) string { - return fmt.Sprintf("%s_%s", objType, id) -} - func stripMigrationHazards(stmts ...Statement) []Statement { var noHazardsStmts []Statement for _, stmt := range stmts { diff --git a/pkg/diff/sql_graph.go b/pkg/diff/sql_graph.go index 7277282..bd5a10b 100644 --- a/pkg/diff/sql_graph.go +++ b/pkg/diff/sql_graph.go @@ -4,7 +4,6 @@ import ( "fmt" "github.com/stripe/pg-schema-diff/internal/graph" - "github.com/stripe/pg-schema-diff/internal/schema" ) // sqlVertexId is an interface for a vertex id in the SQL graph @@ -32,6 +31,14 @@ type schemaObjSqlVertexId struct { diffType diffType } +func buildSchemaObjVertexId(objType string, id string, diffType diffType) sqlVertexId { + return schemaObjSqlVertexId{ + // todo(bplunkett) - maybe change to include an obj type property and move the format to the stringify + schemaObjId: fmt.Sprintf("%s_%s", objType, id), + diffType: diffType, + } +} + func (s schemaObjSqlVertexId) String() string { return fmt.Sprintf("%s_%s", s.schemaObjId, s.diffType) } @@ -58,14 +65,6 @@ func (s sqlVertex) GetPriority() int { return len(s.statements) * int(s.priority) } -func buildTableVertexId(name schema.SchemaQualifiedName) string { - return fmt.Sprintf("table_%s", name) -} - -func buildIndexVertexId(name schema.SchemaQualifiedName) string { - return fmt.Sprintf("index_%s", name) -} - // dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve // the diff of a target schema object // @@ -81,19 +80,16 @@ type dependencyBuilder struct { base sqlVertexId } -func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder { +func mustRun(id sqlVertexId) dependencyBuilder { return dependencyBuilder{ - base: schemaObjSqlVertexId{ - schemaObjId: schemaObjId, - diffType: schemaDiffType, - }, + base: id, } } func (d dependencyBuilder) before(id sqlVertexId) dependency { return dependency{ - target: id, source: d.base, + target: id, } } diff --git a/pkg/diff/sql_vertex_generator.go b/pkg/diff/sql_vertex_generator.go index 0c915ac..60cc378 100644 --- a/pkg/diff/sql_vertex_generator.go +++ b/pkg/diff/sql_vertex_generator.go @@ -136,7 +136,7 @@ func generatePartialGraph[S schema.Object, Diff diff[S]](generator sqlVertexGene type legacySqlVertexGenerator[S schema.Object, Diff diff[S]] interface { sqlGenerator[S, Diff] // GetSQLVertexId gets the canonical vertex id to represent the schema object - GetSQLVertexId(S) string + GetSQLVertexId(S, diffType) sqlVertexId // GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the // schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has @@ -180,10 +180,7 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Add(o S) (partialSQLGraph, er return partialSQLGraph{ vertices: []sqlVertex{{ - id: schemaObjSqlVertexId{ - schemaObjId: s.generator.GetSQLVertexId(o), - diffType: diffTypeAddAlter, - }, + id: s.generator.GetSQLVertexId(o, diffTypeAddAlter), priority: sqlPrioritySooner, statements: statements, }}, @@ -203,10 +200,7 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Delete(o S) (partialSQLGraph, return partialSQLGraph{ vertices: []sqlVertex{{ - id: schemaObjSqlVertexId{ - schemaObjId: s.generator.GetSQLVertexId(o), - diffType: diffTypeDelete, - }, + id: s.generator.GetSQLVertexId(o, diffTypeDelete), priority: sqlPriorityLater, statements: statements, }}, @@ -226,10 +220,7 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Alter(d Diff) (partialSQLGrap return partialSQLGraph{ vertices: []sqlVertex{{ - id: schemaObjSqlVertexId{ - schemaObjId: s.generator.GetSQLVertexId(d.GetNew()), - diffType: diffTypeAddAlter, - }, + id: s.generator.GetSQLVertexId(d.GetNew(), diffTypeAddAlter), priority: sqlPrioritySooner, statements: statements, }},