From e4913293737713fe49060d0f24233be2565222c0 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Fri, 13 Sep 2024 12:58:48 +0200 Subject: [PATCH] add batch persister test --- persistence/sql/batch/test_persister.go | 90 +++++++++++++++++++++++++ persistence/sql/persister_test.go | 5 ++ 2 files changed, 95 insertions(+) create mode 100644 persistence/sql/batch/test_persister.go diff --git a/persistence/sql/batch/test_persister.go b/persistence/sql/batch/test_persister.go new file mode 100644 index 000000000000..7d476ca60150 --- /dev/null +++ b/persistence/sql/batch/test_persister.go @@ -0,0 +1,90 @@ +package batch + +import ( + "context" + "testing" + + "github.com/gobuffalo/pop/v6" + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/kratos/identity" + "github.com/ory/kratos/persistence" + "github.com/ory/x/dbal" + "github.com/ory/x/otelx" + "github.com/ory/x/sqlcon" +) + +func TestPersister(ctx context.Context, tracer *otelx.Tracer, p persistence.Persister) func(t *testing.T) { + return func(t *testing.T) { + t.Run("method=batch.Create", func(t *testing.T) { + + ident1 := identity.NewIdentity("") + ident1.NID = p.NetworkID(ctx) + ident2 := identity.NewIdentity("") + ident2.NID = p.NetworkID(ctx) + + // Create two identities + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + + err := Create(ctx, conn, []*identity.Identity{ident1, ident2}) + require.NoError(t, err) + + return nil + }) + + require.NotEqual(t, uuid.Nil, ident1.ID) + require.NotEqual(t, uuid.Nil, ident2.ID) + + // Create conflicting verifiable addresses + _ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + conn := &TracerConnection{ + Tracer: tracer, + Connection: tx, + } + + err := Create(ctx, conn, []*identity.VerifiableAddress{{ + Value: "foo.1@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.2@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "conflict@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.3@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "conflict@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }, { + Value: "foo.4@bar.de", + IdentityID: ident1.ID, + NID: ident1.NID, + }}) + + assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation) + + if conn.Connection.Dialect.Name() != dbal.DriverMySQL { + // MySQL does not support partial errors. + partialErr := new(PartialConflictError[identity.VerifiableAddress]) + require.ErrorAs(t, err, &partialErr) + assert.Len(t, partialErr.Failed, 1) + } + + return nil + }) + }) + } +} diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index 57593577c8c7..3029cdc51ef0 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -29,6 +29,7 @@ import ( "github.com/ory/kratos/internal" "github.com/ory/kratos/internal/testhelpers" "github.com/ory/kratos/persistence/sql" + "github.com/ory/kratos/persistence/sql/batch" sqltesthelpers "github.com/ory/kratos/persistence/sql/testhelpers" "github.com/ory/kratos/schema" errorx "github.com/ory/kratos/selfservice/errorx/test" @@ -264,6 +265,10 @@ func TestPersister(t *testing.T) { t.Parallel() continuity.TestPersister(ctx, p)(t) }) + t.Run("contract=batch.TestPersister", func(t *testing.T) { + t.Parallel() + batch.TestPersister(ctx, reg.Tracer(ctx), p)(t) + }) }) } }