diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index afa525184f03..4523ce7f2146 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -584,12 +584,24 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... } } - if err := p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { + var succeededIDs []uuid.UUID + + defer func() { + // Report succeeded identities as created. + for _, identID := range succeededIDs { + span.AddEvent(events.NewIdentityCreated(ctx, identID)) + } + }() + + return p.Transaction(ctx, func(ctx context.Context, tx *pop.Connection) error { conn := &batch.TracerConnection{ Tracer: p.r.Tracer(ctx), Connection: tx, } + succeededIDs = make([]uuid.UUID, 0, len(identities)) + failedIdentityIDs := make(map[uuid.UUID]struct{}) + // Don't use batch.WithPartialInserts, because identities have no other // constraints other than the primary key that could cause conflicts. if err := batch.Create(ctx, conn, identities); err != nil { @@ -598,8 +610,6 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... p.normalizeAllAddressess(ctx, identities...) - failedIdentityIDs := make(map[uuid.UUID]struct{}) - if err = p.createVerifiableAddresses(ctx, tx, identities...); err != nil { if paritalErr := new(batch.PartialConflictError[identity.VerifiableAddress]); errors.As(err, &paritalErr) { for _, k := range paritalErr.Failed { @@ -645,10 +655,13 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... if len(failedIdentityIDs) > 0 { partialErr := &identity.CreateIdentitiesError{} failedIDs := make([]uuid.UUID, 0, len(failedIdentityIDs)) + for _, ident := range identities { if _, ok := failedIdentityIDs[ident.ID]; ok { partialErr.AddFailedIdentity(ident, sqlcon.ErrUniqueViolation) failedIDs = append(failedIDs, ident.ID) + } else { + succeededIDs = append(succeededIDs, ident.ID) } } // Manually roll back by deleting the identities that were inserted before the @@ -656,21 +669,17 @@ func (p *IdentityPersister) CreateIdentities(ctx context.Context, identities ... if err := p.DeleteIdentities(ctx, failedIDs); err != nil { return sqlcon.HandleError(err) } - // Wrap the partial error with the first error that occurred, so that the caller - // can continue to handle the error either as a partial error or a full error. + return partialErr + } else { + // No failures: report all identities as created. + for _, ident := range identities { + succeededIDs = append(succeededIDs, ident.ID) + } } return nil - }); err != nil { - return err - } - - for _, ident := range identities { - span.AddEvent(events.NewIdentityCreated(ctx, ident.ID)) - } - - return nil + }) } func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *identity.Identity, expand identity.Expandables) (err error) {