Skip to content

Commit

Permalink
add batch persister test
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Sep 13, 2024
1 parent 17ed6fe commit e491329
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
90 changes: 90 additions & 0 deletions persistence/sql/batch/test_persister.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package batch

import (
"context"
"testing"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/kratos/identity"
"github.com/ory/kratos/persistence"
"github.com/ory/x/dbal"
"github.com/ory/x/otelx"
"github.com/ory/x/sqlcon"
)

func TestPersister(ctx context.Context, tracer *otelx.Tracer, p persistence.Persister) func(t *testing.T) {
return func(t *testing.T) {
t.Run("method=batch.Create", func(t *testing.T) {

ident1 := identity.NewIdentity("")
ident1.NID = p.NetworkID(ctx)
ident2 := identity.NewIdentity("")
ident2.NID = p.NetworkID(ctx)

// Create two identities
_ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
conn := &TracerConnection{
Tracer: tracer,
Connection: tx,
}

err := Create(ctx, conn, []*identity.Identity{ident1, ident2})
require.NoError(t, err)

return nil
})

require.NotEqual(t, uuid.Nil, ident1.ID)
require.NotEqual(t, uuid.Nil, ident2.ID)

// Create conflicting verifiable addresses
_ = p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error {
conn := &TracerConnection{
Tracer: tracer,
Connection: tx,
}

err := Create(ctx, conn, []*identity.VerifiableAddress{{
Value: "[email protected]",
IdentityID: ident1.ID,
NID: ident1.NID,
}, {
Value: "[email protected]",
IdentityID: ident1.ID,
NID: ident1.NID,
}, {
Value: "[email protected]",
IdentityID: ident1.ID,
NID: ident1.NID,
}, {
Value: "[email protected]",
IdentityID: ident1.ID,
NID: ident1.NID,
}, {
Value: "[email protected]",
IdentityID: ident1.ID,
NID: ident1.NID,
}, {
Value: "[email protected]",
IdentityID: ident1.ID,
NID: ident1.NID,
}})

assert.ErrorIs(t, err, sqlcon.ErrUniqueViolation)

if conn.Connection.Dialect.Name() != dbal.DriverMySQL {
// MySQL does not support partial errors.
partialErr := new(PartialConflictError[identity.VerifiableAddress])
require.ErrorAs(t, err, &partialErr)
assert.Len(t, partialErr.Failed, 1)
}

return nil
})
})
}
}
5 changes: 5 additions & 0 deletions persistence/sql/persister_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/ory/kratos/internal"
"github.com/ory/kratos/internal/testhelpers"
"github.com/ory/kratos/persistence/sql"
"github.com/ory/kratos/persistence/sql/batch"
sqltesthelpers "github.com/ory/kratos/persistence/sql/testhelpers"
"github.com/ory/kratos/schema"
errorx "github.com/ory/kratos/selfservice/errorx/test"
Expand Down Expand Up @@ -264,6 +265,10 @@ func TestPersister(t *testing.T) {
t.Parallel()
continuity.TestPersister(ctx, p)(t)
})
t.Run("contract=batch.TestPersister", func(t *testing.T) {
t.Parallel()
batch.TestPersister(ctx, reg.Tracer(ctx), p)(t)
})
})
}
}
Expand Down

0 comments on commit e491329

Please sign in to comment.