Skip to content

Commit

Permalink
Refactor SQL generation - graph-first approach (#176)
Browse files Browse the repository at this point in the history
* Refactor SQL generation such that the SQL generators take a graph-first approach
  • Loading branch information
bplunkett-stripe authored Oct 2, 2024
1 parent f036cb2 commit 9216a8f
Show file tree
Hide file tree
Showing 5 changed files with 509 additions and 402 deletions.
193 changes: 0 additions & 193 deletions pkg/diff/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@ import (
"fmt"
"sort"

"github.com/stripe/pg-schema-diff/internal/graph"
"github.com/stripe/pg-schema-diff/internal/schema"
)

var ErrNotImplemented = fmt.Errorf("not implemented")
var errDuplicateIdentifier = fmt.Errorf("duplicate identifier")

type diffType string

const (
diffTypeDelete diffType = "DELETE"
diffTypeAddAlter diffType = "ADDALTER"
)

type (
diff[S schema.Object] interface {
GetOld() S
Expand All @@ -32,79 +24,8 @@ type (
// provided diff. Alter, e.g., with a table, might produce add/delete statements
Alter(Diff) ([]Statement, error)
}

// dependency indicates an edge between the SQL to resolve a diff for a source schema object and the SQL to resolve
// the diff of a target schema object
//
// Most SchemaObjects will have two nodes in the SQL graph: a node for delete SQL and a node for add/alter SQL.
// These nodes will almost always be present in the sqlGraph even if the schema object is not being deleted (or added/altered).
// If a node is present for a schema object where the "diffType" is NOT occurring, it will just be a no-op (no SQl statements)
dependency struct {
sourceObjId string
sourceType diffType

targetObjId string
targetType diffType
}
)

type dependencyBuilder struct {
valObjId string
valType diffType
}

func mustRun(schemaObjId string, schemaDiffType diffType) dependencyBuilder {
return dependencyBuilder{
valObjId: schemaObjId,
valType: schemaDiffType,
}
}

func (d dependencyBuilder) before(valObjId string, valType diffType) dependency {
return dependency{
sourceType: d.valType,
sourceObjId: d.valObjId,

targetType: valType,
targetObjId: valObjId,
}
}

func (d dependencyBuilder) after(valObjId string, valType diffType) dependency {
return dependency{
sourceObjId: valObjId,
sourceType: valType,

targetObjId: d.valObjId,
targetType: d.valType,
}
}

// sqlVertexGenerator is used to generate SQL statements for schema objects that have dependency webs
// with other schema objects. The schema object represents a vertex in the graph.
type sqlVertexGenerator[S schema.Object, Diff diff[S]] interface {
sqlGenerator[S, Diff]
// GetSQLVertexId gets the canonical vertex id to represent the schema object
GetSQLVertexId(S) string

// GetAddAlterDependencies gets the dependencies of the SQL generated to resolve the AddAlter diff for the
// schema objects. Dependencies can be formed on any other nodes in the SQL graph, even if the node has
// no statements. If the diff is just an add, then old will be the zero value
//
// These dependencies can also be built in reverse: the SQL returned by the sqlVertexGenerator to resolve the
// diff for the object must always be run before the SQL required to resolve another SQL vertex diff
GetAddAlterDependencies(new S, old S) ([]dependency, error)

// GetDeleteDependencies is the same as above but for deletes.
// Invariant to maintain:
// - If an object X depends on the delete for an object Y (generated by the sqlVertexGenerator), immediately after the
// the (Y, diffTypeDelete) sqlVertex's SQL is run, Y must no longer be present in the schema; either the
// (Y, diffTypeDelete) statements deleted Y or something that vertex depended on deleted Y. In other words, if a
// delete is cascaded by another delete (e.g., index dropped by table drop) and the index SQL is empty,
// the index delete vertex must still have dependency from itself to the object from which the delete cascades down from
GetDeleteDependencies(S) ([]dependency, error)
}

type (
// listDiff represents the differences between two lists.
listDiff[S schema.Object, Diff diff[S]] struct {
Expand Down Expand Up @@ -158,120 +79,6 @@ func (ld listDiff[S, D]) resolveToSQLGroupedByEffect(sqlGenerator sqlGenerator[S
}, nil
}

func (ld listDiff[S, D]) resolveToSQLGraph(generator sqlVertexGenerator[S, D]) (*sqlGraph, error) {
graph := graph.NewGraph[sqlVertex]()

for _, a := range ld.adds {
statements, err := generator.Add(a)
if err != nil {
return nil, fmt.Errorf("generating SQL for add %s: %w", a.GetName(), err)
}

deps, err := generator.GetAddAlterDependencies(a, *new(S))
if err != nil {
return nil, fmt.Errorf("getting dependencies for add %s: %w", a.GetName(), err)
}
if err := addSQLVertexToGraph(graph, sqlVertex{
ObjId: generator.GetSQLVertexId(a),
Statements: statements,
DiffType: diffTypeAddAlter,
}, deps); err != nil {
return nil, fmt.Errorf("adding SQL Vertex for add %s: %w", a.GetName(), err)
}
}

for _, a := range ld.alters {
statements, err := generator.Alter(a)
if err != nil {
return nil, fmt.Errorf("generating SQL for diff %+v: %w", a, err)
}

vertexId := generator.GetSQLVertexId(a.GetOld())
vertexIdAfterAlter := generator.GetSQLVertexId(a.GetNew())
if vertexIdAfterAlter != vertexId {
return nil, fmt.Errorf("an alter lead to a node with a different id: old=%s, new=%s", vertexId, vertexIdAfterAlter)
}

deps, err := generator.GetAddAlterDependencies(a.GetNew(), a.GetOld())
if err != nil {
return nil, fmt.Errorf("getting dependencies for alter %s: %w", a.GetOld().GetName(), err)
}

if err := addSQLVertexToGraph(graph, sqlVertex{
ObjId: vertexId,
Statements: statements,
DiffType: diffTypeAddAlter,
}, deps); err != nil {
return nil, fmt.Errorf("adding SQL Vertex for alter %s: %w", a.GetOld().GetName(), err)
}
}

for _, d := range ld.deletes {
statements, err := generator.Delete(d)
if err != nil {
return nil, fmt.Errorf("generating SQL for delete %s: %w", d.GetName(), err)
}

deps, err := generator.GetDeleteDependencies(d)
if err != nil {
return nil, fmt.Errorf("getting dependencies for delete %s: %w", d.GetName(), err)
}

if err := addSQLVertexToGraph(graph, sqlVertex{
ObjId: generator.GetSQLVertexId(d),
Statements: statements,
DiffType: diffTypeDelete,
}, deps); err != nil {
return nil, fmt.Errorf("adding SQL Vertex for delete %s: %w", d.GetName(), err)
}
}

return (*sqlGraph)(graph), nil
}

func addSQLVertexToGraph(graph *graph.Graph[sqlVertex], vertex sqlVertex, dependencies []dependency) error {
// It's possible the node already exists. merge it if it does
if graph.HasVertexWithId(vertex.GetId()) {
vertex = mergeSQLVertices(graph.GetVertex(vertex.GetId()), vertex)
}
graph.AddVertex(vertex)
for _, dep := range dependencies {
if err := addDependency(graph, dep); err != nil {
return fmt.Errorf("adding dependencies for %s: %w", vertex.GetId(), err)
}
}
return nil
}

func addDependency(graph *graph.Graph[sqlVertex], dep dependency) error {
sourceVertex := sqlVertex{
ObjId: dep.sourceObjId,
DiffType: dep.sourceType,
Statements: nil,
}
targetVertex := sqlVertex{
ObjId: dep.targetObjId,
DiffType: dep.targetType,
Statements: nil,
}

// To maintain the correctness of the graph, we will add a dummy vertex for the missing dependencies
addVertexIfNotExists(graph, sourceVertex)
addVertexIfNotExists(graph, targetVertex)

if err := graph.AddEdge(sourceVertex.GetId(), targetVertex.GetId()); err != nil {
return fmt.Errorf("adding edge from %s to %s: %w", sourceVertex.GetId(), targetVertex.GetId(), err)
}

return nil
}

func addVertexIfNotExists(graph *graph.Graph[sqlVertex], vertex sqlVertex) {
if !graph.HasVertexWithId(vertex.GetId()) {
graph.AddVertex(vertex)
}
}

type schemaObjectEntry[S schema.Object] struct {
index int // index is the index the schema object in the list
obj S
Expand Down
26 changes: 13 additions & 13 deletions pkg/diff/policy_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ type policyDiff struct {
oldAndNew[schema.Policy]
}

func buildPolicyDiffs(psg *policySQLVertexGenerator, old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) {
func buildPolicyDiffs(psg sqlVertexGenerator[schema.Policy, policyDiff], old, new []schema.Policy) (listDiff[schema.Policy, policyDiff], error) {
return diffLists(old, new, func(old, new schema.Policy, _, _ int) (_ policyDiff, requiresRecreate bool, _ error) {
diff := policyDiff{
oldAndNew: oldAndNew[schema.Policy]{
Expand All @@ -131,7 +131,7 @@ type policySQLVertexGenerator struct {
oldSchemaColumnsByName map[string]schema.Column
}

func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*policySQLVertexGenerator, error) {
func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (sqlVertexGenerator[schema.Policy, policyDiff], error) {
var oldSchemaColumnsByName map[string]schema.Column
if oldTable != nil {
if oldTable.SchemaQualifiedName != table.SchemaQualifiedName {
Expand All @@ -140,12 +140,12 @@ func newPolicySQLVertexGenerator(oldTable *schema.Table, table schema.Table) (*p
oldSchemaColumnsByName = buildSchemaObjByNameMap(oldTable.Columns)
}

return &policySQLVertexGenerator{
return legacyToNewSqlVertexGenerator[schema.Policy, policyDiff](&policySQLVertexGenerator{
table: table,
newSchemaColumnsByName: buildSchemaObjByNameMap(table.Columns),
oldTable: oldTable,
oldSchemaColumnsByName: oldSchemaColumnsByName,
}, nil
}), nil
}

func (psg *policySQLVertexGenerator) Add(p schema.Policy) ([]Statement, error) {
Expand Down Expand Up @@ -262,17 +262,17 @@ func (psg *policySQLVertexGenerator) Alter(diff policyDiff) ([]Statement, error)
}}, nil
}

func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy) string {
return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName)
func (psg *policySQLVertexGenerator) GetSQLVertexId(p schema.Policy, diffType diffType) sqlVertexId {
return buildPolicyVertexId(psg.table.SchemaQualifiedName, p.EscapedName, diffType)
}

func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string) string {
return buildVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName))
func buildPolicyVertexId(owningTable schema.SchemaQualifiedName, policyEscapedName string, diffType diffType) sqlVertexId {
return buildSchemaObjVertexId("policy", fmt.Sprintf("%s.%s", owningTable.GetFQEscapedName(), policyEscapedName), diffType)
}

func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolicy schema.Policy) ([]dependency, error) {
deps := []dependency{
mustRun(psg.GetSQLVertexId(newPolicy), diffTypeDelete).before(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter),
mustRun(psg.GetSQLVertexId(newPolicy, diffTypeDelete)).before(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)),
}

newTargetColumns, err := getTargetColumns(newPolicy.Columns, psg.newSchemaColumnsByName)
Expand All @@ -282,7 +282,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic

// Run after the new columns are added/altered
for _, tc := range newTargetColumns {
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).after(buildColumnVertexId(tc.Name), diffTypeAddAlter))
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).after(buildColumnVertexId(tc.Name, diffTypeAddAlter)))
}

if !cmp.Equal(oldPolicy, schema.Policy{}) {
Expand All @@ -294,7 +294,7 @@ func (psg *policySQLVertexGenerator) GetAddAlterDependencies(newPolicy, oldPolic
for _, tc := range oldTargetColumns {
// It only needs to run before the delete if the column is actually being deleted
if _, stillExists := psg.newSchemaColumnsByName[tc.GetName()]; !stillExists {
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy), diffTypeAddAlter).before(buildColumnVertexId(tc.Name), diffTypeDelete))
deps = append(deps, mustRun(psg.GetSQLVertexId(newPolicy, diffTypeAddAlter)).before(buildColumnVertexId(tc.Name, diffTypeDelete)))
}
}
}
Expand All @@ -311,8 +311,8 @@ func (psg *policySQLVertexGenerator) GetDeleteDependencies(pol schema.Policy) ([
}
// The policy needs to be deleted before all the columns it references are deleted or add/altered
for _, c := range columns {
deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeDelete))
deps = append(deps, mustRun(psg.GetSQLVertexId(pol), diffTypeDelete).before(buildColumnVertexId(c.Name), diffTypeAddAlter))
deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeDelete)))
deps = append(deps, mustRun(psg.GetSQLVertexId(pol, diffTypeDelete)).before(buildColumnVertexId(c.Name, diffTypeAddAlter)))
}

return deps, nil
Expand Down
Loading

0 comments on commit 9216a8f

Please sign in to comment.