From 96c1ff7747ea38e23a3892f74b75ee555ed49c88 Mon Sep 17 00:00:00 2001 From: Arne Luenser Date: Wed, 19 Jul 2023 12:33:58 +0200 Subject: [PATCH] feat: allow extra migrations in NewPersister --- driver/registry.go | 14 +++++++++++--- driver/registry_default.go | 2 +- persistence/sql/persister.go | 5 +++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/driver/registry.go b/driver/registry.go index 6bf98a101c75..38c87baf5c9b 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -5,6 +5,7 @@ package driver import ( "context" + "io/fs" "github.com/ory/kratos/selfservice/sessiontokenexchange" "github.com/ory/x/contextx" @@ -182,6 +183,7 @@ type options struct { config *config.Config replaceTracer func(*otelx.Tracer) *otelx.Tracer inspect func(Registry) error + extraMigrations []fs.FS } type RegistryOption func(*options) @@ -190,24 +192,30 @@ func SkipNetworkInit(o *options) { o.skipNetworkInit = true } -func WithConfig(config *config.Config) func(o *options) { +func WithConfig(config *config.Config) RegistryOption { return func(o *options) { o.config = config } } -func ReplaceTracer(f func(*otelx.Tracer) *otelx.Tracer) func(o *options) { +func ReplaceTracer(f func(*otelx.Tracer) *otelx.Tracer) RegistryOption { return func(o *options) { o.replaceTracer = f } } -func Inspect(f func(reg Registry) error) func(o *options) { +func Inspect(f func(reg Registry) error) RegistryOption { return func(o *options) { o.inspect = f } } +func WithExtraMigrations(m ...fs.FS) RegistryOption { + return func(o *options) { + o.extraMigrations = append(o.extraMigrations, m...) + } +} + func newOptions(os []RegistryOption) *options { o := new(options) for _, f := range os { diff --git a/driver/registry_default.go b/driver/registry_default.go index f4e9ba3fb040..f1307e310902 100644 --- a/driver/registry_default.go +++ b/driver/registry_default.go @@ -615,7 +615,7 @@ func (m *RegistryDefault) Init(ctx context.Context, ctxer contextx.Contextualize m.Logger().WithError(err).Warnf("Unable to open database, retrying.") return errors.WithStack(err) } - p, err := sql.NewPersister(ctx, m, c) + p, err := sql.NewPersister(ctx, m, c, o.extraMigrations...) if err != nil { m.Logger().WithError(err).Warnf("Unable to initialize persister, retrying.") return err diff --git a/persistence/sql/persister.go b/persistence/sql/persister.go index 9762823b1056..c7c98188dc78 100644 --- a/persistence/sql/persister.go +++ b/persistence/sql/persister.go @@ -6,6 +6,7 @@ package sql import ( "context" "embed" + "io/fs" "time" "github.com/gobuffalo/pop/v6" @@ -53,8 +54,8 @@ type ( } ) -func NewPersister(ctx context.Context, r persisterDependencies, c *pop.Connection) (*Persister, error) { - m, err := popx.NewMigrationBox(mergefs.Merge(migrations, networkx.Migrations), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0)) +func NewPersister(ctx context.Context, r persisterDependencies, c *pop.Connection, extraMigrations ...fs.FS) (*Persister, error) { + m, err := popx.NewMigrationBox(mergefs.Merge(append([]fs.FS{migrations, networkx.Migrations}, extraMigrations...)...), popx.NewMigrator(c, r.Logger(), r.Tracer(ctx), 0)) if err != nil { return nil, err }