Skip to content

Commit

Permalink
Start publishing events
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Dec 31, 2024
1 parent 4231933 commit 8c36a8b
Show file tree
Hide file tree
Showing 7 changed files with 531 additions and 15 deletions.
7 changes: 7 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
with-expecter: true
packages:
github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1:
config:
dir: ./pkg/mocks
outpkg: mocks
interfaces:
IdentityApi_SubscribeAssociationChangesServer:
config:
github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1:
config:
dir: ./pkg/mocks
Expand Down
2 changes: 1 addition & 1 deletion pkg/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (s *Server) startGRPC() error {
}
mlsv1pb.RegisterMlsApiServer(grpcServer, s.mlsv1)

s.identityv1, err = identityv1.NewService(s.Log, s.Config.MLSStore, s.Config.MLSValidator, s.natsServer)
s.identityv1, err = identityv1.NewService(s.Log, s.Config.MLSStore, s.Config.MLSValidator, s.natsServer, publishToWakuRelay)
if err != nil {
return errors.Wrap(err, "creating identity service")
}
Expand Down
60 changes: 52 additions & 8 deletions pkg/identity/api/v1/identity_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import (

"github.com/nats-io/nats-server/v2/server"
"github.com/nats-io/nats.go"
wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb"
"github.com/xmtp/xmtp-node-go/pkg/envelopes"
identityTypes "github.com/xmtp/xmtp-node-go/pkg/identity/types"
mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store"
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate"
api "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
Expand All @@ -24,16 +26,24 @@ type Service struct {
store mlsstore.MlsStore
validationService mlsvalidate.MLSValidationService

ctx context.Context
nc *nats.Conn
ctxCancel func()
ctx context.Context
nc *nats.Conn
publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error
ctxCancel func()
}

func NewService(log *zap.Logger, store mlsstore.MlsStore, validationService mlsvalidate.MLSValidationService, natsServer *server.Server) (s *Service, err error) {
func NewService(
log *zap.Logger,
store mlsstore.MlsStore,
validationService mlsvalidate.MLSValidationService,
natsServer *server.Server,
publishToWakuRelay func(context.Context, *wakupb.WakuMessage) error,
) (s *Service, err error) {
s = &Service{
log: log.Named("identity"),
store: store,
validationService: validationService,
log: log.Named("identity"),
store: store,
validationService: validationService,
publishToWakuRelay: publishToWakuRelay,
}
s.ctx, s.ctxCancel = context.WithCancel(context.Background())

Expand Down Expand Up @@ -90,7 +100,17 @@ Start transaction (SERIALIZABLE isolation level)
End transaction
*/
func (s *Service) PublishIdentityUpdate(ctx context.Context, req *api.PublishIdentityUpdateRequest) (*api.PublishIdentityUpdateResponse, error) {
return s.store.PublishIdentityUpdate(ctx, req, s.validationService)
res, err := s.store.PublishIdentityUpdate(ctx, req, s.validationService)
if err != nil {
return nil, err
}

if err = s.PublishAssociationChangesEvent(ctx, res); err != nil {
s.log.Error("error publishing association changes event", zap.Error(err))
// Don't return the erro here because the transaction has already been committed
}

return &api.PublishIdentityUpdateResponse{}, nil
}

func (s *Service) GetIdentityUpdates(ctx context.Context, req *api.GetIdentityUpdatesRequest) (*api.GetIdentityUpdatesResponse, error) {
Expand Down Expand Up @@ -145,6 +165,30 @@ func (s *Service) SubscribeAssociationChanges(req *identity.SubscribeAssociation
return nil
}

func (s *Service) PublishAssociationChangesEvent(ctx context.Context, identityUpdateResult *identityTypes.PublishIdentityUpdateResult) error {
protoEvents := identityUpdateResult.GetChanges()
if len(protoEvents) == 0 {
return nil
}

for _, protoEvent := range protoEvents {
msgBytes, err := pb.Marshal(protoEvent)
if err != nil {
return err
}

if err = s.publishToWakuRelay(ctx, &wakupb.WakuMessage{
ContentTopic: topic.AssociationChangedTopic,
Timestamp: int64(identityUpdateResult.TimestampNs),
Payload: msgBytes,
}); err != nil {
return err
}
}

return nil
}

func buildNatsSubjectForAssociationChanges() string {
return envelopes.BuildNatsSubject(topic.BuildAssociationChangedTopic())
}
Expand Down
43 changes: 41 additions & 2 deletions pkg/identity/api/v1/identity_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,26 @@ package api

import (
"context"
"fmt"
"testing"
"time"

"github.com/nats-io/nats-server/v2/server"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
wakupb "github.com/waku-org/go-waku/waku/v2/protocol/pb"
mlsstore "github.com/xmtp/xmtp-node-go/pkg/mls/store"
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate"
identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
associations "github.com/xmtp/xmtp-node-go/pkg/proto/identity/associations"
mlsv1 "github.com/xmtp/xmtp-node-go/pkg/proto/mls/api/v1"
test "github.com/xmtp/xmtp-node-go/pkg/testing"
pb "google.golang.org/protobuf/proto"
)

const INBOX_ID = "test_inbox"

type mockedMLSValidationService struct {
mock.Mock
}
Expand All @@ -38,7 +44,7 @@ func (m *mockedMLSValidationService) GetAssociationState(ctx context.Context, ol
new_members = append(new_members, &associations.MemberIdentifier{Kind: &associations.MemberIdentifier_Address{Address: "0x03"}})

out := mlsvalidate.AssociationStateResult{
AssociationState: &associations.AssociationState{InboxId: "test_inbox", Members: member_map, RecoveryAddress: "recovery", SeenSignatures: [][]byte{[]byte("seen"), []byte("sig")}},
AssociationState: &associations.AssociationState{InboxId: INBOX_ID, Members: member_map, RecoveryAddress: "recovery", SeenSignatures: [][]byte{[]byte("seen"), []byte("sig")}},
StateDiff: &associations.AssociationStateDiff{NewMembers: new_members, RemovedMembers: nil},
}
return &out, nil
Expand Down Expand Up @@ -78,9 +84,16 @@ func newTestService(t *testing.T, ctx context.Context) (*Service, *bun.DB, func(
Port: server.RANDOM_PORT,
})
require.NoError(t, err)
go natsServer.Start()
if !natsServer.ReadyForConnections(4 * time.Second) {
t.Fail()
}
mlsValidationService := newMockedValidationService()
publishToWakuRelay := func(ctx context.Context, msg *wakupb.WakuMessage) error {
return nil
}

svc, err := NewService(log, store, mlsValidationService, natsServer)
svc, err := NewService(log, store, mlsValidationService, natsServer, publishToWakuRelay)
require.NoError(t, err)

return svc, db, func() {
Expand Down Expand Up @@ -265,3 +278,29 @@ func TestInboxSizeLimit(t *testing.T) {
require.Equal(t, res.Responses[0].InboxId, inbox_id)
require.Len(t, res.Responses[0].Updates, 256)
}

func TestPublishAssociationChanges(t *testing.T) {
ctx := context.Background()
svc, _, cleanup := newTestService(t, ctx)
defer cleanup()

inboxId := test.RandomInboxId()
address := "test_address"

numPublishedToWaku := 0
svc.publishToWakuRelay = func(ctx context.Context, wakuMsg *wakupb.WakuMessage) error {
numPublishedToWaku++
var msg identity.SubscribeAssociationChangesResponse
err := pb.Unmarshal(wakuMsg.Payload, &msg)
require.NoError(t, err)
require.Equal(t, msg.GetAccountAddressAssociation().InboxId, inboxId)
require.Equal(t, msg.GetAccountAddressAssociation().AccountAddress, fmt.Sprintf("0x0%d", numPublishedToWaku))
return nil
}

_, err := svc.PublishIdentityUpdate(ctx, publishIdentityUpdateRequest(inboxId, makeCreateInbox(address)))
require.NoError(t, err)

// The mocked GetAssociationState always returns 3 new addresses
require.Equal(t, numPublishedToWaku, 3)
}
56 changes: 56 additions & 0 deletions pkg/identity/types/publish-result.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package types

import (
identity "github.com/xmtp/xmtp-node-go/pkg/proto/identity/api/v1"
)

type PublishIdentityUpdateResult struct {
InboxID string

NewAddresses []string
RevokedAddresses []string
NewInstallations [][]byte
RevokedInstallations [][]byte
TimestampNs uint64
}

func NewPublishIdentityUpdateResult(inboxID string, timestampNs uint64, newAddresses []string, revokedAddresses []string, newInstallations [][]byte, revokedInstallations [][]byte) *PublishIdentityUpdateResult {
return &PublishIdentityUpdateResult{
InboxID: inboxID,
TimestampNs: timestampNs,
NewAddresses: newAddresses,
RevokedAddresses: revokedAddresses,
NewInstallations: newInstallations,
RevokedInstallations: revokedInstallations,
}
}

func (p *PublishIdentityUpdateResult) GetChanges() []*identity.SubscribeAssociationChangesResponse {
out := make([]*identity.SubscribeAssociationChangesResponse, 0)

for _, newAddress := range p.NewAddresses {
out = append(out, &identity.SubscribeAssociationChangesResponse{
TimestampNs: p.TimestampNs,
Change: &identity.SubscribeAssociationChangesResponse_AccountAddressAssociation_{
AccountAddressAssociation: &identity.SubscribeAssociationChangesResponse_AccountAddressAssociation{
InboxId: p.InboxID,
AccountAddress: newAddress,
},
},
})
}

for _, revokedAddress := range p.RevokedAddresses {
out = append(out, &identity.SubscribeAssociationChangesResponse{
TimestampNs: p.TimestampNs,
Change: &identity.SubscribeAssociationChangesResponse_AccountAddressRevocation_{
AccountAddressRevocation: &identity.SubscribeAssociationChangesResponse_AccountAddressRevocation{
InboxId: p.InboxID,
AccountAddress: revokedAddress,
},
},
})
}

return out
}
28 changes: 24 additions & 4 deletions pkg/mls/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate"
identityTypes "github.com/xmtp/xmtp-node-go/pkg/identity/types"
migrations "github.com/xmtp/xmtp-node-go/pkg/migrations/mls"
queries "github.com/xmtp/xmtp-node-go/pkg/mls/store/queries"
"github.com/xmtp/xmtp-node-go/pkg/mlsvalidate"
Expand All @@ -31,7 +32,7 @@ type Store struct {
}

type IdentityStore interface {
PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error)
PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identityTypes.PublishIdentityUpdateResult, error)
GetInboxLogs(ctx context.Context, req *identity.GetIdentityUpdatesRequest) (*identity.GetIdentityUpdatesResponse, error)
GetInboxIds(ctx context.Context, req *identity.GetInboxIdsRequest) (*identity.GetInboxIdsResponse, error)
}
Expand Down Expand Up @@ -99,13 +100,21 @@ func (s *Store) GetInboxIds(ctx context.Context, req *identity.GetInboxIdsReques
}, nil
}

func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identity.PublishIdentityUpdateResponse, error) {
func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.PublishIdentityUpdateRequest, validationService mlsvalidate.MLSValidationService) (*identityTypes.PublishIdentityUpdateResult, error) {
newUpdate := req.GetIdentityUpdate()
if newUpdate == nil {
return nil, errors.New("IdentityUpdate is required")
}

now := nowNs()
var newAccountAddresses []string
var revokedAccountAddresses []string

if err := s.RunInRepeatableReadTx(ctx, 3, func(ctx context.Context, txQueries *queries.Queries) error {
// Reset these lists to allow for safe retries of the TX
newAccountAddresses = make([]string, 0)
revokedAccountAddresses = make([]string, 0)

inboxId := newUpdate.GetInboxId()
// We use a pg_advisory_lock to lock the inbox_id instead of SELECT FOR UPDATE
// This allows the lock to be enforced even when there are no existing `inbox_log`s
Expand Down Expand Up @@ -144,7 +153,7 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish

sequence_id, err := txQueries.InsertInboxLog(ctx, queries.InsertInboxLogParams{
InboxID: inboxId,
ServerTimestampNs: nowNs(),
ServerTimestampNs: now,
IdentityUpdateProto: protoBytes,
})

Expand All @@ -157,6 +166,7 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish
for _, new_member := range state.StateDiff.NewMembers {
log.Info("New member", zap.Any("member", new_member))
if address, ok := new_member.Kind.(*associations.MemberIdentifier_Address); ok {
newAccountAddresses = append(newAccountAddresses, address.Address)
_, err = txQueries.InsertAddressLog(ctx, queries.InsertAddressLogParams{
Address: address.Address,
InboxID: inboxId,
Expand All @@ -172,6 +182,7 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish
for _, removed_member := range state.StateDiff.RemovedMembers {
log.Info("Removed member", zap.Any("member", removed_member))
if address, ok := removed_member.Kind.(*associations.MemberIdentifier_Address); ok {
revokedAccountAddresses = append(revokedAccountAddresses, address.Address)
err = txQueries.RevokeAddressFromLog(ctx, queries.RevokeAddressFromLogParams{
Address: address.Address,
InboxID: inboxId,
Expand All @@ -188,7 +199,16 @@ func (s *Store) PublishIdentityUpdate(ctx context.Context, req *identity.Publish
return nil, err
}

return &identity.PublishIdentityUpdateResponse{}, nil
result := identityTypes.NewPublishIdentityUpdateResult(
req.IdentityUpdate.InboxId,
uint64(now),
newAccountAddresses,
revokedAccountAddresses,
[][]byte{}, // TODO: Handle installations added
[][]byte{}, // TODO: Handle installations revoked
)

return result, nil
}

func (s *Store) GetInboxLogs(ctx context.Context, batched_req *identity.GetIdentityUpdatesRequest) (*identity.GetIdentityUpdatesResponse, error) {
Expand Down
Loading

0 comments on commit 8c36a8b

Please sign in to comment.