From 2977e161b8c362bd670c1f521bb7bf16afb8422d Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Fri, 21 Feb 2025 22:12:31 +0800 Subject: [PATCH 01/21] Implement message distribution, acknowledgment, visibility timeout, and time extension logic --- msg_distribution/visibility_timeout.go | 112 +++++++++++++++++++++++++ 1 file changed, 112 insertions(+) create mode 100644 msg_distribution/visibility_timeout.go diff --git a/msg_distribution/visibility_timeout.go b/msg_distribution/visibility_timeout.go new file mode 100644 index 00000000..3b803ce9 --- /dev/null +++ b/msg_distribution/visibility_timeout.go @@ -0,0 +1,112 @@ +// ✅ 1️⃣ Distribute messages to subscribers +// The Subscribe function retrieves messages from the in-memory buffer. +// Uses visibility timeout to lock messages for subscribers. + +// ✅ 2️⃣ Handle subscriber acknowledgments +// Acknowledge updates the message status (processed = true) in Spanner. +// Removes acknowledged messages from in-memory buffers. + +// ✅ 3️⃣ Implement visibility timeout logic +// If a subscriber doesn’t acknowledge in time, the message becomes visible again. +// Another subscriber can pick it up. + +// ✅ 4️⃣ Support time extension requests +// ModifyVisibilityTimeout allows subscribers to extend their processing time. +// Updates the timeout in Spanner. + +package main + +import ( + "context" + "log" + "time" + + pubsubpb "github.com/alphauslabs/pubsubproto" + "cloud.google.com/go/spanner" + "google.golang.org/grpc" +) + +type PubSubServer struct { + pubsubpb.UnimplementedPubSubServiceServer + spannerClient *spanner.Client +} + +// Publish writes a message to Spanner and acknowledges receipt +func (s *PubSubServer) Publish(ctx context.Context, req *pubsubpb.PublishRequest) (*pubsubpb.PublishResponse, error) { + id := generateUUID() + createdAt := time.Now() + visibilityTimeout := createdAt.Add(time.Minute) // Default 1 min lock + + _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ + spanner.Insert("Messages", + []string{"id", "payload", "topic", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, + []interface{}{id, req.Payload, req.Topic, createdAt, createdAt, visibilityTimeout, false}, + ), + }) + if err != nil { + return nil, err + } + + return &pubsubpb.PublishResponse{MessageId: id}, nil +} + +// Subscribe streams messages to a subscriber +func (s *PubSubServer) Subscribe(req *pubsubpb.SubscribeRequest, stream pubsubpb.PubSubService_SubscribeServer) error { + ctx := stream.Context() + stmt := spanner.Statement{ + SQL: `SELECT id, payload, topic, visibilityTimeout, processed FROM Messages + WHERE processed = FALSE AND visibilityTimeout <= CURRENT_TIMESTAMP()`, + } + iter := s.spannerClient.Single().Query(ctx, stmt) + subscriberID := req.SubscriptionId + for { + row, err := iter.Next() + if err != nil { + return err + } + + var msg pubsubpb.Message + err = row.Columns(&msg.Id, &msg.Payload, &msg.Topic, &msg.Processed) + if err != nil { + return err + } + + msg.Processed = false + if err := stream.Send(&msg); err != nil { + return err + } + } +} + +// Acknowledge confirms message processing +func (s *PubSubServer) Acknowledge(ctx context.Context, req *pubsubpb.AcknowledgeRequest) (*pubsubpb.AcknowledgeResponse, error) { + _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ + spanner.Update("Messages", []string{"id", "processed"}, []interface{}{req.Id, true}), + }) + if err != nil { + return nil, err + } + + return &pubsubpb.AcknowledgeResponse{Success: true}, nil +} + +// ModifyVisibilityTimeout updates message lock duration +func (s *PubSubServer) ModifyVisibilityTimeout(ctx context.Context, req *pubsubpb.ModifyVisibilityTimeoutRequest) (*pubsubpb.ModifyVisibilityTimeoutResponse, error) { + newTimeout := time.Now().Add(time.Duration(req.NewTimeout) * time.Second) + _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ + spanner.Update("Messages", []string{"id", "visibilityTimeout"}, []interface{}{req.Id, newTimeout}), + }) + if err != nil { + return nil, err + } + return &pubsubpb.ModifyVisibilityTimeoutResponse{Success: true}, nil +} + +func main() { + // gRPC server setup + server := grpc.NewServer() + pubsubpb.RegisterPubSubServiceServer(server, &PubSubServer{}) + + log.Println("Pub/Sub Service is running...") + // Listen & Serve (setup omitted for brevity) +} From 871e019be562178aba64408092e5a43d3e9a3a44 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy Date: Sun, 23 Feb 2025 01:35:11 +0800 Subject: [PATCH 02/21] had some conflict in go.mod so i manually fix it --- go.mod | 2 +- server.go | 260 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 251 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index 5e86ac90..3b075aac 100644 --- a/go.mod +++ b/go.mod @@ -84,4 +84,4 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20250219182151-9fdb1cabc7b2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 // indirect google.golang.org/protobuf v1.36.5 // indirect -) +) \ No newline at end of file diff --git a/server.go b/server.go index baa858b8..45e899c7 100644 --- a/server.go +++ b/server.go @@ -17,12 +17,42 @@ type server struct { client *spanner.Client op *hedge.Op pb.UnimplementedPubSubServiceServer + + visibilityTimeouts sync.Map // messageID -> VisibilityInfo + lockMu sync.RWMutex + + messageQueue map[string][]*pb.Message // topic -> messages + messageQueueMu sync.RWMutex +} + +type broadCastInput struct { + Type string `json:"type"` + Msg interface{} `json:"msg"` +} + +type VisibilityInfo struct { + MessageID string `json:"messageId"` + SubscriberID string `json:"subscriberId"` + ExpiresAt time.Time `json:"expiresAt"` + NodeID string `json:"nodeId"` } const ( MessagesTable = "Messages" + visibilityTimeout = 5 * time.Minute + cleanupInterval = 30 * time.Second ) +func NewServer(client *spanner.Client, op *hedge.Op) *server { + s := &server{ + client: client, + op: op, + messageQueue: make(map[string][]*pb.Message), + } + go s.startVisibilityCleanup() + return s +} + func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.PublishResponse, error) { if in.Topic == "" { return nil, status.Error(codes.InvalidArgument, "topic must not be empty") @@ -38,16 +68,18 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis messageID := uuid.New().String() mutation := spanner.InsertOrUpdate( - MessagesTable, - []string{"id", "topic", "payload", "createdAt", "updatedAt"}, - []interface{}{ - messageID, - in.Topic, - in.Payload, - spanner.CommitTimestamp, - spanner.CommitTimestamp, - }, - ) + MessagesTable, + []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, + []interface{}{ + messageID, + in.Topic, + in.Payload, + spanner.CommitTimestamp, + spanner.CommitTimestamp, + nil, // Explicitly set visibilityTimeout as NULL + false, // Default to unprocessed + }, +) _, err := s.client.Apply(ctx, []*spanner.Mutation{mutation}) if err != nil { @@ -72,3 +104,211 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis log.Printf("[Publish] Message successfully broadcasted and wrote to spanner with ID: %s", messageID) return &pb.PublishResponse{MessageId: messageID}, nil } + +func (s *server) Subscribe(req *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { + subscriberID := uuid.New().String() + ctx := stream.Context() + + log.Printf("[Subscribe] New subscriber: %s for topic: %s", subscriberID, req.Topic) + go s.keepAliveSubscriber(ctx, stream) + + for { + select { + case <-ctx.Done(): + s.cleanupSubscriberLocks(subscriberID) + return nil + default: + s.messageQueueMu.RLock() + msgs, exists := s.messageQueue[req.Topic] + s.messageQueueMu.RUnlock() + + if !exists || len(msgs) == 0 { + time.Sleep(100 * time.Millisecond) + continue + } + // Check visibility timeout before sending + info, exists := s.visibilityTimeouts.Load(msg.Id) + if exists && time.Now().Before(info.(VisibilityInfo).ExpiresAt) { + continue // Skip locked messages + } + + s.messageQueueMu.Lock() + msg := msgs[0] + s.messageQueue[req.Topic] = msgs[1:] + s.messageQueueMu.Unlock() + + + locked, err := s.tryLockMessage(msg.Id, subscriberID) + if err != nil || !locked { + continue + } + + if err := stream.Send(msg); err != nil { + s.releaseMessageLock(msg.Id, subscriberID) + return err + } + } + } +} + +func (s *server) tryLockMessage(messageID, subscriberID string) (bool, error) { + s.lockMu.Lock() + defer s.lockMu.Unlock() + + if _, exists := s.visibilityTimeouts.Load(messageID); exists { + return false, nil + } + + visInfo := VisibilityInfo{ + MessageID: messageID, + SubscriberID: subscriberID, + ExpiresAt: time.Now().Add(visibilityTimeout), + NodeID: uuid.New().String(), + } + + s.visibilityTimeouts.Store(messageID, visInfo) + return true, s.broadcastVisibilityUpdate("lock", visInfo) +} + +func (s *server) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { + s.messageQueueMu.Lock() + defer s.messageQueueMu.Unlock() + + if err := s.releaseMessageLock(req.Id, req.SubscriberId); err != nil { + log.Printf("Error releasing message lock: %v", err) + } + + mutation := spanner.Update( + MessagesTable, + []string{"id", "processed", "updatedAt"}, + []interface{}{req.Id, true, spanner.CommitTimestamp}, + ) + + _, err := s.client.Apply(ctx, []*spanner.Mutation{mutation}) + if err != nil { + return nil, err + } + + s.messageQueue[req.Topic] = s.messageQueue[req.Topic][1:] + + bcastin := broadCastInput{ + Type: "ack", + Msg: map[string]string{ + "messageId": req.Id, + "topic": req.Topic, + }, + } + if err := s.broadcastAck(bcastin); err != nil { + log.Printf("Error broadcasting ack: %v", err) + } + + return &pb.AcknowledgeResponse{Success: true}, nil +} + +func (s *server) releaseMessageLock(messageID, subscriberID string) error { + s.lockMu.Lock() + defer s.lockMu.Unlock() + + if info, exists := s.visibilityTimeouts.Load(messageID); exists { + visInfo := info.(VisibilityInfo) + if visInfo.SubscriberID == subscriberID { + s.visibilityTimeouts.Delete(messageID) + return s.broadcastVisibilityUpdate("unlock", visInfo) + } + } + return nil +} + +func (s *server) ExtendVisibilityTimeout(ctx context.Context, req *pb.ExtendTimeoutRequest) (*pb.ExtendTimeoutResponse, error) { + s.lockMu.Lock() + defer s.lockMu.Unlock() + + info, exists := s.visibilityTimeouts.Load(req.MessageId) + if !exists { + return nil, status.Error(codes.NotFound, "Message lock not found") + } + + visInfo := info.(VisibilityInfo) + if visInfo.SubscriberID != req.SubscriberId { + return nil, status.Error(codes.PermissionDenied, "Not allowed to extend timeout for this message") + } + + newExpiry := time.Now().Add(time.Duration(req.ExtensionSeconds) * time.Second) + visInfo.ExpiresAt = newExpiry + s.visibilityTimeouts.Store(req.MessageId, visInfo) + + // Update Spanner to reflect the new timeout + go func() { + mutation := spanner.Update( + MessagesTable, + []string{"id", "visibilityTimeout", "updatedAt"}, + []interface{}{req.MessageId, newExpiry, spanner.CommitTimestamp}, + ) + _, err := s.client.Apply(ctx, []*spanner.Mutation{mutation}) + if err != nil { + log.Printf("Spanner update error: %v", err) + } +}() + + // Broadcast new timeout info + _ = s.broadcastVisibilityUpdate("extend", visInfo) + + return &pb.ExtendTimeoutResponse{Success: true}, nil +} + + +func (s *server) broadcastVisibilityUpdate(cmdType string, info VisibilityInfo) error { + bcastin := broadCastInput{ + Type: "visibility", + Msg: struct { + Command string `json:"command"` + Info VisibilityInfo `json:"info"` + }{ + Command: cmdType, + Info: info, + }, + } + + data, err := json.Marshal(bcastin) + if err != nil { + return err + } + + results := s.op.Broadcast(context.Background(), data) + for _, result := range results { + if result.Error != nil { + log.Printf("Broadcast error to node %s: %v", result.NodeID, result.Error) + } + } + + return nil +} + +func (s *server) startVisibilityCleanup() { + ticker := time.NewTicker(cleanupInterval) + for range ticker.C { + s.cleanupExpiredLocks() + } +} + +func (s *server) cleanupExpiredLocks() { + now := time.Now() + s.lockMu.Lock() + defer s.lockMu.Unlock() + + s.visibilityTimeouts.Range(func(key, value interface{}) bool { + visInfo := value.(VisibilityInfo) + if now.After(visInfo.ExpiresAt) { + // Double-check before deleting + if info, exists := s.visibilityTimeouts.Load(key); exists { + if time.Now().Before(info.(VisibilityInfo).ExpiresAt) { + return true // Another node extended it + } + } + s.visibilityTimeouts.Delete(key) + s.broadcastVisibilityUpdate("unlock", visInfo) + } + return true + }) +} + From 3644a6f83f34a3b330ba1b3ec30550356a39b240 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy Date: Sun, 23 Feb 2025 01:45:40 +0800 Subject: [PATCH 03/21] Remove visibility_timeout.go --- msg_distribution/visibility_timeout.go | 112 ------------------------- 1 file changed, 112 deletions(-) delete mode 100644 msg_distribution/visibility_timeout.go diff --git a/msg_distribution/visibility_timeout.go b/msg_distribution/visibility_timeout.go deleted file mode 100644 index 3b803ce9..00000000 --- a/msg_distribution/visibility_timeout.go +++ /dev/null @@ -1,112 +0,0 @@ -// ✅ 1️⃣ Distribute messages to subscribers -// The Subscribe function retrieves messages from the in-memory buffer. -// Uses visibility timeout to lock messages for subscribers. - -// ✅ 2️⃣ Handle subscriber acknowledgments -// Acknowledge updates the message status (processed = true) in Spanner. -// Removes acknowledged messages from in-memory buffers. - -// ✅ 3️⃣ Implement visibility timeout logic -// If a subscriber doesn’t acknowledge in time, the message becomes visible again. -// Another subscriber can pick it up. - -// ✅ 4️⃣ Support time extension requests -// ModifyVisibilityTimeout allows subscribers to extend their processing time. -// Updates the timeout in Spanner. - -package main - -import ( - "context" - "log" - "time" - - pubsubpb "github.com/alphauslabs/pubsubproto" - "cloud.google.com/go/spanner" - "google.golang.org/grpc" -) - -type PubSubServer struct { - pubsubpb.UnimplementedPubSubServiceServer - spannerClient *spanner.Client -} - -// Publish writes a message to Spanner and acknowledges receipt -func (s *PubSubServer) Publish(ctx context.Context, req *pubsubpb.PublishRequest) (*pubsubpb.PublishResponse, error) { - id := generateUUID() - createdAt := time.Now() - visibilityTimeout := createdAt.Add(time.Minute) // Default 1 min lock - - _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ - spanner.Insert("Messages", - []string{"id", "payload", "topic", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, - []interface{}{id, req.Payload, req.Topic, createdAt, createdAt, visibilityTimeout, false}, - ), - }) - if err != nil { - return nil, err - } - - return &pubsubpb.PublishResponse{MessageId: id}, nil -} - -// Subscribe streams messages to a subscriber -func (s *PubSubServer) Subscribe(req *pubsubpb.SubscribeRequest, stream pubsubpb.PubSubService_SubscribeServer) error { - ctx := stream.Context() - stmt := spanner.Statement{ - SQL: `SELECT id, payload, topic, visibilityTimeout, processed FROM Messages - WHERE processed = FALSE AND visibilityTimeout <= CURRENT_TIMESTAMP()`, - } - iter := s.spannerClient.Single().Query(ctx, stmt) - subscriberID := req.SubscriptionId - for { - row, err := iter.Next() - if err != nil { - return err - } - - var msg pubsubpb.Message - err = row.Columns(&msg.Id, &msg.Payload, &msg.Topic, &msg.Processed) - if err != nil { - return err - } - - msg.Processed = false - if err := stream.Send(&msg); err != nil { - return err - } - } -} - -// Acknowledge confirms message processing -func (s *PubSubServer) Acknowledge(ctx context.Context, req *pubsubpb.AcknowledgeRequest) (*pubsubpb.AcknowledgeResponse, error) { - _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ - spanner.Update("Messages", []string{"id", "processed"}, []interface{}{req.Id, true}), - }) - if err != nil { - return nil, err - } - - return &pubsubpb.AcknowledgeResponse{Success: true}, nil -} - -// ModifyVisibilityTimeout updates message lock duration -func (s *PubSubServer) ModifyVisibilityTimeout(ctx context.Context, req *pubsubpb.ModifyVisibilityTimeoutRequest) (*pubsubpb.ModifyVisibilityTimeoutResponse, error) { - newTimeout := time.Now().Add(time.Duration(req.NewTimeout) * time.Second) - _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ - spanner.Update("Messages", []string{"id", "visibilityTimeout"}, []interface{}{req.Id, newTimeout}), - }) - if err != nil { - return nil, err - } - return &pubsubpb.ModifyVisibilityTimeoutResponse{Success: true}, nil -} - -func main() { - // gRPC server setup - server := grpc.NewServer() - pubsubpb.RegisterPubSubServiceServer(server, &PubSubServer{}) - - log.Println("Pub/Sub Service is running...") - // Listen & Serve (setup omitted for brevity) -} From 2c14324fdbc56a2a93be3cc4b00a6338ee3b9df6 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Sun, 23 Feb 2025 01:46:21 +0800 Subject: [PATCH 04/21] Delete msg_distribution directory --- msg_distribution/visibility_timeout.go | 112 ------------------------- 1 file changed, 112 deletions(-) delete mode 100644 msg_distribution/visibility_timeout.go diff --git a/msg_distribution/visibility_timeout.go b/msg_distribution/visibility_timeout.go deleted file mode 100644 index 3b803ce9..00000000 --- a/msg_distribution/visibility_timeout.go +++ /dev/null @@ -1,112 +0,0 @@ -// ✅ 1️⃣ Distribute messages to subscribers -// The Subscribe function retrieves messages from the in-memory buffer. -// Uses visibility timeout to lock messages for subscribers. - -// ✅ 2️⃣ Handle subscriber acknowledgments -// Acknowledge updates the message status (processed = true) in Spanner. -// Removes acknowledged messages from in-memory buffers. - -// ✅ 3️⃣ Implement visibility timeout logic -// If a subscriber doesn’t acknowledge in time, the message becomes visible again. -// Another subscriber can pick it up. - -// ✅ 4️⃣ Support time extension requests -// ModifyVisibilityTimeout allows subscribers to extend their processing time. -// Updates the timeout in Spanner. - -package main - -import ( - "context" - "log" - "time" - - pubsubpb "github.com/alphauslabs/pubsubproto" - "cloud.google.com/go/spanner" - "google.golang.org/grpc" -) - -type PubSubServer struct { - pubsubpb.UnimplementedPubSubServiceServer - spannerClient *spanner.Client -} - -// Publish writes a message to Spanner and acknowledges receipt -func (s *PubSubServer) Publish(ctx context.Context, req *pubsubpb.PublishRequest) (*pubsubpb.PublishResponse, error) { - id := generateUUID() - createdAt := time.Now() - visibilityTimeout := createdAt.Add(time.Minute) // Default 1 min lock - - _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ - spanner.Insert("Messages", - []string{"id", "payload", "topic", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, - []interface{}{id, req.Payload, req.Topic, createdAt, createdAt, visibilityTimeout, false}, - ), - }) - if err != nil { - return nil, err - } - - return &pubsubpb.PublishResponse{MessageId: id}, nil -} - -// Subscribe streams messages to a subscriber -func (s *PubSubServer) Subscribe(req *pubsubpb.SubscribeRequest, stream pubsubpb.PubSubService_SubscribeServer) error { - ctx := stream.Context() - stmt := spanner.Statement{ - SQL: `SELECT id, payload, topic, visibilityTimeout, processed FROM Messages - WHERE processed = FALSE AND visibilityTimeout <= CURRENT_TIMESTAMP()`, - } - iter := s.spannerClient.Single().Query(ctx, stmt) - subscriberID := req.SubscriptionId - for { - row, err := iter.Next() - if err != nil { - return err - } - - var msg pubsubpb.Message - err = row.Columns(&msg.Id, &msg.Payload, &msg.Topic, &msg.Processed) - if err != nil { - return err - } - - msg.Processed = false - if err := stream.Send(&msg); err != nil { - return err - } - } -} - -// Acknowledge confirms message processing -func (s *PubSubServer) Acknowledge(ctx context.Context, req *pubsubpb.AcknowledgeRequest) (*pubsubpb.AcknowledgeResponse, error) { - _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ - spanner.Update("Messages", []string{"id", "processed"}, []interface{}{req.Id, true}), - }) - if err != nil { - return nil, err - } - - return &pubsubpb.AcknowledgeResponse{Success: true}, nil -} - -// ModifyVisibilityTimeout updates message lock duration -func (s *PubSubServer) ModifyVisibilityTimeout(ctx context.Context, req *pubsubpb.ModifyVisibilityTimeoutRequest) (*pubsubpb.ModifyVisibilityTimeoutResponse, error) { - newTimeout := time.Now().Add(time.Duration(req.NewTimeout) * time.Second) - _, err := s.spannerClient.Apply(ctx, []*spanner.Mutation{ - spanner.Update("Messages", []string{"id", "visibilityTimeout"}, []interface{}{req.Id, newTimeout}), - }) - if err != nil { - return nil, err - } - return &pubsubpb.ModifyVisibilityTimeoutResponse{Success: true}, nil -} - -func main() { - // gRPC server setup - server := grpc.NewServer() - pubsubpb.RegisterPubSubServiceServer(server, &PubSubServer{}) - - log.Println("Pub/Sub Service is running...") - // Listen & Serve (setup omitted for brevity) -} From ee9713e735a31d43ba201e39c075d4171f553882 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Mon, 24 Feb 2025 20:51:10 +0800 Subject: [PATCH 05/21] implements a message handling system in handleBroadcastedMsg --- broadcast.go | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/broadcast.go b/broadcast.go index a0ba8849..344ae8dc 100644 --- a/broadcast.go +++ b/broadcast.go @@ -19,7 +19,7 @@ var ctrlbroadcast = map[string]func(*PubSub, []byte) ([]byte, error){ topicsub: handleBroadcastedTopicsub, } -// Root handler for op.Broadcast() +// Root handler for op.Broadcast() // do not change this func broadcast(data any, msg []byte) ([]byte, error) { var in broadCastInput app := data.(*PubSub) @@ -30,7 +30,35 @@ func broadcast(data any, msg []byte) ([]byte, error) { } func handleBroadcastedMsg(app *PubSub, msg []byte) ([]byte, error) { - return nil, nil + parts := strings.Split(string(msg), ":") + switch parts[0] { + case "lock": + //if a node receives a "lock" request for a message it already has locked, it should reject duplicate locks. + messageID := parts[1] + if _, exists := app.messageLocks.Load(messageID); exists { + return nil, nil // Already locked, ignore duplicate + } + case "unlock": + // Handle unlock request + messageID := parts[1] + app.messageLocks.Delete(messageID) + // Clean up locks and timers + case "delete": + messageID := parts[1] + app.messageLocks.Delete(messageID) + app.messageQueue.Delete(messageID) + case "extend": + // Handle timeout extension + messageID := parts[1] + newTimeout, _ := strconv.Atoi(parts[2]) + if lockInfo, ok := app.messageLocks.Load(messageID); ok { + info := lockInfo.(MessageLockInfo) + info.Timeout = time.Now().Add(time.Duration(newTimeout) * time.Second) + app.messageLocks.Store(messageID, info) + } + // Update timeout and reset timer + } + return nil, nil } func handleBroadcastedTopicsub(app *PubSub, msg []byte) ([]byte, error) { From b7bb7f1d444ed07271c74c2261ded1bbf13c3d80 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Mon, 24 Feb 2025 21:08:39 +0800 Subject: [PATCH 06/21] added some sync.map --- app.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/app.go b/app.go index 42bed720..bf432dee 100644 --- a/app.go +++ b/app.go @@ -8,4 +8,11 @@ import ( type PubSub struct { Op *hedge.Op Client *spanner.Client + // Message handling + messageLocks sync.Map // messageID -> MessageLockInfo + messageQueue sync.Map // topic -> []*pb.Message + subscriptions sync.Map // subscriptionID -> *pb.Subscription + + // Timer tracking + timeoutTimers sync.Map // messageID -> *time.Timer } From f490a0856c3b2f6d0fec5dc2e4bb1f536725d5a8 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Mon, 24 Feb 2025 23:08:23 +0800 Subject: [PATCH 07/21] Implement distributed message locking and visibility timeout handling --- server.go | 498 ++++++++++++++++++++++++++++++++---------------------- 1 file changed, 300 insertions(+), 198 deletions(-) diff --git a/server.go b/server.go index 45e899c7..d852f088 100644 --- a/server.go +++ b/server.go @@ -1,65 +1,45 @@ +//TO DO: separate the helper functions to a different file package main import ( - "context" - "encoding/json" - "log" - - "cloud.google.com/go/spanner" - pb "github.com/alphauslabs/pubsub-proto/v1" - "github.com/flowerinthenight/hedge/v2" - "github.com/google/uuid" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "context" + "encoding/json" + "fmt" + "log" + "sync" + "time" + + "cloud.google.com/go/spanner" + pb "github.com/alphauslabs/pubsub-proto/v1" + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) -type server struct { - client *spanner.Client - op *hedge.Op - pb.UnimplementedPubSubServiceServer - - visibilityTimeouts sync.Map // messageID -> VisibilityInfo - lockMu sync.RWMutex - - messageQueue map[string][]*pb.Message // topic -> messages - messageQueueMu sync.RWMutex +type MessageLockInfo struct { + Timeout time.Time + Locked bool + NodeID string + LockHolders map[string]bool // Track which nodes have acknowledged the lock } -type broadCastInput struct { - Type string `json:"type"` - Msg interface{} `json:"msg"` -} - -type VisibilityInfo struct { - MessageID string `json:"messageId"` - SubscriberID string `json:"subscriberId"` - ExpiresAt time.Time `json:"expiresAt"` - NodeID string `json:"nodeId"` +type server struct { + *PubSub + pb.UnimplementedPubSubServiceServer } const ( MessagesTable = "Messages" - visibilityTimeout = 5 * time.Minute - cleanupInterval = 30 * time.Second ) -func NewServer(client *spanner.Client, op *hedge.Op) *server { - s := &server{ - client: client, - op: op, - messageQueue: make(map[string][]*pb.Message), - } - go s.startVisibilityCleanup() - return s -} - func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.PublishResponse, error) { if in.Topic == "" { return nil, status.Error(codes.InvalidArgument, "topic must not be empty") } b, _ := json.Marshal(in) - l, _ := s.op.HasLock() + + l, _ := s.Op.HasLock() if l { log.Println("[Publish-leader] Received message:\n", string(b)) } else { @@ -68,20 +48,20 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis messageID := uuid.New().String() mutation := spanner.InsertOrUpdate( - MessagesTable, - []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, - []interface{}{ - messageID, - in.Topic, - in.Payload, - spanner.CommitTimestamp, - spanner.CommitTimestamp, - nil, // Explicitly set visibilityTimeout as NULL - false, // Default to unprocessed - }, -) + MessagesTable, + []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, + []interface{}{ + messageID, + in.Topic, + in.Payload, + spanner.CommitTimestamp, + spanner.CommitTimestamp, + nil, // Initial visibilityTimeout is NULL + false, // Not processed yet + }, + ) - _, err := s.client.Apply(ctx, []*spanner.Mutation{mutation}) + _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) if err != nil { log.Printf("Error writing to Spanner: %v", err) return nil, err @@ -94,7 +74,7 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis } bin, _ := json.Marshal(bcastin) - out := s.op.Broadcast(ctx, bin) + out := s.Op.Broadcast(ctx, bin) for _, v := range out { if v.Error != nil { // for us to know, then do necessary actions if frequent log.Printf("[Publish] Error broadcasting message: %v", v.Error) @@ -105,210 +85,332 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis return &pb.PublishResponse{MessageId: messageID}, nil } -func (s *server) Subscribe(req *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { - subscriberID := uuid.New().String() - ctx := stream.Context() +func (s *server) Subscribe(in *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { + // Validate subscription in memory first + subscription, err := s.validateTopicSubscription(in.SubscriptionId) + if err != nil { + return err + } - log.Printf("[Subscribe] New subscriber: %s for topic: %s", subscriberID, req.Topic) - go s.keepAliveSubscriber(ctx, stream) + log.Printf("[Subscribe] Starting subscription stream for ID: %s", in.SubscriptionId) for { select { - case <-ctx.Done(): - s.cleanupSubscriberLocks(subscriberID) + case <-stream.Context().Done(): return nil default: - s.messageQueueMu.RLock() - msgs, exists := s.messageQueue[req.Topic] - s.messageQueueMu.RUnlock() - - if !exists || len(msgs) == 0 { - time.Sleep(100 * time.Millisecond) + // Request message from the leader instead of querying directly + message, err := s.requestMessageFromLeader(subscription.TopicId) + if err != nil { + log.Printf("[Subscribe] No available messages for subscription: %s", in.SubscriptionId) + time.Sleep(time.Second) // Prevent CPU overuse continue } - // Check visibility timeout before sending - info, exists := s.visibilityTimeouts.Load(msg.Id) - if exists && time.Now().Before(info.(VisibilityInfo).ExpiresAt) { - continue // Skip locked messages + + // Ensure it's not already locked by another node + if _, exists := s.messageLocks.Load(message.Id); exists { + continue // Skip locked messages } - - s.messageQueueMu.Lock() - msg := msgs[0] - s.messageQueue[req.Topic] = msgs[1:] - s.messageQueueMu.Unlock() - - - locked, err := s.tryLockMessage(msg.Id, subscriberID) - if err != nil || !locked { + + // Try to acquire distributed lock + if err := s.broadcastLock(stream.Context(), message.Id, 30*time.Second); err != nil { continue } - if err := stream.Send(msg); err != nil { - s.releaseMessageLock(msg.Id, subscriberID) + // Send message to subscriber + if err := stream.Send(message); err != nil { + s.broadcastUnlock(stream.Context(), message.Id) return err } } } } -func (s *server) tryLockMessage(messageID, subscriberID string) (bool, error) { - s.lockMu.Lock() - defer s.lockMu.Unlock() - - if _, exists := s.visibilityTimeouts.Load(messageID); exists { - return false, nil +func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { + // Verify lock exists and is valid + lockInfo, ok := s.messageLocks.Load(in.Id) + if !ok { + return nil, status.Error(codes.NotFound, "message lock not found") } - - visInfo := VisibilityInfo{ - MessageID: messageID, - SubscriberID: subscriberID, - ExpiresAt: time.Now().Add(visibilityTimeout), - NodeID: uuid.New().String(), - } - - s.visibilityTimeouts.Store(messageID, visInfo) - return true, s.broadcastVisibilityUpdate("lock", visInfo) -} - -func (s *server) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { - s.messageQueueMu.Lock() - defer s.messageQueueMu.Unlock() - - if err := s.releaseMessageLock(req.Id, req.SubscriberId); err != nil { - log.Printf("Error releasing message lock: %v", err) + info := lockInfo.(MessageLockInfo) + if !info.Locked || time.Now().After(info.Timeout) { + return nil, status.Error(codes.FailedPrecondition, "message lock expired") } + // Update Spanner mutation := spanner.Update( MessagesTable, []string{"id", "processed", "updatedAt"}, - []interface{}{req.Id, true, spanner.CommitTimestamp}, + []interface{}{in.Id, true, spanner.CommitTimestamp}, ) - _, err := s.client.Apply(ctx, []*spanner.Mutation{mutation}) + _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) if err != nil { return nil, err } - s.messageQueue[req.Topic] = s.messageQueue[req.Topic][1:] - - bcastin := broadCastInput{ - Type: "ack", - Msg: map[string]string{ - "messageId": req.Id, - "topic": req.Topic, - }, + // Broadcast delete to all nodes + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("delete:%s", in.Id)), } - if err := s.broadcastAck(bcastin); err != nil { - log.Printf("Error broadcasting ack: %v", err) + + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + + // Clean up local state + s.messageLocks.Delete(in.Id) + if timer, ok := s.timeoutTimers.Load(in.Id); ok { + timer.(*time.Timer).Stop() + s.timeoutTimers.Delete(in.Id) } return &pb.AcknowledgeResponse{Success: true}, nil } -func (s *server) releaseMessageLock(messageID, subscriberID string) error { - s.lockMu.Lock() - defer s.lockMu.Unlock() - - if info, exists := s.visibilityTimeouts.Load(messageID); exists { - visInfo := info.(VisibilityInfo) - if visInfo.SubscriberID == subscriberID { - s.visibilityTimeouts.Delete(messageID) - return s.broadcastVisibilityUpdate("unlock", visInfo) - } +func (s *server) ModifyVisibilityTimeout(ctx context.Context, in *pb.ModifyVisibilityTimeoutRequest) (*pb.ModifyVisibilityTimeoutResponse, error) { + lockInfo, ok := s.messageLocks.Load(in.Id) + if !ok { + return nil, status.Error(codes.NotFound, "message lock not found") } - return nil -} -func (s *server) ExtendVisibilityTimeout(ctx context.Context, req *pb.ExtendTimeoutRequest) (*pb.ExtendTimeoutResponse, error) { - s.lockMu.Lock() - defer s.lockMu.Unlock() + info := lockInfo.(MessageLockInfo) + if !info.Locked { + return nil, status.Error(codes.FailedPrecondition, "message not locked") + } - info, exists := s.visibilityTimeouts.Load(req.MessageId) - if !exists { - return nil, status.Error(codes.NotFound, "Message lock not found") + // Ensure the same node is extending the lock + if info.NodeID != s.Op.ID() { + return nil, status.Error(codes.PermissionDenied, "only the original lock holder can extend timeout") } - visInfo := info.(VisibilityInfo) - if visInfo.SubscriberID != req.SubscriberId { - return nil, status.Error(codes.PermissionDenied, "Not allowed to extend timeout for this message") + // Broadcast new timeout + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("extend:%s:%d", in.Id, in.NewTimeout)), } - newExpiry := time.Now().Add(time.Duration(req.ExtensionSeconds) * time.Second) - visInfo.ExpiresAt = newExpiry - s.visibilityTimeouts.Store(req.MessageId, visInfo) + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) - // Update Spanner to reflect the new timeout - go func() { - mutation := spanner.Update( - MessagesTable, - []string{"id", "visibilityTimeout", "updatedAt"}, - []interface{}{req.MessageId, newExpiry, spanner.CommitTimestamp}, - ) - _, err := s.client.Apply(ctx, []*spanner.Mutation{mutation}) - if err != nil { - log.Printf("Spanner update error: %v", err) + // Update local timer + if timer, ok := s.timeoutTimers.Load(in.Id); ok { + timer.(*time.Timer).Stop() } -}() + newTimer := time.NewTimer(time.Duration(in.NewTimeout) * time.Second) + s.timeoutTimers.Store(in.Id, newTimer) - // Broadcast new timeout info - _ = s.broadcastVisibilityUpdate("extend", visInfo) + // Update lock info + info.Timeout = time.Now().Add(time.Duration(in.NewTimeout) * time.Second) + s.messageLocks.Store(in.Id, info) - return &pb.ExtendTimeoutResponse{Success: true}, nil + go func() { + <-newTimer.C + s.handleMessageTimeout(in.Id) + }() + + return &pb.ModifyVisibilityTimeoutResponse{Success: true}, nil } -func (s *server) broadcastVisibilityUpdate(cmdType string, info VisibilityInfo) error { - bcastin := broadCastInput{ - Type: "visibility", - Msg: struct { - Command string `json:"command"` - Info VisibilityInfo `json:"info"` - }{ - Command: cmdType, - Info: info, - }, +// validateTopicSubscription checks if subscription exists in memory +func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { + if val, ok := s.subscriptions.Load(subscriptionID); ok { + return val.(*pb.Subscription), nil } - data, err := json.Marshal(bcastin) + // Request subscription details from the leader + leaderSubscription, err := s.requestSubscriptionFromLeader(subscriptionID) if err != nil { - return err + return nil, status.Error(codes.NotFound, "subscription not found in memory or leader") + } + + // Store it in memory to prevent duplicate lookups + s.subscriptions.Store(subscriptionID, leaderSubscription) + + return leaderSubscription, nil // Do not store in-memory cache here +} + +// broadcastLock sends lock request to all nodes and waits for acknowledgment +func (s *server) broadcastLock(ctx context.Context, messageID string, timeout time.Duration) error { + lockInfo := MessageLockInfo{ + Timeout: time.Now().Add(timeout), + Locked: true, + NodeID: s.Op.ID(), + LockHolders: make(map[string]bool), + } + + // Store initial lock info before broadcasting + _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) + if loaded { + return fmt.Errorf("message already locked by another node") } - results := s.op.Broadcast(context.Background(), data) - for _, result := range results { - if result.Error != nil { - log.Printf("Broadcast error to node %s: %v", result.NodeID, result.Error) + // Broadcast lock request + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("lock:%s:%d", messageID, timeout.Seconds())), + } + + bin, _ := json.Marshal(broadcastData) + out := s.Op.Broadcast(ctx, bin) + + // Ensure majority of nodes acknowledged + successCount := 0 + for _, v := range out { + if v.Error == nil { + successCount++ } } + if successCount < (len(out)/2 + 1) { + s.messageLocks.Delete(messageID) + return fmt.Errorf("failed to acquire lock across majority of nodes") + } + + // Start local timeout timer + timer := time.NewTimer(timeout) + s.timeoutTimers.Store(messageID, timer) + + go func() { + <-timer.C + s.handleMessageTimeout(messageID) + }() + return nil } -func (s *server) startVisibilityCleanup() { - ticker := time.NewTicker(cleanupInterval) - for range ticker.C { - s.cleanupExpiredLocks() - } -} +func (s *server) handleMessageTimeout(messageID string) { + if lockInfo, ok := s.messageLocks.Load(messageID); ok { + info := lockInfo.(MessageLockInfo) + if info.Locked && time.Now().After(info.Timeout) { + log.Printf("[Timeout] Unlocking expired message: %s", messageID) -func (s *server) cleanupExpiredLocks() { - now := time.Now() - s.lockMu.Lock() - defer s.lockMu.Unlock() - - s.visibilityTimeouts.Range(func(key, value interface{}) bool { - visInfo := value.(VisibilityInfo) - if now.After(visInfo.ExpiresAt) { - // Double-check before deleting - if info, exists := s.visibilityTimeouts.Load(key); exists { - if time.Now().Before(info.(VisibilityInfo).ExpiresAt) { - return true // Another node extended it - } + // Broadcast unlock + s.broadcastUnlock(context.Background(), messageID) + + // Remove lock entry + s.messageLocks.Delete(messageID) + + // Notify all nodes to retry processing this message + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("retry:%s", messageID)), } - s.visibilityTimeouts.Delete(key) - s.broadcastVisibilityUpdate("unlock", visInfo) + + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(context.Background(), bin) } - return true - }) + } +} + + +func (s *server) broadcastUnlock(ctx context.Context, messageID string) { + // Ensure only the leader sends the unlock request + if !s.Op.IsLeader() { + log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") + return + } + + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), + } + + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + + // Clean up local state + s.messageLocks.Delete(messageID) + if timer, ok := s.timeoutTimers.Load(messageID); ok { + timer.(*time.Timer).Stop() + s.timeoutTimers.Delete(messageID) + } + + log.Printf("[Unlock] Leader node unlocked message: %s", messageID) +} + +//helper function - checks if the current node is the leader node in your pub/sub system. +func (s *server) IsLeader() bool { + return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID +} + + +//helper function - asks the leader node for messages +func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { + // Simulated request to leader (replace with actual leader communication) + log.Printf("[Leader] Requesting message for topic: %s", topicID) + return nil, status.Error(codes.NotFound, "no messages available from leader") +} + + +func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string) error { + // Non-leader nodes should not modify state directly + if !s.IsLeader() { + return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") + } + + value, exists := s.messageLocks.Load(messageID) + if !exists { + return status.Error(codes.NotFound, "message not locked") + } + + info, ok := value.(VisibilityInfo) + if !ok || info.SubscriberID != subscriberID { + return status.Error(codes.PermissionDenied, "message locked by another subscriber") + } + + // Leader extends visibility timeout + newExpiresAt := time.Now().Add(visibilityTimeout) + info.ExpiresAt = newExpiresAt + s.messageLocks.Store(messageID, info) + + // Create broadcast message + broadcastMsg := broadCastInput{ + Type: "extend", + Msg: fmt.Sprintf("%s:%d", messageID, visibilityTimeout.Seconds()), + } + msgBytes, _ := json.Marshal(broadcastMsg) + + // Leader broadcasts the new timeout + s.Op.Broadcast(context.TODO(), msgBytes) + + log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) + return nil +} + +//helper function - listen for the leader's broadcast and apply the timeout only when received. +func (s *server) HandleTimeoutExtension(msg broadCastInput) { + // Parse message + parts := strings.Split(string(msg.Msg), ":") + if len(parts) != 2 { + log.Println("[HandleTimeoutExtension] Invalid message format") + return + } + + messageID := parts[0] + timeoutSeconds, err := strconv.Atoi(parts[1]) + if err != nil { + log.Println("[HandleTimeoutExtension] Failed to parse timeout value") + return + } + + // Apply the extended timeout + value, exists := s.messageLocks.Load(messageID) + if !exists { + log.Printf("[HandleTimeoutExtension] Message %s not found in locks", messageID) + return + } + + info, ok := value.(VisibilityInfo) + if !ok { + log.Println("[HandleTimeoutExtension] Invalid visibility info") + return + } + + info.ExpiresAt = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) + s.messageLocks.Store(messageID, info) + + log.Printf("[HandleTimeoutExtension] Applied timeout extension for message: %s", messageID) } From 15238105d4debe5d60822231008136b826a56e38 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy Date: Tue, 25 Feb 2025 12:33:05 +0800 Subject: [PATCH 08/21] Merged main into kate_branch --- helpers.go | 222 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 helpers.go diff --git a/helpers.go b/helpers.go new file mode 100644 index 00000000..9b2af89b --- /dev/null +++ b/helpers.go @@ -0,0 +1,222 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "strconv" + "strings" + "sync" + "time" + pb "github.com/alphauslabs/pubsub-proto/v1" + "google.golang.org/grpc/status" + "google.golang.org/grpc/codes" +) + + +//HELPERFUNCTIONS// +// validateTopicSubscription checks if subscription exists in memory +func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { + if val, ok := s.subscriptions.Load(subscriptionID); ok { + return val.(*pb.Subscription), nil + } + + // Request subscription details from the leader + leaderSubscription, err := s.requestSubscriptionFromLeader(subscriptionID) + if err != nil { + return nil, status.Error(codes.NotFound, "subscription not found in memory or leader") + } + + // Store it in memory to prevent duplicate lookups + s.subscriptions.Store(subscriptionID, leaderSubscription) + + return leaderSubscription, nil // Do not store in-memory cache here +} + +// broadcastLock sends lock request to all nodes and waits for acknowledgment +func (s *server) broadcastLock(ctx context.Context, messageID string, timeout time.Duration) error { + lockInfo := MessageLockInfo{ + Timeout: time.Now().Add(timeout), + Locked: true, + NodeID: s.Op.ID(), + LockHolders: make(map[string]bool), + } + + // Store initial lock info before broadcasting + _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) + if loaded { + return fmt.Errorf("message already locked by another node") + } + + // Broadcast lock request + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("lock:%s:%d", messageID, timeout.Seconds())), + } + + bin, _ := json.Marshal(broadcastData) + out := s.Op.Broadcast(ctx, bin) + + // Ensure majority of nodes acknowledged + successCount := 0 + for _, v := range out { + if v.Error == nil { + successCount++ + } + } + + if successCount < (len(out)/2 + 1) { + s.messageLocks.Delete(messageID) + return fmt.Errorf("failed to acquire lock across majority of nodes") + } + + // Start local timeout timer + timer := time.NewTimer(timeout) + s.timeoutTimers.Store(messageID, timer) + + go func() { + <-timer.C + s.handleMessageTimeout(messageID) + }() + + return nil +} + +//helper function - function ensures that if a node crashes while holding a lock, other nodes can unlock the message and allow it to be processed again. +func (s *server) handleMessageTimeout(messageID string) { + if lockInfo, ok := s.messageLocks.Load(messageID); ok { + info := lockInfo.(MessageLockInfo) + if info.Locked && time.Now().After(info.Timeout) { + log.Printf("[Timeout] Unlocking expired message: %s", messageID) + + // Broadcast unlock + s.BroadcastUnlock(context.Background(), messageID) + + // Remove lock entry + s.messageLocks.Delete(messageID) + + // Notify all nodes to retry processing this message + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("retry:%s", messageID)), + } + + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(context.Background(), bin) + } + } +} + +//helper function - function ensures that only the leader node is responsible for broadcasting unlock requests: +func (s *server) broadcastUnlock(ctx context.Context, messageID string) { + // Ensure only the leader sends the unlock request + if !s.Op.IsLeader() { + log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") + return + } + + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), + } + + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + + // Clean up local state + s.messageLocks.Delete(messageID) + if timer, ok := s.timeoutTimers.Load(messageID); ok { + timer.(*time.Timer).Stop() + s.timeoutTimers.Delete(messageID) + } + + log.Printf("[Unlock] Leader node unlocked message: %s", messageID) +} + +//helper function - checks if the current node is the leader node in your pub/sub system. +func (s *server) IsLeader() bool { + return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID +} + + +//helper function - asks the leader node for messages +func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { + // Simulated request to leader (replace with actual leader communication) + log.Printf("[Leader] Requesting message for topic: %s", topicID) + return nil, status.Error(codes.NotFound, "no messages available from leader") +} + + +func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string) error { + // Non-leader nodes should not modify state directly + if !s.IsLeader() { + return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") + } + + value, exists := s.messageLocks.Load(messageID) + if !exists { + return status.Error(codes.NotFound, "message not locked") + } + + info, ok := value.(VisibilityInfo) + if !ok || info.SubscriberID != subscriberID { + return status.Error(codes.PermissionDenied, "message locked by another subscriber") + } + + // Leader extends visibility timeout + newExpiresAt := time.Now().Add(visibilityTimeout) + info.ExpiresAt = newExpiresAt + s.messageLocks.Store(messageID, info) + + // Create broadcast message + broadcastMsg := broadCastInput{ + Type: "extend", + Msg: fmt.Sprintf("%s:%d", messageID, visibilityTimeout.Seconds()), + } + msgBytes, _ := json.Marshal(broadcastMsg) + + // Leader broadcasts the new timeout + s.Op.Broadcast(context.TODO(), msgBytes) + + log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) + return nil +} + +//helper function - listen for the leader's broadcast and apply the timeout only when received. +func (s *server) HandleTimeoutExtension(msg broadCastInput) { + // Parse message + parts := strings.Split(string(msg.Msg), ":") + if len(parts) != 2 { + log.Println("[HandleTimeoutExtension] Invalid message format") + return + } + + messageID := parts[0] + timeoutSeconds, err := strconv.Atoi(parts[1]) + if err != nil { + log.Println("[HandleTimeoutExtension] Failed to parse timeout value") + return + } + + // Apply the extended timeout + value, exists := s.messageLocks.Load(messageID) + if !exists { + log.Printf("[HandleTimeoutExtension] Message %s not found in locks", messageID) + return + } + + info, ok := value.(VisibilityInfo) + if !ok { + log.Println("[HandleTimeoutExtension] Invalid visibility info") + return + } + + info.ExpiresAt = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) + s.messageLocks.Store(messageID, info) + + log.Printf("[HandleTimeoutExtension] Applied timeout extension for message: %s", messageID) +} + + From 37c89b0aacf06d90796f8513075bf61afdc1bef3 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Tue, 25 Feb 2025 15:27:26 +0800 Subject: [PATCH 09/21] Create helpers.go helper functions for the server.go --- helpers.go | 244 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 244 insertions(+) create mode 100644 helpers.go diff --git a/helpers.go b/helpers.go new file mode 100644 index 00000000..c5c04463 --- /dev/null +++ b/helpers.go @@ -0,0 +1,244 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "strconv" + "strings" + "time" + + pb "github.com/alphauslabs/pubsub-proto/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/alphauslabs/pubsub/app" + "github.com/alphauslabs/pubsub/broadcast" +) + +// MessageLockInfo tracks lock state across nodes +type MessageLockInfo struct { + Timeout time.Time + Locked bool + NodeID string + SubscriberID string // Added to track which subscriber has the lock + LockHolders map[string]bool // Track which nodes have acknowledged the lock +} + +// validateTopicSubscription checks if subscription exists in memory +func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { + if val, ok := s.subscriptions.Load(subscriptionID); ok { + return val.(*pb.Subscription), nil + } + // Request subscription details from the leader - using topicsub message type + broadcastData := broadCastInput{ + Type: topicsub, + Msg: []byte(fmt.Sprintf("get:%s", subscriptionID)), + } + bin, _ := json.Marshal(broadcastData) + resp, err := s.Op.Request(context.Background(), bin) + if err != nil { + return nil, status.Error(codes.Internal, "failed to request subscription from leader") + } + + var subscription pb.Subscription + if err := json.Unmarshal(resp, &subscription); err != nil { + return nil, status.Error(codes.Internal, "failed to parse subscription response") + } + + if subscription.Id == "" { + return nil, status.Error(codes.NotFound, "subscription not found") + } + + // Store it in memory to prevent duplicate lookups + s.subscriptions.Store(subscriptionID, &subscription) + return &subscription, nil +} + +// broadcastLock sends lock request to all nodes and waits for acknowledgment +func (s *server) broadcastLock(ctx context.Context, messageID string, subscriberID string, timeout time.Duration) error { + lockInfo := MessageLockInfo{ + Timeout: time.Now().Add(timeout), + Locked: true, + NodeID: s.Op.ID(), + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), + } + // Store initial lock info before broadcasting + _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) + if loaded { + return fmt.Errorf("message already locked by another node") + } + // Broadcast lock request - format matching handleBroadcastedMsg + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("lock:%s:%d:%s", messageID, int(timeout.Seconds()), subscriberID)), + } + + bin, _ := json.Marshal(broadcastData) + out := s.Op.Broadcast(ctx, bin) + // Ensure majority of nodes acknowledged + successCount := 0 + for _, v := range out { + if v.Error == nil { + successCount++ + } + } + if successCount < (len(out)/2 + 1) { + s.messageLocks.Delete(messageID) + return fmt.Errorf("failed to acquire lock across majority of nodes") + } + // Start local timeout timer + timer := time.NewTimer(timeout) + s.timeoutTimers.Store(messageID, timer) + + go func() { + <-timer.C + s.handleMessageTimeout(messageID) + }() + return nil +} + +// handleMessageTimeout ensures that if a node crashes while holding a lock, +// other nodes can unlock the message and allow it to be processed again. +func (s *server) handleMessageTimeout(messageID string) { + if lockInfo, ok := s.messageLocks.Load(messageID); ok { + info := lockInfo.(MessageLockInfo) + if info.Locked && time.Now().After(info.Timeout) { + log.Printf("[Timeout] Unlocking expired message: %s", messageID) + // Broadcast unlock + s.broadcastUnlock(context.Background(), messageID) + // Remove lock entry + s.messageLocks.Delete(messageID) + // Notify all nodes to retry processing this message + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("retry:%s", messageID)), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(context.Background(), bin) + } + } +} + +// broadcastUnlock ensures that only the leader node is responsible for broadcasting unlock requests +func (s *server) broadcastUnlock(ctx context.Context, messageID string) { + // Ensure only the leader sends the unlock request + if !s.IsLeader() { + log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") + return + } + // Format matching handleBroadcastedMsg + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + // Clean up local state + s.messageLocks.Delete(messageID) + if timer, ok := s.timeoutTimers.Load(messageID); ok { + timer.(*time.Timer).Stop() + s.timeoutTimers.Delete(messageID) + } + log.Printf("[Unlock] Leader node unlocked message: %s", messageID) +} + +// IsLeader checks if the current node is the leader node in the pub/sub system +func (s *server) IsLeader() bool { + return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID +} + +// requestMessageFromLeader asks the leader node for messages +func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { + // Use the message type with proper format for requesting a message + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("getmessage:%s", topicID)), + } + + bin, _ := json.Marshal(broadcastData) + resp, err := s.Op.Request(context.Background(), bin) + if err != nil { + return nil, err + } + + if len(resp) == 0 { + return nil, status.Error(codes.NotFound, "no messages available") + } + + var message pb.Message + if err := json.Unmarshal(resp, &message); err != nil { + return nil, err + } + + return &message, nil +} + +// ExtendVisibilityTimeout extends the visibility timeout for a message +func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, visibilityTimeout time.Duration) error { + // Non-leader nodes should not modify state directly + if !s.IsLeader() { + return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") + } + value, exists := s.messageLocks.Load(messageID) + if !exists { + return status.Error(codes.NotFound, "message not locked") + } + info, ok := value.(MessageLockInfo) + if !ok || info.SubscriberID != subscriberID { + return status.Error(codes.PermissionDenied, "message locked by another subscriber") + } + // Leader extends visibility timeout + newExpiresAt := time.Now().Add(visibilityTimeout) + info.Timeout = newExpiresAt + s.messageLocks.Store(messageID, info) + // Create broadcast message - format matching handleBroadcastedMsg + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("extend:%s:%d", messageID, int(visibilityTimeout.Seconds()))), + } + msgBytes, _ := json.Marshal(broadcastData) + // Leader broadcasts the new timeout + s.Op.Broadcast(context.TODO(), msgBytes) + log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) + return nil +} + +// HandleBroadcastMessage processes broadcast messages received from other nodes +func (s *server) HandleBroadcastMessage(msgType string, msgData []byte) error { + // This method would be called by your broadcast handler + switch msgType { + case "lock": + parts := strings.Split(string(msgData), ":") + if len(parts) < 3 { + return fmt.Errorf("invalid lock message format") + } + messageID := parts[0] + timeoutSecondsStr := parts[1] + subscriberID := parts[2] + + timeoutSeconds, err := strconv.Atoi(timeoutSecondsStr) + if err != nil { + return err + } + + // Store the lock locally + lockInfo := MessageLockInfo{ + Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), + Locked: true, + NodeID: s.Op.ID(), // This is the current node + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), + } + s.messageLocks.Store(messageID, lockInfo) + + case "unlock": + messageID := string(msgData) + s.messageLocks.Delete(messageID) + + // Add other message types as needed + } + + return nil +} From ba02771566801bebebd3ec4cb49cad1e5ce8e8d3 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Tue, 25 Feb 2025 15:29:47 +0800 Subject: [PATCH 10/21] removed the helper functions --- server.go | 534 ++++++++++++++++++------------------------------------ 1 file changed, 181 insertions(+), 353 deletions(-) diff --git a/server.go b/server.go index d852f088..c5c04463 100644 --- a/server.go +++ b/server.go @@ -1,416 +1,244 @@ -//TO DO: separate the helper functions to a different file package main import ( - "context" - "encoding/json" - "fmt" - "log" - "sync" - "time" - - "cloud.google.com/go/spanner" - pb "github.com/alphauslabs/pubsub-proto/v1" - "github.com/google/uuid" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "context" + "encoding/json" + "fmt" + "log" + "strconv" + "strings" + "time" + + pb "github.com/alphauslabs/pubsub-proto/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/alphauslabs/pubsub/app" + "github.com/alphauslabs/pubsub/broadcast" ) +// MessageLockInfo tracks lock state across nodes type MessageLockInfo struct { - Timeout time.Time - Locked bool - NodeID string - LockHolders map[string]bool // Track which nodes have acknowledged the lock -} - -type server struct { - *PubSub - pb.UnimplementedPubSubServiceServer + Timeout time.Time + Locked bool + NodeID string + SubscriberID string // Added to track which subscriber has the lock + LockHolders map[string]bool // Track which nodes have acknowledged the lock } -const ( - MessagesTable = "Messages" -) - -func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.PublishResponse, error) { - if in.Topic == "" { - return nil, status.Error(codes.InvalidArgument, "topic must not be empty") +// validateTopicSubscription checks if subscription exists in memory +func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { + if val, ok := s.subscriptions.Load(subscriptionID); ok { + return val.(*pb.Subscription), nil + } + // Request subscription details from the leader - using topicsub message type + broadcastData := broadCastInput{ + Type: topicsub, + Msg: []byte(fmt.Sprintf("get:%s", subscriptionID)), + } + bin, _ := json.Marshal(broadcastData) + resp, err := s.Op.Request(context.Background(), bin) + if err != nil { + return nil, status.Error(codes.Internal, "failed to request subscription from leader") } - b, _ := json.Marshal(in) + var subscription pb.Subscription + if err := json.Unmarshal(resp, &subscription); err != nil { + return nil, status.Error(codes.Internal, "failed to parse subscription response") + } - l, _ := s.Op.HasLock() - if l { - log.Println("[Publish-leader] Received message:\n", string(b)) - } else { - log.Printf("[Publish] Received message:\n%v", string(b)) + if subscription.Id == "" { + return nil, status.Error(codes.NotFound, "subscription not found") } - messageID := uuid.New().String() - mutation := spanner.InsertOrUpdate( - MessagesTable, - []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, - []interface{}{ - messageID, - in.Topic, - in.Payload, - spanner.CommitTimestamp, - spanner.CommitTimestamp, - nil, // Initial visibilityTimeout is NULL - false, // Not processed yet - }, - ) + // Store it in memory to prevent duplicate lookups + s.subscriptions.Store(subscriptionID, &subscription) + return &subscription, nil +} - _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) - if err != nil { - log.Printf("Error writing to Spanner: %v", err) - return nil, err +// broadcastLock sends lock request to all nodes and waits for acknowledgment +func (s *server) broadcastLock(ctx context.Context, messageID string, subscriberID string, timeout time.Duration) error { + lockInfo := MessageLockInfo{ + Timeout: time.Now().Add(timeout), + Locked: true, + NodeID: s.Op.ID(), + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), } - - // broadcast message - bcastin := broadCastInput{ + // Store initial lock info before broadcasting + _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) + if loaded { + return fmt.Errorf("message already locked by another node") + } + // Broadcast lock request - format matching handleBroadcastedMsg + broadcastData := broadCastInput{ Type: message, - Msg: b, + Msg: []byte(fmt.Sprintf("lock:%s:%d:%s", messageID, int(timeout.Seconds()), subscriberID)), } - bin, _ := json.Marshal(bcastin) + bin, _ := json.Marshal(broadcastData) out := s.Op.Broadcast(ctx, bin) + // Ensure majority of nodes acknowledged + successCount := 0 for _, v := range out { - if v.Error != nil { // for us to know, then do necessary actions if frequent - log.Printf("[Publish] Error broadcasting message: %v", v.Error) + if v.Error == nil { + successCount++ } } - - log.Printf("[Publish] Message successfully broadcasted and wrote to spanner with ID: %s", messageID) - return &pb.PublishResponse{MessageId: messageID}, nil -} - -func (s *server) Subscribe(in *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { - // Validate subscription in memory first - subscription, err := s.validateTopicSubscription(in.SubscriptionId) - if err != nil { - return err - } - - log.Printf("[Subscribe] Starting subscription stream for ID: %s", in.SubscriptionId) - - for { - select { - case <-stream.Context().Done(): - return nil - default: - // Request message from the leader instead of querying directly - message, err := s.requestMessageFromLeader(subscription.TopicId) - if err != nil { - log.Printf("[Subscribe] No available messages for subscription: %s", in.SubscriptionId) - time.Sleep(time.Second) // Prevent CPU overuse - continue - } - - // Ensure it's not already locked by another node - if _, exists := s.messageLocks.Load(message.Id); exists { - continue // Skip locked messages - } - - // Try to acquire distributed lock - if err := s.broadcastLock(stream.Context(), message.Id, 30*time.Second); err != nil { - continue - } - - // Send message to subscriber - if err := stream.Send(message); err != nil { - s.broadcastUnlock(stream.Context(), message.Id) - return err - } - } - } -} - -func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { - // Verify lock exists and is valid - lockInfo, ok := s.messageLocks.Load(in.Id) - if !ok { - return nil, status.Error(codes.NotFound, "message lock not found") - } - info := lockInfo.(MessageLockInfo) - if !info.Locked || time.Now().After(info.Timeout) { - return nil, status.Error(codes.FailedPrecondition, "message lock expired") - } - - // Update Spanner - mutation := spanner.Update( - MessagesTable, - []string{"id", "processed", "updatedAt"}, - []interface{}{in.Id, true, spanner.CommitTimestamp}, - ) - - _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) - if err != nil { - return nil, err - } - - // Broadcast delete to all nodes - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("delete:%s", in.Id)), - } - - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(ctx, bin) - - // Clean up local state - s.messageLocks.Delete(in.Id) - if timer, ok := s.timeoutTimers.Load(in.Id); ok { - timer.(*time.Timer).Stop() - s.timeoutTimers.Delete(in.Id) - } - - return &pb.AcknowledgeResponse{Success: true}, nil -} - -func (s *server) ModifyVisibilityTimeout(ctx context.Context, in *pb.ModifyVisibilityTimeoutRequest) (*pb.ModifyVisibilityTimeoutResponse, error) { - lockInfo, ok := s.messageLocks.Load(in.Id) - if !ok { - return nil, status.Error(codes.NotFound, "message lock not found") - } - - info := lockInfo.(MessageLockInfo) - if !info.Locked { - return nil, status.Error(codes.FailedPrecondition, "message not locked") - } - - // Ensure the same node is extending the lock - if info.NodeID != s.Op.ID() { - return nil, status.Error(codes.PermissionDenied, "only the original lock holder can extend timeout") - } - - // Broadcast new timeout - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("extend:%s:%d", in.Id, in.NewTimeout)), - } - - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(ctx, bin) - - // Update local timer - if timer, ok := s.timeoutTimers.Load(in.Id); ok { - timer.(*time.Timer).Stop() - } - newTimer := time.NewTimer(time.Duration(in.NewTimeout) * time.Second) - s.timeoutTimers.Store(in.Id, newTimer) - - // Update lock info - info.Timeout = time.Now().Add(time.Duration(in.NewTimeout) * time.Second) - s.messageLocks.Store(in.Id, info) - - go func() { - <-newTimer.C - s.handleMessageTimeout(in.Id) - }() - - return &pb.ModifyVisibilityTimeoutResponse{Success: true}, nil -} - - -// validateTopicSubscription checks if subscription exists in memory -func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { - if val, ok := s.subscriptions.Load(subscriptionID); ok { - return val.(*pb.Subscription), nil - } - - // Request subscription details from the leader - leaderSubscription, err := s.requestSubscriptionFromLeader(subscriptionID) - if err != nil { - return nil, status.Error(codes.NotFound, "subscription not found in memory or leader") - } - - // Store it in memory to prevent duplicate lookups - s.subscriptions.Store(subscriptionID, leaderSubscription) - - return leaderSubscription, nil // Do not store in-memory cache here -} - -// broadcastLock sends lock request to all nodes and waits for acknowledgment -func (s *server) broadcastLock(ctx context.Context, messageID string, timeout time.Duration) error { - lockInfo := MessageLockInfo{ - Timeout: time.Now().Add(timeout), - Locked: true, - NodeID: s.Op.ID(), - LockHolders: make(map[string]bool), - } - - // Store initial lock info before broadcasting - _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) - if loaded { - return fmt.Errorf("message already locked by another node") - } - - // Broadcast lock request - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("lock:%s:%d", messageID, timeout.Seconds())), - } - - bin, _ := json.Marshal(broadcastData) - out := s.Op.Broadcast(ctx, bin) - - // Ensure majority of nodes acknowledged - successCount := 0 - for _, v := range out { - if v.Error == nil { - successCount++ - } - } - - if successCount < (len(out)/2 + 1) { - s.messageLocks.Delete(messageID) - return fmt.Errorf("failed to acquire lock across majority of nodes") - } - - // Start local timeout timer - timer := time.NewTimer(timeout) - s.timeoutTimers.Store(messageID, timer) - - go func() { - <-timer.C - s.handleMessageTimeout(messageID) - }() - - return nil + if successCount < (len(out)/2 + 1) { + s.messageLocks.Delete(messageID) + return fmt.Errorf("failed to acquire lock across majority of nodes") + } + // Start local timeout timer + timer := time.NewTimer(timeout) + s.timeoutTimers.Store(messageID, timer) + + go func() { + <-timer.C + s.handleMessageTimeout(messageID) + }() + return nil } +// handleMessageTimeout ensures that if a node crashes while holding a lock, +// other nodes can unlock the message and allow it to be processed again. func (s *server) handleMessageTimeout(messageID string) { - if lockInfo, ok := s.messageLocks.Load(messageID); ok { - info := lockInfo.(MessageLockInfo) - if info.Locked && time.Now().After(info.Timeout) { - log.Printf("[Timeout] Unlocking expired message: %s", messageID) - - // Broadcast unlock - s.broadcastUnlock(context.Background(), messageID) - - // Remove lock entry - s.messageLocks.Delete(messageID) - - // Notify all nodes to retry processing this message - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("retry:%s", messageID)), - } - - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(context.Background(), bin) - } - } + if lockInfo, ok := s.messageLocks.Load(messageID); ok { + info := lockInfo.(MessageLockInfo) + if info.Locked && time.Now().After(info.Timeout) { + log.Printf("[Timeout] Unlocking expired message: %s", messageID) + // Broadcast unlock + s.broadcastUnlock(context.Background(), messageID) + // Remove lock entry + s.messageLocks.Delete(messageID) + // Notify all nodes to retry processing this message + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("retry:%s", messageID)), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(context.Background(), bin) + } + } } - +// broadcastUnlock ensures that only the leader node is responsible for broadcasting unlock requests func (s *server) broadcastUnlock(ctx context.Context, messageID string) { - // Ensure only the leader sends the unlock request - if !s.Op.IsLeader() { - log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") - return - } - - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), - } - - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(ctx, bin) - - // Clean up local state - s.messageLocks.Delete(messageID) - if timer, ok := s.timeoutTimers.Load(messageID); ok { - timer.(*time.Timer).Stop() - s.timeoutTimers.Delete(messageID) - } - - log.Printf("[Unlock] Leader node unlocked message: %s", messageID) + // Ensure only the leader sends the unlock request + if !s.IsLeader() { + log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") + return + } + // Format matching handleBroadcastedMsg + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + // Clean up local state + s.messageLocks.Delete(messageID) + if timer, ok := s.timeoutTimers.Load(messageID); ok { + timer.(*time.Timer).Stop() + s.timeoutTimers.Delete(messageID) + } + log.Printf("[Unlock] Leader node unlocked message: %s", messageID) } -//helper function - checks if the current node is the leader node in your pub/sub system. +// IsLeader checks if the current node is the leader node in the pub/sub system func (s *server) IsLeader() bool { - return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID + return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID } - -//helper function - asks the leader node for messages +// requestMessageFromLeader asks the leader node for messages func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { - // Simulated request to leader (replace with actual leader communication) - log.Printf("[Leader] Requesting message for topic: %s", topicID) - return nil, status.Error(codes.NotFound, "no messages available from leader") -} + // Use the message type with proper format for requesting a message + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("getmessage:%s", topicID)), + } + + bin, _ := json.Marshal(broadcastData) + resp, err := s.Op.Request(context.Background(), bin) + if err != nil { + return nil, err + } + + if len(resp) == 0 { + return nil, status.Error(codes.NotFound, "no messages available") + } + var message pb.Message + if err := json.Unmarshal(resp, &message); err != nil { + return nil, err + } + + return &message, nil +} -func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string) error { +// ExtendVisibilityTimeout extends the visibility timeout for a message +func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, visibilityTimeout time.Duration) error { // Non-leader nodes should not modify state directly if !s.IsLeader() { return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") } - value, exists := s.messageLocks.Load(messageID) if !exists { return status.Error(codes.NotFound, "message not locked") } - - info, ok := value.(VisibilityInfo) + info, ok := value.(MessageLockInfo) if !ok || info.SubscriberID != subscriberID { return status.Error(codes.PermissionDenied, "message locked by another subscriber") } - // Leader extends visibility timeout newExpiresAt := time.Now().Add(visibilityTimeout) - info.ExpiresAt = newExpiresAt + info.Timeout = newExpiresAt s.messageLocks.Store(messageID, info) - - // Create broadcast message - broadcastMsg := broadCastInput{ - Type: "extend", - Msg: fmt.Sprintf("%s:%d", messageID, visibilityTimeout.Seconds()), + // Create broadcast message - format matching handleBroadcastedMsg + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("extend:%s:%d", messageID, int(visibilityTimeout.Seconds()))), } - msgBytes, _ := json.Marshal(broadcastMsg) - + msgBytes, _ := json.Marshal(broadcastData) // Leader broadcasts the new timeout s.Op.Broadcast(context.TODO(), msgBytes) - log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) return nil } -//helper function - listen for the leader's broadcast and apply the timeout only when received. -func (s *server) HandleTimeoutExtension(msg broadCastInput) { - // Parse message - parts := strings.Split(string(msg.Msg), ":") - if len(parts) != 2 { - log.Println("[HandleTimeoutExtension] Invalid message format") - return - } +// HandleBroadcastMessage processes broadcast messages received from other nodes +func (s *server) HandleBroadcastMessage(msgType string, msgData []byte) error { + // This method would be called by your broadcast handler + switch msgType { + case "lock": + parts := strings.Split(string(msgData), ":") + if len(parts) < 3 { + return fmt.Errorf("invalid lock message format") + } + messageID := parts[0] + timeoutSecondsStr := parts[1] + subscriberID := parts[2] - messageID := parts[0] - timeoutSeconds, err := strconv.Atoi(parts[1]) - if err != nil { - log.Println("[HandleTimeoutExtension] Failed to parse timeout value") - return - } + timeoutSeconds, err := strconv.Atoi(timeoutSecondsStr) + if err != nil { + return err + } - // Apply the extended timeout - value, exists := s.messageLocks.Load(messageID) - if !exists { - log.Printf("[HandleTimeoutExtension] Message %s not found in locks", messageID) - return - } + // Store the lock locally + lockInfo := MessageLockInfo{ + Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), + Locked: true, + NodeID: s.Op.ID(), // This is the current node + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), + } + s.messageLocks.Store(messageID, lockInfo) - info, ok := value.(VisibilityInfo) - if !ok { - log.Println("[HandleTimeoutExtension] Invalid visibility info") - return - } + case "unlock": + messageID := string(msgData) + s.messageLocks.Delete(messageID) - info.ExpiresAt = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) - s.messageLocks.Store(messageID, info) + // Add other message types as needed + } - log.Printf("[HandleTimeoutExtension] Applied timeout extension for message: %s", messageID) + return nil } - From 8ffd918b4cb68ce0df3f4ac84c53aebfaf5be568 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Tue, 25 Feb 2025 15:34:12 +0800 Subject: [PATCH 11/21] removed the helper functions --- server.go | 330 +++++++++++++++++++++++------------------------------- 1 file changed, 138 insertions(+), 192 deletions(-) diff --git a/server.go b/server.go index c5c04463..26e4c44a 100644 --- a/server.go +++ b/server.go @@ -3,242 +3,188 @@ package main import ( "context" "encoding/json" - "fmt" "log" - "strconv" - "strings" "time" + "cloud.google.com/go/spanner" pb "github.com/alphauslabs/pubsub-proto/v1" + "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "github.com/alphauslabs/pubsub/app" "github.com/alphauslabs/pubsub/broadcast" ) -// MessageLockInfo tracks lock state across nodes -type MessageLockInfo struct { - Timeout time.Time - Locked bool - NodeID string - SubscriberID string // Added to track which subscriber has the lock - LockHolders map[string]bool // Track which nodes have acknowledged the lock +type server struct { + *PubSub + pb.UnimplementedPubSubServiceServer } -// validateTopicSubscription checks if subscription exists in memory -func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { - if val, ok := s.subscriptions.Load(subscriptionID); ok { - return val.(*pb.Subscription), nil - } - // Request subscription details from the leader - using topicsub message type - broadcastData := broadCastInput{ - Type: topicsub, - Msg: []byte(fmt.Sprintf("get:%s", subscriptionID)), - } - bin, _ := json.Marshal(broadcastData) - resp, err := s.Op.Request(context.Background(), bin) - if err != nil { - return nil, status.Error(codes.Internal, "failed to request subscription from leader") - } +// Constant for table name and message types +const ( + MessagesTable = "Messages" + // message = "message" // Match the constants in broadcast.go + // topicsub = "topicsub" // Match the constants in broadcast.go +) - var subscription pb.Subscription - if err := json.Unmarshal(resp, &subscription); err != nil { - return nil, status.Error(codes.Internal, "failed to parse subscription response") - } - if subscription.Id == "" { - return nil, status.Error(codes.NotFound, "subscription not found") +// Publish a message to a topic +func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.PublishResponse, error) { + if in.Topic == "" { + return nil, status.Error(codes.InvalidArgument, "topic must not be empty") + } + b, _ := json.Marshal(in) + l, _ := s.Op.HasLock() + if l { + log.Println("[Publish-leader] Received message:\n", string(b)) + } else { + log.Printf("[Publish] Received message:\n%v", string(b)) + } + + messageID := uuid.New().String() + mutation := spanner.InsertOrUpdate( + MessagesTable, + []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, + []interface{}{ + messageID, + in.Topic, + in.Payload, + spanner.CommitTimestamp, + spanner.CommitTimestamp, + nil, // Initial visibilityTimeout is NULL + false, // Not processed yet + }, + ) + _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) + if err != nil { + log.Printf("Error writing to Spanner: %v", err) + return nil, err } - // Store it in memory to prevent duplicate lookups - s.subscriptions.Store(subscriptionID, &subscription) - return &subscription, nil -} - -// broadcastLock sends lock request to all nodes and waits for acknowledgment -func (s *server) broadcastLock(ctx context.Context, messageID string, subscriberID string, timeout time.Duration) error { - lockInfo := MessageLockInfo{ - Timeout: time.Now().Add(timeout), - Locked: true, - NodeID: s.Op.ID(), - SubscriberID: subscriberID, - LockHolders: make(map[string]bool), - } - // Store initial lock info before broadcasting - _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) - if loaded { - return fmt.Errorf("message already locked by another node") - } - // Broadcast lock request - format matching handleBroadcastedMsg - broadcastData := broadCastInput{ + // broadcast message - using the correct message Type constant + bcastin := broadCastInput{ Type: message, - Msg: []byte(fmt.Sprintf("lock:%s:%d:%s", messageID, int(timeout.Seconds()), subscriberID)), + Msg: b, } - - bin, _ := json.Marshal(broadcastData) + bin, _ := json.Marshal(bcastin) out := s.Op.Broadcast(ctx, bin) - // Ensure majority of nodes acknowledged - successCount := 0 for _, v := range out { - if v.Error == nil { - successCount++ + if v.Error != nil { // for us to know, then do necessary actions if frequent + log.Printf("[Publish] Error broadcasting message: %v", v.Error) } } - if successCount < (len(out)/2 + 1) { - s.messageLocks.Delete(messageID) - return fmt.Errorf("failed to acquire lock across majority of nodes") - } - // Start local timeout timer - timer := time.NewTimer(timeout) - s.timeoutTimers.Store(messageID, timer) - - go func() { - <-timer.C - s.handleMessageTimeout(messageID) - }() - return nil + log.Printf("[Publish] Message successfully broadcasted and wrote to spanner with ID: %s", messageID) + return &pb.PublishResponse{MessageId: messageID}, nil } -// handleMessageTimeout ensures that if a node crashes while holding a lock, -// other nodes can unlock the message and allow it to be processed again. -func (s *server) handleMessageTimeout(messageID string) { - if lockInfo, ok := s.messageLocks.Load(messageID); ok { - info := lockInfo.(MessageLockInfo) - if info.Locked && time.Now().After(info.Timeout) { - log.Printf("[Timeout] Unlocking expired message: %s", messageID) - // Broadcast unlock - s.broadcastUnlock(context.Background(), messageID) - // Remove lock entry - s.messageLocks.Delete(messageID) - // Notify all nodes to retry processing this message - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("retry:%s", messageID)), +// Subscribe to receive messages for a subscription +func (s *server) Subscribe(in *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { + // Validate subscription in memory first + subscription, err := s.validateTopicSubscription(in.SubscriptionId) + if err != nil { + return err + } + log.Printf("[Subscribe] Starting subscription stream for ID: %s", in.SubscriptionId) + for { + select { + case <-stream.Context().Done(): + return nil + default: + // Request message from the leader instead of querying directly + message, err := s.requestMessageFromLeader(subscription.TopicId) + if err != nil { + log.Printf("[Subscribe] No available messages for subscription: %s", in.SubscriptionId) + time.Sleep(time.Second) // Prevent CPU overuse + continue + } + // Ensure it's not already locked by another node + if _, exists := s.messageLocks.Load(message.Id); exists { + continue // Skip locked messages + } + // Try to acquire distributed lock + if err := s.broadcastLock(stream.Context(), message.Id, in.SubscriptionId, 30*time.Second); err != nil { + continue + } + // Send message to subscriber + if err := stream.Send(message); err != nil { + s.broadcastUnlock(stream.Context(), message.Id) + return err } - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(context.Background(), bin) } } } -// broadcastUnlock ensures that only the leader node is responsible for broadcasting unlock requests -func (s *server) broadcastUnlock(ctx context.Context, messageID string) { - // Ensure only the leader sends the unlock request - if !s.IsLeader() { - log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") - return - } - // Format matching handleBroadcastedMsg +// Acknowledge a processed message +func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { + // Verify lock exists and is valid + lockInfo, ok := s.messageLocks.Load(in.Id) + if !ok { + return nil, status.Error(codes.NotFound, "message lock not found") + } + info := lockInfo.(MessageLockInfo) + if !info.Locked || time.Now().After(info.Timeout) { + return nil, status.Error(codes.FailedPrecondition, "message lock expired") + } + // Update Spanner + mutation := spanner.Update( + MessagesTable, + []string{"id", "processed", "updatedAt"}, + []interface{}{in.Id, true, spanner.CommitTimestamp}, + ) + _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) + if err != nil { + return nil, err + } + // Broadcast delete to all nodes - format matching handleBroadcastedMsg broadcastData := broadCastInput{ Type: message, - Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), + Msg: []byte(fmt.Sprintf("delete:%s", in.Id)), } + bin, _ := json.Marshal(broadcastData) s.Op.Broadcast(ctx, bin) // Clean up local state - s.messageLocks.Delete(messageID) - if timer, ok := s.timeoutTimers.Load(messageID); ok { + s.messageLocks.Delete(in.Id) + if timer, ok := s.timeoutTimers.Load(in.Id); ok { timer.(*time.Timer).Stop() - s.timeoutTimers.Delete(messageID) + s.timeoutTimers.Delete(in.Id) } - log.Printf("[Unlock] Leader node unlocked message: %s", messageID) -} - -// IsLeader checks if the current node is the leader node in the pub/sub system -func (s *server) IsLeader() bool { - return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID + return &pb.AcknowledgeResponse{Success: true}, nil } -// requestMessageFromLeader asks the leader node for messages -func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { - // Use the message type with proper format for requesting a message - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("getmessage:%s", topicID)), - } - - bin, _ := json.Marshal(broadcastData) - resp, err := s.Op.Request(context.Background(), bin) - if err != nil { - return nil, err +// ModifyVisibilityTimeout extends message lock timeout +func (s *server) ModifyVisibilityTimeout(ctx context.Context, in *pb.ModifyVisibilityTimeoutRequest) (*pb.ModifyVisibilityTimeoutResponse, error) { + lockInfo, ok := s.messageLocks.Load(in.Id) + if !ok { + return nil, status.Error(codes.NotFound, "message lock not found") } - - if len(resp) == 0 { - return nil, status.Error(codes.NotFound, "no messages available") + info := lockInfo.(MessageLockInfo) + if !info.Locked { + return nil, status.Error(codes.FailedPrecondition, "message not locked") } - - var message pb.Message - if err := json.Unmarshal(resp, &message); err != nil { - return nil, err + // Ensure the same node is extending the lock + if info.NodeID != s.Op.ID() { + return nil, status.Error(codes.PermissionDenied, "only the original lock holder can extend timeout") } - - return &message, nil -} - -// ExtendVisibilityTimeout extends the visibility timeout for a message -func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, visibilityTimeout time.Duration) error { - // Non-leader nodes should not modify state directly - if !s.IsLeader() { - return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") - } - value, exists := s.messageLocks.Load(messageID) - if !exists { - return status.Error(codes.NotFound, "message not locked") - } - info, ok := value.(MessageLockInfo) - if !ok || info.SubscriberID != subscriberID { - return status.Error(codes.PermissionDenied, "message locked by another subscriber") - } - // Leader extends visibility timeout - newExpiresAt := time.Now().Add(visibilityTimeout) - info.Timeout = newExpiresAt - s.messageLocks.Store(messageID, info) - // Create broadcast message - format matching handleBroadcastedMsg + // Broadcast new timeout - format matching handleBroadcastedMsg broadcastData := broadCastInput{ Type: message, - Msg: []byte(fmt.Sprintf("extend:%s:%d", messageID, int(visibilityTimeout.Seconds()))), + Msg: []byte(fmt.Sprintf("extend:%s:%d", in.Id, in.NewTimeout)), } - msgBytes, _ := json.Marshal(broadcastData) - // Leader broadcasts the new timeout - s.Op.Broadcast(context.TODO(), msgBytes) - log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) - return nil -} - -// HandleBroadcastMessage processes broadcast messages received from other nodes -func (s *server) HandleBroadcastMessage(msgType string, msgData []byte) error { - // This method would be called by your broadcast handler - switch msgType { - case "lock": - parts := strings.Split(string(msgData), ":") - if len(parts) < 3 { - return fmt.Errorf("invalid lock message format") - } - messageID := parts[0] - timeoutSecondsStr := parts[1] - subscriberID := parts[2] - - timeoutSeconds, err := strconv.Atoi(timeoutSecondsStr) - if err != nil { - return err - } - - // Store the lock locally - lockInfo := MessageLockInfo{ - Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), - Locked: true, - NodeID: s.Op.ID(), // This is the current node - SubscriberID: subscriberID, - LockHolders: make(map[string]bool), - } - s.messageLocks.Store(messageID, lockInfo) - - case "unlock": - messageID := string(msgData) - s.messageLocks.Delete(messageID) - - // Add other message types as needed + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + // Update local timer + if timer, ok := s.timeoutTimers.Load(in.Id); ok { + timer.(*time.Timer).Stop() } - - return nil + newTimer := time.NewTimer(time.Duration(in.NewTimeout) * time.Second) + s.timeoutTimers.Store(in.Id, newTimer) + // Update lock info + info.Timeout = time.Now().Add(time.Duration(in.NewTimeout) * time.Second) + s.messageLocks.Store(in.Id, info) + go func() { + <-newTimer.C + s.handleMessageTimeout(in.Id) + }() + return &pb.ModifyVisibilityTimeoutResponse{Success: true}, nil } From b7f053746f1998fdd17f60adda4f014a4e4933b0 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy Date: Tue, 25 Feb 2025 18:15:50 +0800 Subject: [PATCH 12/21] updated broadcast.go --- app.go | 20 + go.sum | 36 +- helpers.go | 545 +++++++++++++++++-------- msg_distribution/visibility_timeout.go | 250 ++++++++++++ 4 files changed, 672 insertions(+), 179 deletions(-) create mode 100644 app.go create mode 100644 msg_distribution/visibility_timeout.go diff --git a/app.go b/app.go new file mode 100644 index 00000000..c961855a --- /dev/null +++ b/app.go @@ -0,0 +1,20 @@ +package main + +import ( + "cloud.google.com/go/spanner" + "github.com/flowerinthenight/hedge/v2" + "sync" +) + +type PubSub struct { + Op *hedge.Op + Client *spanner.Client + // Message handling + messageLocks sync.Map // messageID -> MessageLockInfo + messageQueue sync.Map // topic -> []*pb.Message + subscriptions sync.Map // subscriptionID -> *pb.Subscription + + // Timer tracking + timeoutTimers sync.Map // messageID -> *time.Timer + storage *storage.Storage // jansen's storage +} diff --git a/go.sum b/go.sum index f2887028..f0c8b423 100644 --- a/go.sum +++ b/go.sum @@ -38,8 +38,8 @@ cloud.google.com/go v0.104.0/go.mod h1:OO6xxXdJyvuJPcEPBLN9BJPD+jep5G1+2U5B5gkRY cloud.google.com/go v0.105.0/go.mod h1:PrLgOJNe5nfE9UMxKxgXj4mD3voiP+YQ6gdt6KMFOKM= cloud.google.com/go v0.107.0/go.mod h1:wpc2eNrD7hXUTy8EKS10jkxpZBjASrORK7goS+3YX2I= cloud.google.com/go v0.110.0/go.mod h1:SJnCLqQ0FCFGSZMUNUf84MV3Aia54kn7pi8st7tMzaY= -cloud.google.com/go v0.118.1 h1:b8RATMcrK9A4BH0rj8yQupPXp+aP+cJ0l6H7V9osV1E= -cloud.google.com/go v0.118.1/go.mod h1:CFO4UPEPi8oV21xoezZCrd3d81K4fFkDTEJu4R8K+9M= +cloud.google.com/go v0.118.2 h1:bKXO7RXMFDkniAAvvuMrAPtQ/VHrs9e7J5UT3yrGdTY= +cloud.google.com/go v0.118.2/go.mod h1:CFO4UPEPi8oV21xoezZCrd3d81K4fFkDTEJu4R8K+9M= cloud.google.com/go/accessapproval v1.4.0/go.mod h1:zybIuC3KpDOvotz59lFe5qxRZx6C75OtwbisN56xYB4= cloud.google.com/go/accessapproval v1.5.0/go.mod h1:HFy3tuiGvMdcd/u+Cu5b9NkO1pEICJ46IR82PoUdplw= cloud.google.com/go/accessapproval v1.6.0/go.mod h1:R0EiYnwV5fsRFiKZkPHr6mwyk2wxUJ30nL4j2pcFY2E= @@ -319,8 +319,8 @@ cloud.google.com/go/iam v0.8.0/go.mod h1:lga0/y3iH6CX7sYqypWJ33hf7kkfXJag67naqGE cloud.google.com/go/iam v0.11.0/go.mod h1:9PiLDanza5D+oWFZiH1uG+RnRCfEGKoyl6yo4cgWZGY= cloud.google.com/go/iam v0.12.0/go.mod h1:knyHGviacl11zrtZUoDuYpDgLjvr28sLQaG0YB2GYAY= cloud.google.com/go/iam v0.13.0/go.mod h1:ljOg+rcNfzZ5d6f1nAUJ8ZIxOaZUVoS14bKCtaLZ/D0= -cloud.google.com/go/iam v1.3.1 h1:KFf8SaT71yYq+sQtRISn90Gyhyf4X8RGgeAVC8XGf3E= -cloud.google.com/go/iam v1.3.1/go.mod h1:3wMtuyT4NcbnYNPLMBzYRFiEfjKfJlLVLrisE7bwm34= +cloud.google.com/go/iam v1.4.0 h1:ZNfy/TYfn2uh/ukvhp783WhnbVluqf/tzOaqVUPlIPA= +cloud.google.com/go/iam v1.4.0/go.mod h1:gMBgqPaERlriaOV0CUl//XUzDhSfXevn4OEUbg6VRs4= cloud.google.com/go/iap v1.4.0/go.mod h1:RGFwRJdihTINIe4wZ2iCP0zF/qu18ZwyKxrhMhygBEc= cloud.google.com/go/iap v1.5.0/go.mod h1:UH/CGgKd4KyohZL5Pt0jSKE4m3FR51qg6FKQ/z/Ix9A= cloud.google.com/go/iap v1.6.0/go.mod h1:NSuvI9C/j7UdjGjIde7t7HBz+QTwBcapPE07+sSRcLk= @@ -378,8 +378,8 @@ cloud.google.com/go/monitoring v1.7.0/go.mod h1:HpYse6kkGo//7p6sT0wsIC6IBDET0RhI cloud.google.com/go/monitoring v1.8.0/go.mod h1:E7PtoMJ1kQXWxPjB6mv2fhC5/15jInuulFdYYtlcvT4= cloud.google.com/go/monitoring v1.12.0/go.mod h1:yx8Jj2fZNEkL/GYZyTLS4ZtZEZN8WtDEiEqG4kLK50w= cloud.google.com/go/monitoring v1.13.0/go.mod h1:k2yMBAB1H9JT/QETjNkgdCGD9bPF712XiLTVr+cBrpw= -cloud.google.com/go/monitoring v1.23.0 h1:M3nXww2gn9oZ/qWN2bZ35CjolnVHM3qnSbu6srCPgjk= -cloud.google.com/go/monitoring v1.23.0/go.mod h1:034NnlQPDzrQ64G2Gavhl0LUHZs9H3rRmhtnp7jiJgg= +cloud.google.com/go/monitoring v1.24.0 h1:csSKiCJ+WVRgNkRzzz3BPoGjFhjPY23ZTcaenToJxMM= +cloud.google.com/go/monitoring v1.24.0/go.mod h1:Bd1PRK5bmQBQNnuGwHBfUamAV1ys9049oEPHnn4pcsc= cloud.google.com/go/networkconnectivity v1.4.0/go.mod h1:nOl7YL8odKyAOtzNX73/M5/mGZgqqMeryi6UPZTk/rA= cloud.google.com/go/networkconnectivity v1.5.0/go.mod h1:3GzqJx7uhtlM3kln0+x5wyFvuVH1pIBJjhCpjzSt75o= cloud.google.com/go/networkconnectivity v1.6.0/go.mod h1:OJOoEXW+0LAxHh89nXd64uGG+FbQoeH8DtxCHVOMlaM= @@ -526,8 +526,8 @@ cloud.google.com/go/shell v1.6.0/go.mod h1:oHO8QACS90luWgxP3N9iZVuEiSF84zNyLytb+ cloud.google.com/go/spanner v1.41.0/go.mod h1:MLYDBJR/dY4Wt7ZaMIQ7rXOTLjYrmxLE/5ve9vFfWos= cloud.google.com/go/spanner v1.44.0/go.mod h1:G8XIgYdOK+Fbcpbs7p2fiprDw4CaZX63whnSMLVBxjk= cloud.google.com/go/spanner v1.45.0/go.mod h1:FIws5LowYz8YAE1J8fOS7DJup8ff7xJeetWEo5REA2M= -cloud.google.com/go/spanner v1.75.0 h1:2zrltTJv/4P3pCgpYgde4Eb1vN8Cgy1fNy7pbTnOovg= -cloud.google.com/go/spanner v1.75.0/go.mod h1:TLFZBvPQmx3We7sGh12eTk9lLsRLczzZaiweqfMpR80= +cloud.google.com/go/spanner v1.76.1 h1:vYbVZuXfnFwvNcvH3lhI2PeUA+kHyqKmLC7mJWaC4Ok= +cloud.google.com/go/spanner v1.76.1/go.mod h1:YtwoE+zObKY7+ZeDCBtZ2ukM+1/iPaMfUM+KnTh/sx0= cloud.google.com/go/speech v1.6.0/go.mod h1:79tcr4FHCimOp56lwC01xnt/WPJZc4v3gzyT7FoBkCM= cloud.google.com/go/speech v1.7.0/go.mod h1:KptqL+BAQIhMsj1kOP2la5DSEEerPDuOP/2mmkhHhZQ= cloud.google.com/go/speech v1.8.0/go.mod h1:9bYIl1/tjsAnMgKGHKmBZzXKEkGgtU+MpdDPTE9f7y0= @@ -630,8 +630,8 @@ github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuy github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= -github.com/alphauslabs/pubsub-proto v0.0.0-20250221062210-631fa96417c7 h1:05Vkbm1ajlv8/GxFR4lQrE6qeINnlyuCR6CnMU6vn8g= -github.com/alphauslabs/pubsub-proto v0.0.0-20250221062210-631fa96417c7/go.mod h1:sfDRyxsVmiZtG22g8DUGajSp4qOxCkDrwl655xHNDPg= +github.com/alphauslabs/pubsub-proto v0.0.0-20250224043151-d2fff9627a86 h1:W63Fb2ORG7Ym31iYh+KIP74E0i3yu/JFg4XfQz2re5A= +github.com/alphauslabs/pubsub-proto v0.0.0-20250224043151-d2fff9627a86/go.mod h1:sfDRyxsVmiZtG22g8DUGajSp4qOxCkDrwl655xHNDPg= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/arrow/go/v10 v10.0.1/go.mod h1:YvhnlEePVnBS4+0z3fhPfUy7W1Ikj0Ih0vcRo/gZ1M0= @@ -1210,8 +1210,8 @@ golang.org/x/oauth2 v0.4.0/go.mod h1:RznEsdpjGAINPTOF0UH/t+xJ75L18YO3Ho6Pyn+uRec golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I= golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= -golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= -golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= +golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1348,8 +1348,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= -golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= +golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1490,8 +1490,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/ google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI= google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0= google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg= -google.golang.org/api v0.219.0 h1:nnKIvxKs/06jWawp2liznTBnMRQBEPpGo7I+oEypTX0= -google.golang.org/api v0.219.0/go.mod h1:K6OmjGm+NtLrIkHxv1U3a0qIf/0JOvAHd5O/6AoyKYE= +google.golang.org/api v0.222.0 h1:Aiewy7BKLCuq6cUCeOUrsAlzjXPqBkEeQ/iwGHVQa/4= +google.golang.org/api v0.222.0/go.mod h1:efZia3nXpWELrwMlN5vyQrD4GmJN1Vw0x68Et3r+a9c= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1633,8 +1633,8 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= google.golang.org/genproto v0.0.0-20250127172529-29210b9bc287 h1:WoUI1G0DQ648FKvSl756SKxHQR/bI+y4HyyIQfxMWI8= google.golang.org/genproto v0.0.0-20250127172529-29210b9bc287/go.mod h1:wkQ2Aj/xvshAUDtO/JHvu9y+AaN9cqs28QuSVSHtZSY= -google.golang.org/genproto/googleapis/api v0.0.0-20250127172529-29210b9bc287 h1:A2ni10G3UlplFrWdCDJTl7D7mJ7GSRm37S+PDimaKRw= -google.golang.org/genproto/googleapis/api v0.0.0-20250127172529-29210b9bc287/go.mod h1:iYONQfRdizDB8JJBybql13nArx91jcUk7zCXEsOofM4= +google.golang.org/genproto/googleapis/api v0.0.0-20250219182151-9fdb1cabc7b2 h1:35ZFtrCgaAjF7AFAK0+lRSf+4AyYnWRbH7og13p7rZ4= +google.golang.org/genproto/googleapis/api v0.0.0-20250219182151-9fdb1cabc7b2/go.mod h1:W9ynFDP/shebLB1Hl/ESTOap2jHd6pmLXPNZC7SVDbA= google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 h1:DMTIbak9GhdaSxEjvVzAeNZvyc03I61duqNbnm3SU0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= diff --git a/helpers.go b/helpers.go index 9b2af89b..65994129 100644 --- a/helpers.go +++ b/helpers.go @@ -3,220 +3,443 @@ package main import ( "context" "encoding/json" - "errors" "fmt" "log" "strconv" "strings" - "sync" "time" - pb "github.com/alphauslabs/pubsub-proto/v1" - "google.golang.org/grpc/status" - "google.golang.org/grpc/codes" + + pb "github.com/alphauslabs/pubsub-proto/v1" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) +// MessageLockInfo tracks lock state across nodes +type MessageLockInfo struct { + Timeout time.Time + Locked bool + NodeID string + SubscriberID string // Added to track which subscriber has the lock + LockHolders map[string]bool // Track which nodes have acknowledged the lock +} -//HELPERFUNCTIONS// // validateTopicSubscription checks if subscription exists in memory func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { - if val, ok := s.subscriptions.Load(subscriptionID); ok { - return val.(*pb.Subscription), nil - } + if val, ok := s.subscriptions.Load(subscriptionID); ok { + return val.(*pb.Subscription), nil + } + // Request subscription details from the leader - using topicsub message type + broadcastData := broadCastInput{ + Type: topicsub, + Msg: []byte(fmt.Sprintf("get:%s", subscriptionID)), + } + bin, _ := json.Marshal(broadcastData) + resp, err := s.Op.Request(context.Background(), bin) + if err != nil { + return nil, status.Error(codes.Internal, "failed to request subscription from leader") + } - // Request subscription details from the leader - leaderSubscription, err := s.requestSubscriptionFromLeader(subscriptionID) - if err != nil { - return nil, status.Error(codes.NotFound, "subscription not found in memory or leader") - } + var subscription pb.Subscription + if err := json.Unmarshal(resp, &subscription); err != nil { + return nil, status.Error(codes.Internal, "failed to parse subscription response") + } - // Store it in memory to prevent duplicate lookups - s.subscriptions.Store(subscriptionID, leaderSubscription) + if subscription.Id == "" { + return nil, status.Error(codes.NotFound, "subscription not found") + } - return leaderSubscription, nil // Do not store in-memory cache here + // Store it in memory to prevent duplicate lookups + s.subscriptions.Store(subscriptionID, &subscription) + return &subscription, nil } // broadcastLock sends lock request to all nodes and waits for acknowledgment -func (s *server) broadcastLock(ctx context.Context, messageID string, timeout time.Duration) error { - lockInfo := MessageLockInfo{ - Timeout: time.Now().Add(timeout), - Locked: true, - NodeID: s.Op.ID(), - LockHolders: make(map[string]bool), - } - - // Store initial lock info before broadcasting - _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) - if loaded { - return fmt.Errorf("message already locked by another node") - } - - // Broadcast lock request - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("lock:%s:%d", messageID, timeout.Seconds())), - } - - bin, _ := json.Marshal(broadcastData) - out := s.Op.Broadcast(ctx, bin) - - // Ensure majority of nodes acknowledged - successCount := 0 - for _, v := range out { - if v.Error == nil { - successCount++ - } - } - - if successCount < (len(out)/2 + 1) { - s.messageLocks.Delete(messageID) - return fmt.Errorf("failed to acquire lock across majority of nodes") - } - - // Start local timeout timer - timer := time.NewTimer(timeout) - s.timeoutTimers.Store(messageID, timer) - - go func() { - <-timer.C - s.handleMessageTimeout(messageID) - }() +func (s *server) broadcastLock(ctx context.Context, messageID string, subscriberID string, timeout time.Duration) error { + lockInfo := MessageLockInfo{ + Timeout: time.Now().Add(timeout), + Locked: true, + NodeID: s.Op.ID(), + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), + } + // Store initial lock info before broadcasting + _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) + if loaded { + return fmt.Errorf("message already locked by another node") + } + // Broadcast lock request - format matching handleBroadcastedMsg + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("lock:%s:%d:%s", messageID, int(timeout.Seconds()), subscriberID)), + } - return nil + bin, _ := json.Marshal(broadcastData) + out := s.Op.Broadcast(ctx, bin) + // Ensure majority of nodes acknowledged + successCount := 0 + for _, v := range out { + if v.Error == nil { + successCount++ + } + } + if successCount < (len(out)/2 + 1) { + s.messageLocks.Delete(messageID) + return fmt.Errorf("failed to acquire lock across majority of nodes") + } + // Start local timeout timer + timer := time.NewTimer(timeout) + s.timeoutTimers.Store(messageID, timer) + + go func() { + <-timer.C + s.handleMessageTimeout(messageID) + }() + return nil } -//helper function - function ensures that if a node crashes while holding a lock, other nodes can unlock the message and allow it to be processed again. +// handleMessageTimeout ensures that if a node crashes while holding a lock, +// other nodes can unlock the message and allow it to be processed again. func (s *server) handleMessageTimeout(messageID string) { - if lockInfo, ok := s.messageLocks.Load(messageID); ok { - info := lockInfo.(MessageLockInfo) - if info.Locked && time.Now().After(info.Timeout) { - log.Printf("[Timeout] Unlocking expired message: %s", messageID) - - // Broadcast unlock - s.BroadcastUnlock(context.Background(), messageID) - - // Remove lock entry - s.messageLocks.Delete(messageID) - - // Notify all nodes to retry processing this message - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("retry:%s", messageID)), - } - - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(context.Background(), bin) - } - } + if lockInfo, ok := s.messageLocks.Load(messageID); ok { + info := lockInfo.(MessageLockInfo) + if info.Locked && time.Now().After(info.Timeout) { + log.Printf("[Timeout] Unlocking expired message: %s", messageID) + // Broadcast unlock + s.broadcastUnlock(context.Background(), messageID) + // Remove lock entry + s.messageLocks.Delete(messageID) + // Notify all nodes to retry processing this message + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("retry:%s", messageID)), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(context.Background(), bin) + } + } } -//helper function - function ensures that only the leader node is responsible for broadcasting unlock requests: +// broadcastUnlock ensures that only the leader node is responsible for broadcasting unlock requests func (s *server) broadcastUnlock(ctx context.Context, messageID string) { - // Ensure only the leader sends the unlock request - if !s.Op.IsLeader() { - log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") - return - } - - broadcastData := broadCastInput{ - Type: message, - Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), - } - - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(ctx, bin) - - // Clean up local state - s.messageLocks.Delete(messageID) - if timer, ok := s.timeoutTimers.Load(messageID); ok { - timer.(*time.Timer).Stop() - s.timeoutTimers.Delete(messageID) - } - - log.Printf("[Unlock] Leader node unlocked message: %s", messageID) + // Ensure only the leader sends the unlock request + if !s.IsLeader() { + log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") + return + } + // Format matching handleBroadcastedMsg + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + // Clean up local state + s.messageLocks.Delete(messageID) + if timer, ok := s.timeoutTimers.Load(messageID); ok { + timer.(*time.Timer).Stop() + s.timeoutTimers.Delete(messageID) + } + log.Printf("[Unlock] Leader node unlocked message: %s", messageID) } -//helper function - checks if the current node is the leader node in your pub/sub system. +// IsLeader checks if the current node is the leader node in the pub/sub system func (s *server) IsLeader() bool { - return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID + return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID } - -//helper function - asks the leader node for messages +// requestMessageFromLeader asks the leader node for messages func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { - // Simulated request to leader (replace with actual leader communication) - log.Printf("[Leader] Requesting message for topic: %s", topicID) - return nil, status.Error(codes.NotFound, "no messages available from leader") -} + // Use the message type with proper format for requesting a message + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("getmessage:%s", topicID)), + } + + bin, _ := json.Marshal(broadcastData) + resp, err := s.Op.Request(context.Background(), bin) + if err != nil { + return nil, err + } + + if len(resp) == 0 { + return nil, status.Error(codes.NotFound, "no messages available") + } + + var message pb.Message + if err := json.Unmarshal(resp, &message); err != nil { + return nil, err + } + return &message, nil +} -func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string) error { +// ExtendVisibilityTimeout extends the visibility timeout for a message +func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, visibilityTimeout time.Duration) error { // Non-leader nodes should not modify state directly if !s.IsLeader() { return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") } - value, exists := s.messageLocks.Load(messageID) if !exists { return status.Error(codes.NotFound, "message not locked") } - - info, ok := value.(VisibilityInfo) + info, ok := value.(MessageLockInfo) if !ok || info.SubscriberID != subscriberID { return status.Error(codes.PermissionDenied, "message locked by another subscriber") } - // Leader extends visibility timeout newExpiresAt := time.Now().Add(visibilityTimeout) - info.ExpiresAt = newExpiresAt + info.Timeout = newExpiresAt s.messageLocks.Store(messageID, info) - - // Create broadcast message - broadcastMsg := broadCastInput{ - Type: "extend", - Msg: fmt.Sprintf("%s:%d", messageID, visibilityTimeout.Seconds()), + // Create broadcast message - format matching handleBroadcastedMsg + broadcastData := broadCastInput{ + Type: message, + Msg: []byte(fmt.Sprintf("extend:%s:%d", messageID, int(visibilityTimeout.Seconds()))), } - msgBytes, _ := json.Marshal(broadcastMsg) - + msgBytes, _ := json.Marshal(broadcastData) // Leader broadcasts the new timeout s.Op.Broadcast(context.TODO(), msgBytes) - log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) return nil } -//helper function - listen for the leader's broadcast and apply the timeout only when received. -func (s *server) HandleTimeoutExtension(msg broadCastInput) { - // Parse message - parts := strings.Split(string(msg.Msg), ":") - if len(parts) != 2 { - log.Println("[HandleTimeoutExtension] Invalid message format") - return - } - - messageID := parts[0] - timeoutSeconds, err := strconv.Atoi(parts[1]) - if err != nil { - log.Println("[HandleTimeoutExtension] Failed to parse timeout value") - return +// HandleBroadcastMessage processes broadcast messages received from other nodes +func (s *server) HandleBroadcastMessage(msgType string, msgData []byte) error { + // This method would be called by your broadcast handler + switch msgType { + case "lock": + parts := strings.Split(string(msgData), ":") + if len(parts) < 3 { + return fmt.Errorf("invalid lock message format") + } + messageID := parts[0] + timeoutSecondsStr := parts[1] + subscriberID := parts[2] + + timeoutSeconds, err := strconv.Atoi(timeoutSecondsStr) + if err != nil { + return err + } + + // Store the lock locally + lockInfo := MessageLockInfo{ + Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), + Locked: true, + NodeID: s.Op.ID(), // This is the current node + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), + } + s.messageLocks.Store(messageID, lockInfo) + + case "unlock": + messageID := string(msgData) + s.messageLocks.Delete(messageID) + + // Add other message types as needed } - // Apply the extended timeout - value, exists := s.messageLocks.Load(messageID) - if !exists { - log.Printf("[HandleTimeoutExtension] Message %s not found in locks", messageID) - return - } - - info, ok := value.(VisibilityInfo) - if !ok { - log.Println("[HandleTimeoutExtension] Invalid visibility info") - return - } - - info.ExpiresAt = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) - s.messageLocks.Store(messageID, info) - - log.Printf("[HandleTimeoutExtension] Applied timeout extension for message: %s", messageID) + return nil } - +// //HELPERFUNCTIONS// +// // validateTopicSubscription checks if subscription exists in memory +// func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { +// if val, ok := s.subscriptions.Load(subscriptionID); ok { +// return val.(*pb.Subscription), nil +// } + +// // Request subscription details from the leader +// leaderSubscription, err := s.requestSubscriptionFromLeader(subscriptionID) +// if err != nil { +// return nil, status.Error(codes.NotFound, "subscription not found in memory or leader") +// } + +// // Store it in memory to prevent duplicate lookups +// s.subscriptions.Store(subscriptionID, leaderSubscription) + +// return leaderSubscription, nil // Do not store in-memory cache here +// } + +// // broadcastLock sends lock request to all nodes and waits for acknowledgment +// func (s *server) broadcastLock(ctx context.Context, messageID string, timeout time.Duration) error { +// lockInfo := MessageLockInfo{ +// Timeout: time.Now().Add(timeout), +// Locked: true, +// NodeID: s.Op.ID(), +// LockHolders: make(map[string]bool), +// } + +// // Store initial lock info before broadcasting +// _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) +// if loaded { +// return fmt.Errorf("message already locked by another node") +// } + +// // Broadcast lock request +// broadcastData := broadCastInput{ +// Type: message, +// Msg: []byte(fmt.Sprintf("lock:%s:%d", messageID, timeout.Seconds())), +// } + +// bin, _ := json.Marshal(broadcastData) +// out := s.Op.Broadcast(ctx, bin) + +// // Ensure majority of nodes acknowledged +// successCount := 0 +// for _, v := range out { +// if v.Error == nil { +// successCount++ +// } +// } + +// if successCount < (len(out)/2 + 1) { +// s.messageLocks.Delete(messageID) +// return fmt.Errorf("failed to acquire lock across majority of nodes") +// } + +// // Start local timeout timer +// timer := time.NewTimer(timeout) +// s.timeoutTimers.Store(messageID, timer) + +// go func() { +// <-timer.C +// s.handleMessageTimeout(messageID) +// }() + +// return nil +// } + +// //helper function - function ensures that if a node crashes while holding a lock, other nodes can unlock the message and allow it to be processed again. +// func (s *server) handleMessageTimeout(messageID string) { +// if lockInfo, ok := s.messageLocks.Load(messageID); ok { +// info := lockInfo.(MessageLockInfo) +// if info.Locked && time.Now().After(info.Timeout) { +// log.Printf("[Timeout] Unlocking expired message: %s", messageID) + +// // Broadcast unlock +// s.broadcastUnlock(context.Background(), messageID) + +// // Remove lock entry +// s.messageLocks.Delete(messageID) + +// // Notify all nodes to retry processing this message +// broadcastData := broadCastInput{ +// Type: message, +// Msg: []byte(fmt.Sprintf("retry:%s", messageID)), +// } + +// bin, _ := json.Marshal(broadcastData) +// s.Op.Broadcast(context.Background(), bin) +// } +// } +// } + +// //helper function - function ensures that only the leader node is responsible for broadcasting unlock requests: +// func (s *server) broadcastUnlock(ctx context.Context, messageID string) { +// // Ensure only the leader sends the unlock request +// if !s.Op.IsLeader() { +// log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") +// return +// } + +// broadcastData := broadCastInput{ +// Type: message, +// Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), +// } + +// bin, _ := json.Marshal(broadcastData) +// s.Op.Broadcast(ctx, bin) + +// // Clean up local state +// s.messageLocks.Delete(messageID) +// if timer, ok := s.timeoutTimers.Load(messageID); ok { +// timer.(*time.Timer).Stop() +// s.timeoutTimers.Delete(messageID) +// } + +// log.Printf("[Unlock] Leader node unlocked message: %s", messageID) +// } + +// //helper function - checks if the current node is the leader node in your pub/sub system. +// func (s *server) IsLeader() bool { +// return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID +// } + + +// //helper function - asks the leader node for messages +// func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { +// // Simulated request to leader (replace with actual leader communication) +// log.Printf("[Leader] Requesting message for topic: %s", topicID) +// return nil, status.Error(codes.NotFound, "no messages available from leader") +// } + + +// func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string) error { +// // Non-leader nodes should not modify state directly +// if !s.IsLeader() { +// return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") +// } + +// value, exists := s.messageLocks.Load(messageID) +// if !exists { +// return status.Error(codes.NotFound, "message not locked") +// } + +// info, ok := value.(VisibilityInfo) +// if !ok || info.SubscriberID != subscriberID { +// return status.Error(codes.PermissionDenied, "message locked by another subscriber") +// } + +// // Leader extends visibility timeout +// newExpiresAt := time.Now().Add(visibilityTimeout) +// info.ExpiresAt = newExpiresAt +// s.messageLocks.Store(messageID, info) + +// // Create broadcast message +// broadcastMsg := broadCastInput{ +// Type: "extend", +// Msg: fmt.Sprintf("%s:%d", messageID, visibilityTimeout.Seconds()), +// } +// msgBytes, _ := json.Marshal(broadcastMsg) + +// // Leader broadcasts the new timeout +// s.Op.Broadcast(context.TODO(), msgBytes) + +// log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) +// return nil +// } + +// //helper function - listen for the leader's broadcast and apply the timeout only when received. +// func (s *server) HandleTimeoutExtension(msg broadCastInput) { +// // Parse message +// parts := strings.Split(string(msg.Msg), ":") +// if len(parts) != 2 { +// log.Println("[HandleTimeoutExtension] Invalid message format") +// return +// } + +// messageID := parts[0] +// timeoutSeconds, err := strconv.Atoi(parts[1]) +// if err != nil { +// log.Println("[HandleTimeoutExtension] Failed to parse timeout value") +// return +// } + +// // Apply the extended timeout +// value, exists := s.messageLocks.Load(messageID) +// if !exists { +// log.Printf("[HandleTimeoutExtension] Message %s not found in locks", messageID) +// return +// } + +// info, ok := value.(VisibilityInfo) +// if !ok { +// log.Println("[HandleTimeoutExtension] Invalid visibility info") +// return +// } + +// info.ExpiresAt = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) +// s.messageLocks.Store(messageID, info) + +// log.Printf("[HandleTimeoutExtension] Applied timeout extension for message: %s", messageID) +// } \ No newline at end of file diff --git a/msg_distribution/visibility_timeout.go b/msg_distribution/visibility_timeout.go new file mode 100644 index 00000000..e6efb687 --- /dev/null +++ b/msg_distribution/visibility_timeout.go @@ -0,0 +1,250 @@ +package main + +import ( + "context" + "encoding/json" + "log" + "sync" + "time" + + "cloud.google.com/go/spanner" + pb "github.com/alphauslabs/pubsub-proto/v1" + "github.com/flowerinthenight/hedge/v2" + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type server struct { + client *spanner.Client + op *hedge.Op + pb.UnimplementedPubSubServiceServer + messageLocks sync.Map // Tracks message locks with expiration times +} + +const ( + MessagesTable = "Messages" + VisibilityTimeout = time.Minute // 1-minute lock +) + +func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.PublishResponse, error) { + if in.TopicId == "" { + return nil, status.Error(codes.InvalidArgument, "topic must not be empty") + } + + b, _ := json.Marshal(in) + l, _ := s.op.HasLock() + if l { + log.Println("[Publish-leader] Received message:\n", string(b)) + } else { + log.Printf("[Publish] Received message:\n%v", string(b)) + } + + messageID := uuid.New().String() + mutation := spanner.InsertOrUpdate( + MessagesTable, + []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, + []interface{}{ + messageID, + in.TopicId, + in.Payload, + spanner.CommitTimestamp, + spanner.CommitTimestamp, + time.Now().Add(VisibilityTimeout), + false, + }, + ) + + _, err := s.client.Apply(ctx, []*spanner.Mutation{mutation}) + if err != nil { + log.Printf("Error writing to Spanner: %v", err) + return nil, err + } + + log.Printf("[Publish] Message successfully wrote to spanner with ID: %s", messageID) + return &pb.PublishResponse{MessageId: messageID}, nil +} + +func (s *server) Subscribe(req *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { + ctx := stream.Context() + subscriberID := req.SubscriptionId + + for { + stmt := spanner.Statement{ + SQL: `SELECT id, payload, topic FROM Messages WHERE processed = FALSE AND visibilityTimeout <= CURRENT_TIMESTAMP()`, + } + iter := s.client.Single().Query(ctx, stmt) + + for { + row, err := iter.Next() + if err != nil { + return err + } + + var msg pb.Message + if err := row.Columns(&msg.Id, &msg.Payload, &msg.Topic); err != nil { + return err + } + + if _, exists := s.messageLocks.Load(msg.Id); exists { + continue // Skip locked messages + } + + s.messageLocks.Store(msg.Id, time.Now().Add(VisibilityTimeout)) + + if err := stream.Send(&msg); err != nil { + s.messageLocks.Delete(msg.Id) + return err + } + } + } +} + +func (s *server) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { + _, err := s.client.Apply(ctx, []*spanner.Mutation{ + spanner.Update(MessagesTable, []string{"id", "processed"}, []interface{}{req.Id, true}), + }) + if err != nil { + return nil, err + } + + s.messageLocks.Delete(req.Id) + return &pb.AcknowledgeResponse{Success: true}, nil +} + + +//update code + +package main + +import ( + "context" + "encoding/json" + "log" + "sync" + "time" + + "cloud.google.com/go/spanner" + pb "github.com/alphauslabs/pubsub-proto/v1" + "github.com/flowerinthenight/hedge/v2" + "github.com/google/uuid" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type server struct { + client *spanner.Client + op *hedge.Op + pb.UnimplementedPubSubServiceServer + messageLocks sync.Map // Tracks message locks with expiration times +} + +const ( + MessagesTable = "Messages" + VisibilityTimeout = time.Minute // 1-minute lock +) + +// Publish a message to Spanner +func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.PublishResponse, error) { + if in.TopicId == "" { + return nil, status.Error(codes.InvalidArgument, "topic must not be empty") + } + + b, _ := json.Marshal(in) + l, _ := s.op.HasLock() + if l { + log.Println("[Publish-leader] Received message:\n", string(b)) + } else { + log.Printf("[Publish] Received message:\n%v", string(b)) + } + + messageID := uuid.New().String() + mutation := spanner.InsertOrUpdate( + MessagesTable, + []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, + []interface{}{ + messageID, + in.TopicId, + in.Payload, + spanner.CommitTimestamp, + spanner.CommitTimestamp, + time.Now(), // Initially visible + false, + }, + ) + + _, err := s.client.Apply(ctx, []*spanner.Mutation{mutation}) + if err != nil { + log.Printf("Error writing to Spanner: %v", err) + return nil, err + } + + log.Printf("[Publish] Message successfully wrote to spanner with ID: %s", messageID) + return &pb.PublishResponse{MessageId: messageID}, nil +} + +// Assigns a message to a subscriber by updating its visibility timeout +func (s *server) AssignMessage(ctx context.Context) (*pb.Message, error) { + stmt := spanner.Statement{ + SQL: `SELECT id, payload, topic FROM Messages + WHERE processed = FALSE AND visibilityTimeout <= CURRENT_TIMESTAMP() + ORDER BY createdAt ASC + LIMIT 1`, + } + + row, err := s.client.Single().Query(ctx, stmt).Next() + if err != nil { + return nil, err + } + + var msg pb.Message + if err := row.Columns(&msg.Id, &msg.Payload, &msg.Topic); err != nil { + return nil, err + } + + // Lock message in Spanner by updating visibility timeout + updateMutation := spanner.Update( + MessagesTable, + []string{"id", "visibilityTimeout"}, + []interface{}{msg.Id, time.Now().Add(VisibilityTimeout)}, + ) + _, err = s.client.Apply(ctx, []*spanner.Mutation{updateMutation}) + if err != nil { + return nil, err + } + + s.messageLocks.Store(msg.Id, time.Now().Add(VisibilityTimeout)) + return &msg, nil +} + +// Subscribe to messages +func (s *server) Subscribe(req *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { + ctx := stream.Context() + subscriberID := req.SubscriptionId + + for { + msg, err := s.AssignMessage(ctx) + if err != nil { + time.Sleep(time.Second) // Prevent excessive Spanner queries + continue + } + + if err := stream.Send(msg); err != nil { + s.messageLocks.Delete(msg.Id) + return err + } + } +} + +// Acknowledge message processing completion +func (s *server) Acknowledge(ctx context.Context, req *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { + _, err := s.client.Apply(ctx, []*spanner.Mutation{ + spanner.Update(MessagesTable, []string{"id", "processed"}, []interface{}{req.Id, true}), + }) + if err != nil { + return nil, err + } + + s.messageLocks.Delete(req.Id) + return &pb.AcknowledgeResponse{Success: true}, nil +} From c3382b75e72461d11cac00d638a6bebf2824a78c Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy Date: Wed, 26 Feb 2025 02:29:03 +0800 Subject: [PATCH 13/21] had some adjustments with the server, broadcast and helpers --- app/app.go | 4 + broadcast/broadcast.go | 132 ++++++++++++++++++++- helpers.go | 256 ++++------------------------------------- server.go | 48 ++------ 4 files changed, 166 insertions(+), 274 deletions(-) diff --git a/app/app.go b/app/app.go index 442ed999..77f472f8 100644 --- a/app/app.go +++ b/app/app.go @@ -10,4 +10,8 @@ type PubSub struct { Op *hedge.Op Client *spanner.Client Storage *storage.Storage + NodeID string + MessageLocks sync.Map + MessageQueue sync.Map + Mutex sync.Mutex } diff --git a/broadcast/broadcast.go b/broadcast/broadcast.go index 0f71e354..73af5b0f 100644 --- a/broadcast/broadcast.go +++ b/broadcast/broadcast.go @@ -4,6 +4,9 @@ import ( "encoding/json" "fmt" "log" + "strings" + "strconv" + "time" pb "github.com/alphauslabs/pubsub-proto/v1" "github.com/alphauslabs/pubsub/app" @@ -12,6 +15,13 @@ import ( const ( message = "message" topicsub = "topicsub" + msgEvent = "msgEvent" + + // Message event types + lockMsg = "lock" + unlockMsg = "unlock" + deleteMsg = "delete" + extendMsg = "extend" ) type BroadCastInput struct { @@ -19,14 +29,23 @@ type BroadCastInput struct { Msg []byte } +type MessageLockInfo struct { + Locked bool + Timeout time.Time + NodeID string + SubscriberID string // Added to track which subscriber has the lock + LockHolders map[string]bool // Track which nodes have acknowledged the lock +} + var ctrlbroadcast = map[string]func(*app.PubSub, []byte) ([]byte, error){ message: handleBroadcastedMsg, topicsub: handleBroadcastedTopicsub, + msgEvent: handleMessageEvent, // Handles message locks, unlocks, deletes } // Root handler for op.Broadcast() func Broadcast(data any, msg []byte) ([]byte, error) { - var in BroadCastInput +var in BroadCastInput app := data.(*app.PubSub) if err := json.Unmarshal(msg, &in); err != nil { return nil, err @@ -48,6 +67,7 @@ func handleBroadcastedMsg(app *app.PubSub, msg []byte) ([]byte, error) { return nil, nil } + // Handles topic-subscription updates func handleBroadcastedTopicsub(app *app.PubSub, msg []byte) ([]byte, error) { log.Println("Received topic-subscriptions:\n", string(msg)) if err := app.Storage.StoreTopicSubscriptions(msg); err != nil { @@ -56,3 +76,113 @@ func handleBroadcastedTopicsub(app *app.PubSub, msg []byte) ([]byte, error) { return nil, nil } + + + // Handles lock/unlock/delete/extend operations separately +func handleMessageEvent(app *app.PubSub, msg []byte) ([]byte, error) { + parts := strings.Split(string(msg), ":") + if len(parts) < 2 { + return nil, fmt.Errorf("invalid message event format") + } + + messageType := parts[0] + messageID := parts[1] + + // Map message event handlers + eventHandlers := map[string]func(*app.PubSub, string, []string) ([]byte, error){ + lockMsg: handleLockMsg, + unlockMsg: handleUnlockMsg, + deleteMsg: handleDeleteMsg, + extendMsg: handleExtendMsg, + retryMsg: handleRetryMsg, + } + + handler, exists := eventHandlers[messageType] + if !exists { + return nil, fmt.Errorf("unknown message event: %s", messageType) + } + + return handler(app, messageID, parts[2:]) +} + + + // Message event handlers +func handleLockMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { + if len(params) < 2 { + return nil, fmt.Errorf("invalid lock parameters") + } + + timeoutSeconds, err := strconv.Atoi(params[0]) + if err != nil { + return nil, err + } + subscriberID := params[1] + + app.Mutex.Lock() + defer app.Mutex.Unlock() + + if _, exists := app.MessageLocks.Load(messageID); exists { + return nil, nil // Already locked + } + + lockInfo := MessageLockInfo{ + Locked: true, + Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), + NodeID: app.NodeID, + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), + } + app.MessageLocks.Store(messageID, lockInfo) + + return nil, nil + } + + + func handleUnlockMsg(app *app.PubSub, messageID string, _ []string) ([]byte, error) { + if !app.IsLeader() { + return nil, nil // Only leader should handle unlocks + } + + app.Mutex.Lock() + defer app.Mutex.Unlock() + + app.MessageLocks.Delete(messageID) + return nil, nil + } + +func handleDeleteMsg(app *app.PubSub, messageID string, _ []string) ([]byte, error) { + app.Mutex.Lock() + defer app.Mutex.Unlock() + + app.MessageLocks.Delete(messageID) + app.MessageQueue.Delete(messageID) + return nil, nil +} + +func handleExtendMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { + if len(params) < 1 { + return nil, fmt.Errorf("missing timeout parameter for extend message") + } + + timeoutSeconds, err := strconv.Atoi(params[0]) + if err != nil { + return nil, err + } + + app.Mutex.Lock() + defer app.Mutex.Unlock() + + if lockInfo, ok := app.MessageLocks.Load(messageID); ok { + // Update timeout + info := lockInfo.(app.MessageLockInfo) + info.Timeout = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) + app.MessageLocks.Store(messageID, info) + } + + return nil, nil +} + +func handleRetryMsg(app *app.PubSub, messageID string, _ []string) ([]byte, error) { + log.Printf("[Retry] Message %s is now available again", messageID) + return nil, nil + } \ No newline at end of file diff --git a/helpers.go b/helpers.go index e0a0a84d..9a01708f 100644 --- a/helpers.go +++ b/helpers.go @@ -12,21 +12,18 @@ import ( pb "github.com/alphauslabs/pubsub-proto/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" -<<<<<<< HEAD -======= "github.com/alphauslabs/pubsub/app" "github.com/alphauslabs/pubsub/broadcast" ->>>>>>> origin/kate_branch ) -// MessageLockInfo tracks lock state across nodes -type MessageLockInfo struct { - Timeout time.Time - Locked bool - NodeID string - SubscriberID string // Added to track which subscriber has the lock - LockHolders map[string]bool // Track which nodes have acknowledged the lock -} +// // MessageLockInfo tracks lock state across nodes +// type MessageLockInfo struct { +// Timeout time.Time +// Locked bool +// NodeID string +// SubscriberID string // Added to track which subscriber has the lock +// LockHolders map[string]bool // Track which nodes have acknowledged the lock +// } // validateTopicSubscription checks if subscription exists in memory func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { @@ -34,8 +31,8 @@ func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscript return val.(*pb.Subscription), nil } // Request subscription details from the leader - using topicsub message type - broadcastData := broadCastInput{ - Type: topicsub, + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.topicsub, Msg: []byte(fmt.Sprintf("get:%s", subscriptionID)), } bin, _ := json.Marshal(broadcastData) @@ -60,7 +57,7 @@ func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscript // broadcastLock sends lock request to all nodes and waits for acknowledgment func (s *server) broadcastLock(ctx context.Context, messageID string, subscriberID string, timeout time.Duration) error { - lockInfo := MessageLockInfo{ + lockInfo := broadcast.MessageLockInfo{ Timeout: time.Now().Add(timeout), Locked: true, NodeID: s.Op.ID(), @@ -73,8 +70,8 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber return fmt.Errorf("message already locked by another node") } // Broadcast lock request - format matching handleBroadcastedMsg - broadcastData := broadCastInput{ - Type: message, + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, Msg: []byte(fmt.Sprintf("lock:%s:%d:%s", messageID, int(timeout.Seconds()), subscriberID)), } @@ -106,7 +103,7 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber // other nodes can unlock the message and allow it to be processed again. func (s *server) handleMessageTimeout(messageID string) { if lockInfo, ok := s.messageLocks.Load(messageID); ok { - info := lockInfo.(MessageLockInfo) + info := lockInfo.(broadcast.MessageLockInfo) if info.Locked && time.Now().After(info.Timeout) { log.Printf("[Timeout] Unlocking expired message: %s", messageID) // Broadcast unlock @@ -114,8 +111,8 @@ func (s *server) handleMessageTimeout(messageID string) { // Remove lock entry s.messageLocks.Delete(messageID) // Notify all nodes to retry processing this message - broadcastData := broadCastInput{ - Type: message, + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, Msg: []byte(fmt.Sprintf("retry:%s", messageID)), } bin, _ := json.Marshal(broadcastData) @@ -132,8 +129,8 @@ func (s *server) broadcastUnlock(ctx context.Context, messageID string) { return } // Format matching handleBroadcastedMsg - broadcastData := broadCastInput{ - Type: message, + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), } bin, _ := json.Marshal(broadcastData) @@ -155,8 +152,8 @@ func (s *server) IsLeader() bool { // requestMessageFromLeader asks the leader node for messages func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { // Use the message type with proper format for requesting a message - broadcastData := broadCastInput{ - Type: message, + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, Msg: []byte(fmt.Sprintf("getmessage:%s", topicID)), } @@ -188,7 +185,7 @@ func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, if !exists { return status.Error(codes.NotFound, "message not locked") } - info, ok := value.(MessageLockInfo) + info, ok := value.(broadcast.MessageLockInfo) if !ok || info.SubscriberID != subscriberID { return status.Error(codes.PermissionDenied, "message locked by another subscriber") } @@ -197,8 +194,8 @@ func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, info.Timeout = newExpiresAt s.messageLocks.Store(messageID, info) // Create broadcast message - format matching handleBroadcastedMsg - broadcastData := broadCastInput{ - Type: message, + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, Msg: []byte(fmt.Sprintf("extend:%s:%d", messageID, int(visibilityTimeout.Seconds()))), } msgBytes, _ := json.Marshal(broadcastData) @@ -240,214 +237,7 @@ func (s *server) HandleBroadcastMessage(msgType string, msgData []byte) error { messageID := string(msgData) s.messageLocks.Delete(messageID) - // Add other message types as needed } return nil } -<<<<<<< HEAD - -// //HELPERFUNCTIONS// -// // validateTopicSubscription checks if subscription exists in memory -// func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { -// if val, ok := s.subscriptions.Load(subscriptionID); ok { -// return val.(*pb.Subscription), nil -// } - -// // Request subscription details from the leader -// leaderSubscription, err := s.requestSubscriptionFromLeader(subscriptionID) -// if err != nil { -// return nil, status.Error(codes.NotFound, "subscription not found in memory or leader") -// } - -// // Store it in memory to prevent duplicate lookups -// s.subscriptions.Store(subscriptionID, leaderSubscription) - -// return leaderSubscription, nil // Do not store in-memory cache here -// } - -// // broadcastLock sends lock request to all nodes and waits for acknowledgment -// func (s *server) broadcastLock(ctx context.Context, messageID string, timeout time.Duration) error { -// lockInfo := MessageLockInfo{ -// Timeout: time.Now().Add(timeout), -// Locked: true, -// NodeID: s.Op.ID(), -// LockHolders: make(map[string]bool), -// } - -// // Store initial lock info before broadcasting -// _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) -// if loaded { -// return fmt.Errorf("message already locked by another node") -// } - -// // Broadcast lock request -// broadcastData := broadCastInput{ -// Type: message, -// Msg: []byte(fmt.Sprintf("lock:%s:%d", messageID, timeout.Seconds())), -// } - -// bin, _ := json.Marshal(broadcastData) -// out := s.Op.Broadcast(ctx, bin) - -// // Ensure majority of nodes acknowledged -// successCount := 0 -// for _, v := range out { -// if v.Error == nil { -// successCount++ -// } -// } - -// if successCount < (len(out)/2 + 1) { -// s.messageLocks.Delete(messageID) -// return fmt.Errorf("failed to acquire lock across majority of nodes") -// } - -// // Start local timeout timer -// timer := time.NewTimer(timeout) -// s.timeoutTimers.Store(messageID, timer) - -// go func() { -// <-timer.C -// s.handleMessageTimeout(messageID) -// }() - -// return nil -// } - -// //helper function - function ensures that if a node crashes while holding a lock, other nodes can unlock the message and allow it to be processed again. -// func (s *server) handleMessageTimeout(messageID string) { -// if lockInfo, ok := s.messageLocks.Load(messageID); ok { -// info := lockInfo.(MessageLockInfo) -// if info.Locked && time.Now().After(info.Timeout) { -// log.Printf("[Timeout] Unlocking expired message: %s", messageID) - -// // Broadcast unlock -// s.broadcastUnlock(context.Background(), messageID) - -// // Remove lock entry -// s.messageLocks.Delete(messageID) - -// // Notify all nodes to retry processing this message -// broadcastData := broadCastInput{ -// Type: message, -// Msg: []byte(fmt.Sprintf("retry:%s", messageID)), -// } - -// bin, _ := json.Marshal(broadcastData) -// s.Op.Broadcast(context.Background(), bin) -// } -// } -// } - -// //helper function - function ensures that only the leader node is responsible for broadcasting unlock requests: -// func (s *server) broadcastUnlock(ctx context.Context, messageID string) { -// // Ensure only the leader sends the unlock request -// if !s.Op.IsLeader() { -// log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") -// return -// } - -// broadcastData := broadCastInput{ -// Type: message, -// Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), -// } - -// bin, _ := json.Marshal(broadcastData) -// s.Op.Broadcast(ctx, bin) - -// // Clean up local state -// s.messageLocks.Delete(messageID) -// if timer, ok := s.timeoutTimers.Load(messageID); ok { -// timer.(*time.Timer).Stop() -// s.timeoutTimers.Delete(messageID) -// } - -// log.Printf("[Unlock] Leader node unlocked message: %s", messageID) -// } - -// //helper function - checks if the current node is the leader node in your pub/sub system. -// func (s *server) IsLeader() bool { -// return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID -// } - - -// //helper function - asks the leader node for messages -// func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { -// // Simulated request to leader (replace with actual leader communication) -// log.Printf("[Leader] Requesting message for topic: %s", topicID) -// return nil, status.Error(codes.NotFound, "no messages available from leader") -// } - - -// func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string) error { -// // Non-leader nodes should not modify state directly -// if !s.IsLeader() { -// return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") -// } - -// value, exists := s.messageLocks.Load(messageID) -// if !exists { -// return status.Error(codes.NotFound, "message not locked") -// } - -// info, ok := value.(VisibilityInfo) -// if !ok || info.SubscriberID != subscriberID { -// return status.Error(codes.PermissionDenied, "message locked by another subscriber") -// } - -// // Leader extends visibility timeout -// newExpiresAt := time.Now().Add(visibilityTimeout) -// info.ExpiresAt = newExpiresAt -// s.messageLocks.Store(messageID, info) - -// // Create broadcast message -// broadcastMsg := broadCastInput{ -// Type: "extend", -// Msg: fmt.Sprintf("%s:%d", messageID, visibilityTimeout.Seconds()), -// } -// msgBytes, _ := json.Marshal(broadcastMsg) - -// // Leader broadcasts the new timeout -// s.Op.Broadcast(context.TODO(), msgBytes) - -// log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) -// return nil -// } - -// //helper function - listen for the leader's broadcast and apply the timeout only when received. -// func (s *server) HandleTimeoutExtension(msg broadCastInput) { -// // Parse message -// parts := strings.Split(string(msg.Msg), ":") -// if len(parts) != 2 { -// log.Println("[HandleTimeoutExtension] Invalid message format") -// return -// } - -// messageID := parts[0] -// timeoutSeconds, err := strconv.Atoi(parts[1]) -// if err != nil { -// log.Println("[HandleTimeoutExtension] Failed to parse timeout value") -// return -// } - -// // Apply the extended timeout -// value, exists := s.messageLocks.Load(messageID) -// if !exists { -// log.Printf("[HandleTimeoutExtension] Message %s not found in locks", messageID) -// return -// } - -// info, ok := value.(VisibilityInfo) -// if !ok { -// log.Println("[HandleTimeoutExtension] Invalid visibility info") -// return -// } - -// info.ExpiresAt = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) -// s.messageLocks.Store(messageID, info) - -// log.Printf("[HandleTimeoutExtension] Applied timeout extension for message: %s", messageID) -// } -======= ->>>>>>> origin/kate_branch diff --git a/server.go b/server.go index dcc32773..624faa94 100644 --- a/server.go +++ b/server.go @@ -5,27 +5,19 @@ import ( "encoding/json" "log" "time" + "fmt" "cloud.google.com/go/spanner" pb "github.com/alphauslabs/pubsub-proto/v1" -<<<<<<< HEAD "github.com/alphauslabs/pubsub/app" "github.com/alphauslabs/pubsub/broadcast" -======= ->>>>>>> origin/kate_branch "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/alphauslabs/pubsub/app" - "github.com/alphauslabs/pubsub/broadcast" ) type server struct { -<<<<<<< HEAD *app.PubSub -======= - *PubSub ->>>>>>> origin/kate_branch pb.UnimplementedPubSubServiceServer } @@ -43,10 +35,6 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis return nil, status.Error(codes.InvalidArgument, "topic must not be empty") } b, _ := json.Marshal(in) -<<<<<<< HEAD - -======= ->>>>>>> origin/kate_branch l, _ := s.Op.HasLock() if l { log.Println("[Publish-leader] Received message:\n", string(b)) @@ -56,7 +44,6 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis messageID := uuid.New().String() mutation := spanner.InsertOrUpdate( -<<<<<<< HEAD MessagesTable, []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, []interface{}{ @@ -70,35 +57,16 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis }, ) -======= - MessagesTable, - []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, - []interface{}{ - messageID, - in.Topic, - in.Payload, - spanner.CommitTimestamp, - spanner.CommitTimestamp, - nil, // Initial visibilityTimeout is NULL - false, // Not processed yet - }, - ) ->>>>>>> origin/kate_branch + _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) if err != nil { log.Printf("Error writing to Spanner: %v", err) return nil, err } -<<<<<<< HEAD // broadcast message bcastin := broadcast.BroadCastInput{ - Type: "message", -======= - // broadcast message - using the correct message Type constant - bcastin := broadCastInput{ - Type: message, ->>>>>>> origin/kate_branch + Type: broadcast.message, Msg: b, } bin, _ := json.Marshal(bcastin) @@ -171,8 +139,8 @@ func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*p return nil, err } // Broadcast delete to all nodes - format matching handleBroadcastedMsg - broadcastData := broadCastInput{ - Type: message, + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, Msg: []byte(fmt.Sprintf("delete:%s", in.Id)), } @@ -193,7 +161,7 @@ func (s *server) ModifyVisibilityTimeout(ctx context.Context, in *pb.ModifyVisib if !ok { return nil, status.Error(codes.NotFound, "message lock not found") } - info := lockInfo.(MessageLockInfo) + info := lockInfo.(broadcast.MessageLockInfo) if !info.Locked { return nil, status.Error(codes.FailedPrecondition, "message not locked") } @@ -202,8 +170,8 @@ func (s *server) ModifyVisibilityTimeout(ctx context.Context, in *pb.ModifyVisib return nil, status.Error(codes.PermissionDenied, "only the original lock holder can extend timeout") } // Broadcast new timeout - format matching handleBroadcastedMsg - broadcastData := broadCastInput{ - Type: message, + broadcastData := BroadCastInput{ + Type: broadcast.msgEvent, Msg: []byte(fmt.Sprintf("extend:%s:%d", in.Id, in.NewTimeout)), } bin, _ := json.Marshal(broadcastData) From 2020f23ee0bcad7d71e252753da6b56b9763236a Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy Date: Wed, 26 Feb 2025 12:10:29 +0800 Subject: [PATCH 14/21] =?UTF-8?q?change=20logic,Unlocking=20and=20extendin?= =?UTF-8?q?g=20timeouts=20=E2=86=92from=20=20Handled=20only=20by=20the=20l?= =?UTF-8?q?eader=20to=20all=20nodes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app.go | 23 ----- app/app.go | 2 + broadcast/broadcast.go | 164 ++++++++++++++++++++---------- go.mod | 4 +- go.sum | 32 +++--- helpers.go | 223 ++++++++++++++++++++++------------------- server.go | 75 +++++++------- 7 files changed, 293 insertions(+), 230 deletions(-) delete mode 100644 app.go diff --git a/app.go b/app.go deleted file mode 100644 index d400ecf8..00000000 --- a/app.go +++ /dev/null @@ -1,23 +0,0 @@ -package main - -import ( - "cloud.google.com/go/spanner" - "github.com/flowerinthenight/hedge/v2" - "sync" -) - -type PubSub struct { - Op *hedge.Op - Client *spanner.Client - // Message handling - messageLocks sync.Map // messageID -> MessageLockInfo - messageQueue sync.Map // topic -> []*pb.Message - subscriptions sync.Map // subscriptionID -> *pb.Subscription - - // Timer tracking - timeoutTimers sync.Map // messageID -> *time.Timer -<<<<<<< HEAD - storage *storage.Storage // jansen's storage -======= ->>>>>>> origin/kate_branch -} diff --git a/app/app.go b/app/app.go index 77f472f8..5df8f645 100644 --- a/app/app.go +++ b/app/app.go @@ -1,9 +1,11 @@ +//app.go package app import ( "cloud.google.com/go/spanner" storage "github.com/alphauslabs/pubsub/storage" "github.com/flowerinthenight/hedge/v2" + "sync" ) type PubSub struct { diff --git a/broadcast/broadcast.go b/broadcast/broadcast.go index 73af5b0f..8dc8fd5e 100644 --- a/broadcast/broadcast.go +++ b/broadcast/broadcast.go @@ -1,3 +1,4 @@ +//broadcast.go package broadcast import ( @@ -22,6 +23,7 @@ const ( unlockMsg = "unlock" deleteMsg = "delete" extendMsg = "extend" + retryMsg = "retry" ) type BroadCastInput struct { @@ -29,6 +31,8 @@ type BroadCastInput struct { Msg []byte } +// MessageLockInfo defines lock information structure +// Note: This should be consistent with the structure in helpers.go type MessageLockInfo struct { Locked bool Timeout time.Time @@ -108,47 +112,82 @@ func handleMessageEvent(app *app.PubSub, msg []byte) ([]byte, error) { // Message event handlers func handleLockMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { - if len(params) < 2 { - return nil, fmt.Errorf("invalid lock parameters") + if len(params) < 3 { + return nil, fmt.Errorf("invalid lock parameters") } - + timeoutSeconds, err := strconv.Atoi(params[0]) if err != nil { - return nil, err + return nil, err } subscriberID := params[1] - + requestingNodeID := params[2] + app.Mutex.Lock() defer app.Mutex.Unlock() - - if _, exists := app.MessageLocks.Load(messageID); exists { - return nil, nil // Already locked + + // Check if already locked + if existingLock, exists := app.MessageLocks.Load(messageID); exists { + info := existingLock.(MessageLockInfo) + + // If lock is expired, allow new lock + if time.Now().After(info.Timeout) { + // Continue with new lock + } else if info.NodeID == requestingNodeID { + // Same node is refreshing its lock, allow it + info.LockHolders[app.NodeID] = true + app.MessageLocks.Store(messageID, info) + return nil, nil + } else { + // Different node has a valid lock, reject + return nil, fmt.Errorf("message already locked by another node") + } } - + + // Create new lock lockInfo := MessageLockInfo{ - Locked: true, - Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), - NodeID: app.NodeID, - SubscriberID: subscriberID, - LockHolders: make(map[string]bool), + Locked: true, + Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), + NodeID: requestingNodeID, + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), } + + // Mark this node as acknowledging the lock + lockInfo.LockHolders[app.NodeID] = true + app.MessageLocks.Store(messageID, lockInfo) - + return nil, nil - } +} - func handleUnlockMsg(app *app.PubSub, messageID string, _ []string) ([]byte, error) { - if !app.IsLeader() { - return nil, nil // Only leader should handle unlocks - } - - app.Mutex.Lock() - defer app.Mutex.Unlock() - - app.MessageLocks.Delete(messageID) - return nil, nil - } +func handleUnlockMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { + if len(params) < 1 { + return nil, fmt.Errorf("invalid unlock parameters") + } + + unlockingNodeID := params[0] + + app.Mutex.Lock() + defer app.Mutex.Unlock() + + // Check if the message is locked + if lockInfo, exists := app.MessageLocks.Load(messageID); exists { + info := lockInfo.(MessageLockInfo) + + // Only the lock owner can unlock + if info.NodeID == unlockingNodeID { + app.MessageLocks.Delete(messageID) + log.Printf("[Unlock] Node %s acknowledged unlock for message: %s", app.NodeID, messageID) + } else { + log.Printf("[Unlock] Rejected unlock from non-owner node %s for message: %s", unlockingNodeID, messageID) + return nil, fmt.Errorf("only lock owner can unlock") + } + } + + return nil, nil +} func handleDeleteMsg(app *app.PubSub, messageID string, _ []string) ([]byte, error) { app.Mutex.Lock() @@ -160,29 +199,50 @@ func handleDeleteMsg(app *app.PubSub, messageID string, _ []string) ([]byte, err } func handleExtendMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { - if len(params) < 1 { - return nil, fmt.Errorf("missing timeout parameter for extend message") - } - - timeoutSeconds, err := strconv.Atoi(params[0]) - if err != nil { - return nil, err - } - - app.Mutex.Lock() - defer app.Mutex.Unlock() - - if lockInfo, ok := app.MessageLocks.Load(messageID); ok { - // Update timeout - info := lockInfo.(app.MessageLockInfo) - info.Timeout = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) - app.MessageLocks.Store(messageID, info) - } - - return nil, nil -} + if len(params) < 2 { + return nil, fmt.Errorf("missing parameters for extend message") + } + + timeoutSeconds, err := strconv.Atoi(params[0]) + if err != nil { + return nil, err + } + + extendingNodeID := params[1] + + app.Mutex.Lock() + defer app.Mutex.Unlock() + + if lockInfo, ok := app.MessageLocks.Load(messageID); ok { + info := lockInfo.(MessageLockInfo) + + // Only update if the request comes from the lock owner + if info.NodeID == extendingNodeID { + info.Timeout = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) + app.MessageLocks.Store(messageID, info) + log.Printf("[Extend] Message %s timeout extended by node %s", messageID, extendingNodeID) + } else { + log.Printf("[Extend] Rejected extend from non-owner node %s for message: %s", extendingNodeID, messageID) + return nil, fmt.Errorf("only lock owner can extend timeout") + } + } + + return nil, nil +} -func handleRetryMsg(app *app.PubSub, messageID string, _ []string) ([]byte, error) { - log.Printf("[Retry] Message %s is now available again", messageID) - return nil, nil - } \ No newline at end of file +func handleRetryMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { + if len(params) < 1 { + return nil, fmt.Errorf("invalid retry parameters") + } + + retryNodeID := params[0] + + // Make the message available again for processing + app.Mutex.Lock() + defer app.Mutex.Unlock() + + app.MessageLocks.Delete(messageID) + log.Printf("[Retry] Message %s is now available again (unlocked by node %s)", messageID, retryNodeID) + + return nil, nil +} \ No newline at end of file diff --git a/go.mod b/go.mod index bcc8abe5..146d2ccb 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ toolchain go1.24.0 require ( cloud.google.com/go/spanner v1.75.0 - github.com/alphauslabs/pubsub-proto v0.0.0-20250221062210-631fa96417c7 + github.com/alphauslabs/pubsub-proto v0.0.0-20250224043151-d2fff9627a86 github.com/flowerinthenight/hedge/v2 v2.0.1 github.com/google/uuid v1.6.0 google.golang.org/api v0.219.0 @@ -85,4 +85,4 @@ require ( google.golang.org/genproto/googleapis/api v0.0.0-20250127172529-29210b9bc287 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 // indirect google.golang.org/protobuf v1.36.5 // indirect -) \ No newline at end of file +) diff --git a/go.sum b/go.sum index f0c8b423..f223bb8a 100644 --- a/go.sum +++ b/go.sum @@ -38,8 +38,8 @@ cloud.google.com/go v0.104.0/go.mod h1:OO6xxXdJyvuJPcEPBLN9BJPD+jep5G1+2U5B5gkRY cloud.google.com/go v0.105.0/go.mod h1:PrLgOJNe5nfE9UMxKxgXj4mD3voiP+YQ6gdt6KMFOKM= cloud.google.com/go v0.107.0/go.mod h1:wpc2eNrD7hXUTy8EKS10jkxpZBjASrORK7goS+3YX2I= cloud.google.com/go v0.110.0/go.mod h1:SJnCLqQ0FCFGSZMUNUf84MV3Aia54kn7pi8st7tMzaY= -cloud.google.com/go v0.118.2 h1:bKXO7RXMFDkniAAvvuMrAPtQ/VHrs9e7J5UT3yrGdTY= -cloud.google.com/go v0.118.2/go.mod h1:CFO4UPEPi8oV21xoezZCrd3d81K4fFkDTEJu4R8K+9M= +cloud.google.com/go v0.118.1 h1:b8RATMcrK9A4BH0rj8yQupPXp+aP+cJ0l6H7V9osV1E= +cloud.google.com/go v0.118.1/go.mod h1:CFO4UPEPi8oV21xoezZCrd3d81K4fFkDTEJu4R8K+9M= cloud.google.com/go/accessapproval v1.4.0/go.mod h1:zybIuC3KpDOvotz59lFe5qxRZx6C75OtwbisN56xYB4= cloud.google.com/go/accessapproval v1.5.0/go.mod h1:HFy3tuiGvMdcd/u+Cu5b9NkO1pEICJ46IR82PoUdplw= cloud.google.com/go/accessapproval v1.6.0/go.mod h1:R0EiYnwV5fsRFiKZkPHr6mwyk2wxUJ30nL4j2pcFY2E= @@ -319,8 +319,8 @@ cloud.google.com/go/iam v0.8.0/go.mod h1:lga0/y3iH6CX7sYqypWJ33hf7kkfXJag67naqGE cloud.google.com/go/iam v0.11.0/go.mod h1:9PiLDanza5D+oWFZiH1uG+RnRCfEGKoyl6yo4cgWZGY= cloud.google.com/go/iam v0.12.0/go.mod h1:knyHGviacl11zrtZUoDuYpDgLjvr28sLQaG0YB2GYAY= cloud.google.com/go/iam v0.13.0/go.mod h1:ljOg+rcNfzZ5d6f1nAUJ8ZIxOaZUVoS14bKCtaLZ/D0= -cloud.google.com/go/iam v1.4.0 h1:ZNfy/TYfn2uh/ukvhp783WhnbVluqf/tzOaqVUPlIPA= -cloud.google.com/go/iam v1.4.0/go.mod h1:gMBgqPaERlriaOV0CUl//XUzDhSfXevn4OEUbg6VRs4= +cloud.google.com/go/iam v1.3.1 h1:KFf8SaT71yYq+sQtRISn90Gyhyf4X8RGgeAVC8XGf3E= +cloud.google.com/go/iam v1.3.1/go.mod h1:3wMtuyT4NcbnYNPLMBzYRFiEfjKfJlLVLrisE7bwm34= cloud.google.com/go/iap v1.4.0/go.mod h1:RGFwRJdihTINIe4wZ2iCP0zF/qu18ZwyKxrhMhygBEc= cloud.google.com/go/iap v1.5.0/go.mod h1:UH/CGgKd4KyohZL5Pt0jSKE4m3FR51qg6FKQ/z/Ix9A= cloud.google.com/go/iap v1.6.0/go.mod h1:NSuvI9C/j7UdjGjIde7t7HBz+QTwBcapPE07+sSRcLk= @@ -378,8 +378,8 @@ cloud.google.com/go/monitoring v1.7.0/go.mod h1:HpYse6kkGo//7p6sT0wsIC6IBDET0RhI cloud.google.com/go/monitoring v1.8.0/go.mod h1:E7PtoMJ1kQXWxPjB6mv2fhC5/15jInuulFdYYtlcvT4= cloud.google.com/go/monitoring v1.12.0/go.mod h1:yx8Jj2fZNEkL/GYZyTLS4ZtZEZN8WtDEiEqG4kLK50w= cloud.google.com/go/monitoring v1.13.0/go.mod h1:k2yMBAB1H9JT/QETjNkgdCGD9bPF712XiLTVr+cBrpw= -cloud.google.com/go/monitoring v1.24.0 h1:csSKiCJ+WVRgNkRzzz3BPoGjFhjPY23ZTcaenToJxMM= -cloud.google.com/go/monitoring v1.24.0/go.mod h1:Bd1PRK5bmQBQNnuGwHBfUamAV1ys9049oEPHnn4pcsc= +cloud.google.com/go/monitoring v1.23.0 h1:M3nXww2gn9oZ/qWN2bZ35CjolnVHM3qnSbu6srCPgjk= +cloud.google.com/go/monitoring v1.23.0/go.mod h1:034NnlQPDzrQ64G2Gavhl0LUHZs9H3rRmhtnp7jiJgg= cloud.google.com/go/networkconnectivity v1.4.0/go.mod h1:nOl7YL8odKyAOtzNX73/M5/mGZgqqMeryi6UPZTk/rA= cloud.google.com/go/networkconnectivity v1.5.0/go.mod h1:3GzqJx7uhtlM3kln0+x5wyFvuVH1pIBJjhCpjzSt75o= cloud.google.com/go/networkconnectivity v1.6.0/go.mod h1:OJOoEXW+0LAxHh89nXd64uGG+FbQoeH8DtxCHVOMlaM= @@ -526,8 +526,8 @@ cloud.google.com/go/shell v1.6.0/go.mod h1:oHO8QACS90luWgxP3N9iZVuEiSF84zNyLytb+ cloud.google.com/go/spanner v1.41.0/go.mod h1:MLYDBJR/dY4Wt7ZaMIQ7rXOTLjYrmxLE/5ve9vFfWos= cloud.google.com/go/spanner v1.44.0/go.mod h1:G8XIgYdOK+Fbcpbs7p2fiprDw4CaZX63whnSMLVBxjk= cloud.google.com/go/spanner v1.45.0/go.mod h1:FIws5LowYz8YAE1J8fOS7DJup8ff7xJeetWEo5REA2M= -cloud.google.com/go/spanner v1.76.1 h1:vYbVZuXfnFwvNcvH3lhI2PeUA+kHyqKmLC7mJWaC4Ok= -cloud.google.com/go/spanner v1.76.1/go.mod h1:YtwoE+zObKY7+ZeDCBtZ2ukM+1/iPaMfUM+KnTh/sx0= +cloud.google.com/go/spanner v1.75.0 h1:2zrltTJv/4P3pCgpYgde4Eb1vN8Cgy1fNy7pbTnOovg= +cloud.google.com/go/spanner v1.75.0/go.mod h1:TLFZBvPQmx3We7sGh12eTk9lLsRLczzZaiweqfMpR80= cloud.google.com/go/speech v1.6.0/go.mod h1:79tcr4FHCimOp56lwC01xnt/WPJZc4v3gzyT7FoBkCM= cloud.google.com/go/speech v1.7.0/go.mod h1:KptqL+BAQIhMsj1kOP2la5DSEEerPDuOP/2mmkhHhZQ= cloud.google.com/go/speech v1.8.0/go.mod h1:9bYIl1/tjsAnMgKGHKmBZzXKEkGgtU+MpdDPTE9f7y0= @@ -1210,8 +1210,8 @@ golang.org/x/oauth2 v0.4.0/go.mod h1:RznEsdpjGAINPTOF0UH/t+xJ75L18YO3Ho6Pyn+uRec golang.org/x/oauth2 v0.5.0/go.mod h1:9/XBHVqLaWO3/BRHs5jbpYCnOZVjj5V0ndyaAM7KB4I= golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4= -golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= -golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1348,8 +1348,8 @@ golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20220922220347-f3bd1da661af/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= -golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -1490,8 +1490,8 @@ google.golang.org/api v0.108.0/go.mod h1:2Ts0XTHNVWxypznxWOYUeI4g3WdP9Pk2Qk58+a/ google.golang.org/api v0.110.0/go.mod h1:7FC4Vvx1Mooxh8C5HWjzZHcavuS2f6pmJpZx60ca7iI= google.golang.org/api v0.111.0/go.mod h1:qtFHvU9mhgTJegR31csQ+rwxyUTHOKFqCKWp1J0fdw0= google.golang.org/api v0.114.0/go.mod h1:ifYI2ZsFK6/uGddGfAD5BMxlnkBqCmqHSDUVi45N5Yg= -google.golang.org/api v0.222.0 h1:Aiewy7BKLCuq6cUCeOUrsAlzjXPqBkEeQ/iwGHVQa/4= -google.golang.org/api v0.222.0/go.mod h1:efZia3nXpWELrwMlN5vyQrD4GmJN1Vw0x68Et3r+a9c= +google.golang.org/api v0.219.0 h1:nnKIvxKs/06jWawp2liznTBnMRQBEPpGo7I+oEypTX0= +google.golang.org/api v0.219.0/go.mod h1:K6OmjGm+NtLrIkHxv1U3a0qIf/0JOvAHd5O/6AoyKYE= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -1633,8 +1633,8 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= google.golang.org/genproto v0.0.0-20250127172529-29210b9bc287 h1:WoUI1G0DQ648FKvSl756SKxHQR/bI+y4HyyIQfxMWI8= google.golang.org/genproto v0.0.0-20250127172529-29210b9bc287/go.mod h1:wkQ2Aj/xvshAUDtO/JHvu9y+AaN9cqs28QuSVSHtZSY= -google.golang.org/genproto/googleapis/api v0.0.0-20250219182151-9fdb1cabc7b2 h1:35ZFtrCgaAjF7AFAK0+lRSf+4AyYnWRbH7og13p7rZ4= -google.golang.org/genproto/googleapis/api v0.0.0-20250219182151-9fdb1cabc7b2/go.mod h1:W9ynFDP/shebLB1Hl/ESTOap2jHd6pmLXPNZC7SVDbA= +google.golang.org/genproto/googleapis/api v0.0.0-20250127172529-29210b9bc287 h1:A2ni10G3UlplFrWdCDJTl7D7mJ7GSRm37S+PDimaKRw= +google.golang.org/genproto/googleapis/api v0.0.0-20250127172529-29210b9bc287/go.mod h1:iYONQfRdizDB8JJBybql13nArx91jcUk7zCXEsOofM4= google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 h1:DMTIbak9GhdaSxEjvVzAeNZvyc03I61duqNbnm3SU0M= google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2/go.mod h1:LuRYeWDFV6WOn90g357N17oMCaxpgCnbi/44qJvDn2I= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= diff --git a/helpers.go b/helpers.go index 9a01708f..f855bd14 100644 --- a/helpers.go +++ b/helpers.go @@ -1,3 +1,4 @@ +//helpers.go package main import ( @@ -16,14 +17,6 @@ import ( "github.com/alphauslabs/pubsub/broadcast" ) -// // MessageLockInfo tracks lock state across nodes -// type MessageLockInfo struct { -// Timeout time.Time -// Locked bool -// NodeID string -// SubscriberID string // Added to track which subscriber has the lock -// LockHolders map[string]bool // Track which nodes have acknowledged the lock -// } // validateTopicSubscription checks if subscription exists in memory func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { @@ -57,81 +50,98 @@ func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscript // broadcastLock sends lock request to all nodes and waits for acknowledgment func (s *server) broadcastLock(ctx context.Context, messageID string, subscriberID string, timeout time.Duration) error { - lockInfo := broadcast.MessageLockInfo{ - Timeout: time.Now().Add(timeout), - Locked: true, - NodeID: s.Op.ID(), - SubscriberID: subscriberID, - LockHolders: make(map[string]bool), - } - // Store initial lock info before broadcasting - _, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo) - if loaded { - return fmt.Errorf("message already locked by another node") - } - // Broadcast lock request - format matching handleBroadcastedMsg - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("lock:%s:%d:%s", messageID, int(timeout.Seconds()), subscriberID)), - } - - bin, _ := json.Marshal(broadcastData) - out := s.Op.Broadcast(ctx, bin) - // Ensure majority of nodes acknowledged - successCount := 0 - for _, v := range out { - if v.Error == nil { - successCount++ - } - } - if successCount < (len(out)/2 + 1) { - s.messageLocks.Delete(messageID) - return fmt.Errorf("failed to acquire lock across majority of nodes") - } - // Start local timeout timer - timer := time.NewTimer(timeout) - s.timeoutTimers.Store(messageID, timer) - - go func() { - <-timer.C - s.handleMessageTimeout(messageID) - }() - return nil + lockInfo := broadcast.MessageLockInfo{ + Timeout: time.Now().Add(timeout), + Locked: true, + NodeID: s.Op.ID(), + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), + } + + // Add this node as a lock holder + lockInfo.LockHolders[s.Op.ID()] = true + + // Check if already locked by this node + if val, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo); loaded { + existingInfo := val.(broadcast.MessageLockInfo) + if existingInfo.NodeID != s.Op.ID() { + return fmt.Errorf("message already locked by another node") + } + // Already locked by this node, just return success + return nil + } + + // Broadcast lock request to all nodes + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, + Msg: []byte(fmt.Sprintf("lock:%s:%d:%s:%s", messageID, int(timeout.Seconds()), subscriberID, s.Op.ID())), + } + + bin, _ := json.Marshal(broadcastData) + out := s.Op.Broadcast(ctx, bin) + + // Track which nodes acknowledged the lock + successCount := 1 // Include self + for i, v := range out { + if v.Error == nil { + successCount++ + // Track which node acknowledged + lockInfo.LockHolders[fmt.Sprintf("node-%d", i)] = true + } + } + + // Need majority for consensus + if successCount < (len(out)/2 + 1) { + s.messageLocks.Delete(messageID) + return fmt.Errorf("failed to acquire lock across majority of nodes") + } + + // Update lock info with acknowledgments + s.messageLocks.Store(messageID, lockInfo) + + // Start local timeout timer + timer := time.NewTimer(timeout) + s.timeoutTimers.Store(messageID, timer) + + go func() { + <-timer.C + s.handleMessageTimeout(messageID) + }() + + return nil } // handleMessageTimeout ensures that if a node crashes while holding a lock, // other nodes can unlock the message and allow it to be processed again. func (s *server) handleMessageTimeout(messageID string) { - if lockInfo, ok := s.messageLocks.Load(messageID); ok { - info := lockInfo.(broadcast.MessageLockInfo) - if info.Locked && time.Now().After(info.Timeout) { - log.Printf("[Timeout] Unlocking expired message: %s", messageID) - // Broadcast unlock - s.broadcastUnlock(context.Background(), messageID) - // Remove lock entry - s.messageLocks.Delete(messageID) - // Notify all nodes to retry processing this message - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("retry:%s", messageID)), - } - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(context.Background(), bin) - } - } + if lockInfo, ok := s.messageLocks.Load(messageID); ok { + info := lockInfo.(broadcast.MessageLockInfo) + + // Only unlock if this node is the lock owner + if info.NodeID == s.Op.ID() && info.Locked && time.Now().After(info.Timeout) { + log.Printf("[Timeout] Node %s unlocking expired message: %s", s.Op.ID(), messageID) + + // Broadcast unlock + s.broadcastUnlock(context.Background(), messageID) + + // Notify all nodes to retry processing this message + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, + Msg: []byte(fmt.Sprintf("retry:%s:%s", messageID, s.Op.ID())), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(context.Background(), bin) + } + } } // broadcastUnlock ensures that only the leader node is responsible for broadcasting unlock requests func (s *server) broadcastUnlock(ctx context.Context, messageID string) { - // Ensure only the leader sends the unlock request - if !s.IsLeader() { - log.Printf("[Unlock] Skipping unlock broadcast. Only the leader handles unlocks.") - return - } - // Format matching handleBroadcastedMsg + + // Any node can broadcast an unlock broadcastData := broadcast.BroadCastInput{ Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("unlock:%s", messageID)), + Msg: []byte(fmt.Sprintf("unlock:%s:%s", messageID, s.Op.ID())), } bin, _ := json.Marshal(broadcastData) s.Op.Broadcast(ctx, bin) @@ -141,13 +151,9 @@ func (s *server) broadcastUnlock(ctx context.Context, messageID string) { timer.(*time.Timer).Stop() s.timeoutTimers.Delete(messageID) } - log.Printf("[Unlock] Leader node unlocked message: %s", messageID) + log.Printf("[Unlock] Node %s unlocked message: %s", s.Op.ID(), messageID) } -// IsLeader checks if the current node is the leader node in the pub/sub system -func (s *server) IsLeader() bool { - return s.Op.ID() == s.Op.GetLeaderID() // Compare current node ID with leader ID -} // requestMessageFromLeader asks the leader node for messages func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { @@ -177,32 +183,43 @@ func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { // ExtendVisibilityTimeout extends the visibility timeout for a message func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, visibilityTimeout time.Duration) error { - // Non-leader nodes should not modify state directly - if !s.IsLeader() { - return status.Error(codes.PermissionDenied, "only the leader can extend visibility timeout") - } - value, exists := s.messageLocks.Load(messageID) - if !exists { - return status.Error(codes.NotFound, "message not locked") - } - info, ok := value.(broadcast.MessageLockInfo) - if !ok || info.SubscriberID != subscriberID { - return status.Error(codes.PermissionDenied, "message locked by another subscriber") - } - // Leader extends visibility timeout - newExpiresAt := time.Now().Add(visibilityTimeout) - info.Timeout = newExpiresAt - s.messageLocks.Store(messageID, info) - // Create broadcast message - format matching handleBroadcastedMsg - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("extend:%s:%d", messageID, int(visibilityTimeout.Seconds()))), - } - msgBytes, _ := json.Marshal(broadcastData) - // Leader broadcasts the new timeout - s.Op.Broadcast(context.TODO(), msgBytes) - log.Printf("[ExtendTimeout] Leader approved timeout extension for message: %s", messageID) - return nil + value, exists := s.messageLocks.Load(messageID) + if !exists { + return status.Error(codes.NotFound, "message not locked") + } + + info, ok := value.(broadcast.MessageLockInfo) + if !ok { + return status.Error(codes.Internal, "invalid lock info") + } + + // Check if this node owns the lock + if info.NodeID != s.Op.ID() { + return status.Error(codes.PermissionDenied, "only the lock owner can extend timeout") + } + + // Check subscriber ID + if info.SubscriberID != subscriberID { + return status.Error(codes.PermissionDenied, "message locked by another subscriber") + } + + // Extend visibility timeout + newExpiresAt := time.Now().Add(visibilityTimeout) + info.Timeout = newExpiresAt + s.messageLocks.Store(messageID, info) + + // Create broadcast message + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, + Msg: []byte(fmt.Sprintf("extend:%s:%d:%s", messageID, int(visibilityTimeout.Seconds()), s.Op.ID())), + } + msgBytes, _ := json.Marshal(broadcastData) + + // Broadcast new timeout to all nodes + s.Op.Broadcast(context.TODO(), msgBytes) + log.Printf("[ExtendTimeout] Node %s extended timeout for message: %s", s.Op.ID(), messageID) + + return nil } // HandleBroadcastMessage processes broadcast messages received from other nodes diff --git a/server.go b/server.go index 624faa94..c2743867 100644 --- a/server.go +++ b/server.go @@ -1,3 +1,4 @@ +//server.go package main import ( @@ -124,7 +125,7 @@ func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*p if !ok { return nil, status.Error(codes.NotFound, "message lock not found") } - info := lockInfo.(MessageLockInfo) + info := lockInfo.(broadcast.MessageLockInfo) if !info.Locked || time.Now().After(info.Timeout) { return nil, status.Error(codes.FailedPrecondition, "message lock expired") } @@ -157,37 +158,43 @@ func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*p // ModifyVisibilityTimeout extends message lock timeout func (s *server) ModifyVisibilityTimeout(ctx context.Context, in *pb.ModifyVisibilityTimeoutRequest) (*pb.ModifyVisibilityTimeoutResponse, error) { - lockInfo, ok := s.messageLocks.Load(in.Id) - if !ok { - return nil, status.Error(codes.NotFound, "message lock not found") - } - info := lockInfo.(broadcast.MessageLockInfo) - if !info.Locked { - return nil, status.Error(codes.FailedPrecondition, "message not locked") - } - // Ensure the same node is extending the lock - if info.NodeID != s.Op.ID() { - return nil, status.Error(codes.PermissionDenied, "only the original lock holder can extend timeout") - } - // Broadcast new timeout - format matching handleBroadcastedMsg - broadcastData := BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("extend:%s:%d", in.Id, in.NewTimeout)), - } - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(ctx, bin) - // Update local timer - if timer, ok := s.timeoutTimers.Load(in.Id); ok { - timer.(*time.Timer).Stop() - } - newTimer := time.NewTimer(time.Duration(in.NewTimeout) * time.Second) - s.timeoutTimers.Store(in.Id, newTimer) - // Update lock info - info.Timeout = time.Now().Add(time.Duration(in.NewTimeout) * time.Second) - s.messageLocks.Store(in.Id, info) - go func() { - <-newTimer.C - s.handleMessageTimeout(in.Id) - }() - return &pb.ModifyVisibilityTimeoutResponse{Success: true}, nil + lockInfo, ok := s.messageLocks.Load(in.Id) + if !ok { + return nil, status.Error(codes.NotFound, "message lock not found") + } + info := lockInfo.(broadcast.MessageLockInfo) + if !info.Locked { + return nil, status.Error(codes.FailedPrecondition, "message not locked") + } + + // Check if this node owns the lock before extending + if info.NodeID != s.Op.ID() { + return nil, status.Error(codes.PermissionDenied, "only the lock owner can extend timeout") + } + + // Broadcast new timeout + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, + Msg: []byte(fmt.Sprintf("extend:%s:%d:%s", in.Id, in.NewTimeout, s.Op.ID())), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + + // Update local timer + if timer, ok := s.timeoutTimers.Load(in.Id); ok { + timer.(*time.Timer).Stop() + } + newTimer := time.NewTimer(time.Duration(in.NewTimeout) * time.Second) + s.timeoutTimers.Store(in.Id, newTimer) + + // Update lock info + info.Timeout = time.Now().Add(time.Duration(in.NewTimeout) * time.Second) + s.messageLocks.Store(in.Id, info) + + go func() { + <-newTimer.C + s.handleMessageTimeout(in.Id) + }() + + return &pb.ModifyVisibilityTimeoutResponse{Success: true}, nil } From 1821064d91e661e12f64bb9c0afa5179ac5ae12b Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy Date: Wed, 26 Feb 2025 13:56:45 +0800 Subject: [PATCH 15/21] using the storage.go? --- app/app.go | 1 - broadcast.go | 66 ----------------- broadcast/broadcast.go | 10 +-- helpers.go | 107 ++++++++++----------------- server.go | 163 ++++++++++++++++++++++++----------------- 5 files changed, 140 insertions(+), 207 deletions(-) delete mode 100644 broadcast.go diff --git a/app/app.go b/app/app.go index 5df8f645..ed453191 100644 --- a/app/app.go +++ b/app/app.go @@ -1,4 +1,3 @@ -//app.go package app import ( diff --git a/broadcast.go b/broadcast.go deleted file mode 100644 index 344ae8dc..00000000 --- a/broadcast.go +++ /dev/null @@ -1,66 +0,0 @@ -package main - -import ( - "encoding/json" -) - -const ( - message = "message" - topicsub = "topicsub" -) - -type broadCastInput struct { - Type string - Msg []byte -} - -var ctrlbroadcast = map[string]func(*PubSub, []byte) ([]byte, error){ - message: handleBroadcastedMsg, - topicsub: handleBroadcastedTopicsub, -} - -// Root handler for op.Broadcast() // do not change this -func broadcast(data any, msg []byte) ([]byte, error) { - var in broadCastInput - app := data.(*PubSub) - if err := json.Unmarshal(msg, &in); err != nil { - return nil, err - } - return ctrlbroadcast[in.Type](app, in.Msg) -} - -func handleBroadcastedMsg(app *PubSub, msg []byte) ([]byte, error) { - parts := strings.Split(string(msg), ":") - switch parts[0] { - case "lock": - //if a node receives a "lock" request for a message it already has locked, it should reject duplicate locks. - messageID := parts[1] - if _, exists := app.messageLocks.Load(messageID); exists { - return nil, nil // Already locked, ignore duplicate - } - case "unlock": - // Handle unlock request - messageID := parts[1] - app.messageLocks.Delete(messageID) - // Clean up locks and timers - case "delete": - messageID := parts[1] - app.messageLocks.Delete(messageID) - app.messageQueue.Delete(messageID) - case "extend": - // Handle timeout extension - messageID := parts[1] - newTimeout, _ := strconv.Atoi(parts[2]) - if lockInfo, ok := app.messageLocks.Load(messageID); ok { - info := lockInfo.(MessageLockInfo) - info.Timeout = time.Now().Add(time.Duration(newTimeout) * time.Second) - app.messageLocks.Store(messageID, info) - } - // Update timeout and reset timer - } - return nil, nil -} - -func handleBroadcastedTopicsub(app *PubSub, msg []byte) ([]byte, error) { - return nil, nil -} diff --git a/broadcast/broadcast.go b/broadcast/broadcast.go index 8ff4bec8..99dfc4f8 100644 --- a/broadcast/broadcast.go +++ b/broadcast/broadcast.go @@ -1,4 +1,3 @@ -//broadcast.go package broadcast import ( @@ -58,7 +57,7 @@ var in BroadCastInput } func handleBroadcastedMsg(app *app.PubSub, msg []byte) ([]byte, error) { - log.Println("[BROADCAST]: Received message:\n", string(msg)) + log.Println("Received message:\n", string(msg)) var message pb.Message if err := json.Unmarshal(msg, &message); err != nil { return nil, fmt.Errorf("failed to unmarshal message: %w", err) @@ -73,7 +72,7 @@ func handleBroadcastedMsg(app *app.PubSub, msg []byte) ([]byte, error) { // Handles topic-subscription updates func handleBroadcastedTopicsub(app *app.PubSub, msg []byte) ([]byte, error) { - log.Println("[BROADCAST]: Received topic-subscriptions:\n", string(msg)) + log.Println("Received topic-subscriptions:\n", string(msg)) if err := app.Storage.StoreTopicSubscriptions(msg); err != nil { return nil, fmt.Errorf("failed to store topic-subscriptions: %w", err) } @@ -143,8 +142,9 @@ func handleLockMsg(app *app.PubSub, messageID string, params []string) ([]byte, return nil, fmt.Errorf("message already locked by another node") } } - - // Create new lock + + // Each node maintains its own timer + // Create new lock lockInfo := MessageLockInfo{ Locked: true, Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), diff --git a/helpers.go b/helpers.go index f855bd14..e2eca5e1 100644 --- a/helpers.go +++ b/helpers.go @@ -18,37 +18,29 @@ import ( ) -// validateTopicSubscription checks if subscription exists in memory -func (s *server) validateTopicSubscription(subscriptionID string) (*pb.Subscription, error) { - if val, ok := s.subscriptions.Load(subscriptionID); ok { - return val.(*pb.Subscription), nil - } - // Request subscription details from the leader - using topicsub message type - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.topicsub, - Msg: []byte(fmt.Sprintf("get:%s", subscriptionID)), - } - bin, _ := json.Marshal(broadcastData) - resp, err := s.Op.Request(context.Background(), bin) - if err != nil { - return nil, status.Error(codes.Internal, "failed to request subscription from leader") - } +// validateTopicSubscription checks if subscription exists in storage +func (s *server) validateTopicSubscription(subscriptionID string) error { + subs, err := s.Storage.GetSubscribtionsForTopic(subscriptionID) + if err != nil { + return status.Errorf(codes.NotFound, "subscription not found") + } - var subscription pb.Subscription - if err := json.Unmarshal(resp, &subscription); err != nil { - return nil, status.Error(codes.Internal, "failed to parse subscription response") - } + found := false + for _, sub := range subs { + if sub == subscriptionID { + found = true + break + } + } - if subscription.Id == "" { - return nil, status.Error(codes.NotFound, "subscription not found") - } + if !found { + return status.Errorf(codes.NotFound, "subscription not found") + } - // Store it in memory to prevent duplicate lookups - s.subscriptions.Store(subscriptionID, &subscription) - return &subscription, nil + return nil } -// broadcastLock sends lock request to all nodes and waits for acknowledgment +// / broadcastLock handles distributed locking func (s *server) broadcastLock(ctx context.Context, messageID string, subscriberID string, timeout time.Duration) error { lockInfo := broadcast.MessageLockInfo{ Timeout: time.Now().Add(timeout), @@ -57,80 +49,57 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber SubscriberID: subscriberID, LockHolders: make(map[string]bool), } - - // Add this node as a lock holder - lockInfo.LockHolders[s.Op.ID()] = true - - // Check if already locked by this node - if val, loaded := s.messageLocks.LoadOrStore(messageID, lockInfo); loaded { - existingInfo := val.(broadcast.MessageLockInfo) - if existingInfo.NodeID != s.Op.ID() { - return fmt.Errorf("message already locked by another node") - } - // Already locked by this node, just return success - return nil + + // Check if message exists in storage + _, err := s.Storage.GetMessage(messageID) + if err != nil { + return err } - - // Broadcast lock request to all nodes + + // Store lock information + lockInfo.LockHolders[s.Op.ID()] = true + s.messageLocks.Store(messageID, lockInfo) + + // Broadcast lock request broadcastData := broadcast.BroadCastInput{ Type: broadcast.msgEvent, Msg: []byte(fmt.Sprintf("lock:%s:%d:%s:%s", messageID, int(timeout.Seconds()), subscriberID, s.Op.ID())), } - + bin, _ := json.Marshal(broadcastData) out := s.Op.Broadcast(ctx, bin) - - // Track which nodes acknowledged the lock + + // Track acknowledgments successCount := 1 // Include self - for i, v := range out { + for _, v := range out { if v.Error == nil { successCount++ - // Track which node acknowledged - lockInfo.LockHolders[fmt.Sprintf("node-%d", i)] = true } } - + // Need majority for consensus if successCount < (len(out)/2 + 1) { s.messageLocks.Delete(messageID) return fmt.Errorf("failed to acquire lock across majority of nodes") } - - // Update lock info with acknowledgments - s.messageLocks.Store(messageID, lockInfo) - - // Start local timeout timer + + // Start timeout timer timer := time.NewTimer(timeout) s.timeoutTimers.Store(messageID, timer) - + go func() { <-timer.C s.handleMessageTimeout(messageID) }() - + return nil } -// handleMessageTimeout ensures that if a node crashes while holding a lock, -// other nodes can unlock the message and allow it to be processed again. func (s *server) handleMessageTimeout(messageID string) { if lockInfo, ok := s.messageLocks.Load(messageID); ok { info := lockInfo.(broadcast.MessageLockInfo) - - // Only unlock if this node is the lock owner if info.NodeID == s.Op.ID() && info.Locked && time.Now().After(info.Timeout) { - log.Printf("[Timeout] Node %s unlocking expired message: %s", s.Op.ID(), messageID) - - // Broadcast unlock s.broadcastUnlock(context.Background(), messageID) - - // Notify all nodes to retry processing this message - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("retry:%s:%s", messageID, s.Op.ID())), - } - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(context.Background(), bin) } } } diff --git a/server.go b/server.go index c2743867..02ca1908 100644 --- a/server.go +++ b/server.go @@ -81,79 +81,110 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis return &pb.PublishResponse{MessageId: messageID}, nil } +// Subscribe to receive messages for a subscription // Subscribe to receive messages for a subscription func (s *server) Subscribe(in *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { - // Validate subscription in memory first - subscription, err := s.validateTopicSubscription(in.SubscriptionId) - if err != nil { - return err - } - log.Printf("[Subscribe] Starting subscription stream for ID: %s", in.SubscriptionId) - for { - select { - case <-stream.Context().Done(): - return nil - default: - // Request message from the leader instead of querying directly - message, err := s.requestMessageFromLeader(subscription.TopicId) - if err != nil { - log.Printf("[Subscribe] No available messages for subscription: %s", in.SubscriptionId) - time.Sleep(time.Second) // Prevent CPU overuse - continue - } - // Ensure it's not already locked by another node - if _, exists := s.messageLocks.Load(message.Id); exists { - continue // Skip locked messages - } - // Try to acquire distributed lock - if err := s.broadcastLock(stream.Context(), message.Id, in.SubscriptionId, 30*time.Second); err != nil { - continue - } - // Send message to subscriber - if err := stream.Send(message); err != nil { - s.broadcastUnlock(stream.Context(), message.Id) - return err - } - } - } + // Validate subscription using storage + subs, err := s.Storage.GetSubscribtionsForTopic(in.TopicId) + if err != nil { + return status.Errorf(codes.NotFound, "Topic %s not found", in.TopicId) + } + + found := false + for _, sub := range subs { + if sub == in.SubscriptionId { + found = true + break + } + } + + if !found { + return status.Errorf(codes.NotFound, "Subscription %s not found", in.SubscriptionId) + } + + log.Printf("[Subscribe] Starting subscription stream for ID: %s", in.SubscriptionId) + + for { + select { + case <-stream.Context().Done(): + return nil + default: + // Get messages from storage + messages, err := s.Storage.GetMessagesByTopic(in.TopicId) + if err != nil { + log.Printf("[Subscribe] Error getting messages: %v", err) + time.Sleep(time.Second) + continue + } + + if len(messages) == 0 { + time.Sleep(time.Second) + continue + } + + for _, message := range messages { + // Check if message is locked + if _, exists := s.messageLocks.Load(message.Id); exists { + continue + } + + // Try to acquire lock + if err := s.broadcastLock(stream.Context(), message.Id, in.SubscriptionId, 30*time.Second); err != nil { + continue + } + + // Send message + if err := stream.Send(message); err != nil { + s.broadcastUnlock(stream.Context(), message.Id) + return err + } + } + } + } } // Acknowledge a processed message func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { - // Verify lock exists and is valid - lockInfo, ok := s.messageLocks.Load(in.Id) - if !ok { - return nil, status.Error(codes.NotFound, "message lock not found") - } - info := lockInfo.(broadcast.MessageLockInfo) - if !info.Locked || time.Now().After(info.Timeout) { - return nil, status.Error(codes.FailedPrecondition, "message lock expired") - } - // Update Spanner - mutation := spanner.Update( - MessagesTable, - []string{"id", "processed", "updatedAt"}, - []interface{}{in.Id, true, spanner.CommitTimestamp}, - ) - _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) - if err != nil { - return nil, err - } - // Broadcast delete to all nodes - format matching handleBroadcastedMsg - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("delete:%s", in.Id)), - } + // Verify lock exists and is valid + lockInfo, ok := s.messageLocks.Load(in.Id) + if !ok { + return nil, status.Error(codes.NotFound, "message lock not found") + } - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(ctx, bin) - // Clean up local state - s.messageLocks.Delete(in.Id) - if timer, ok := s.timeoutTimers.Load(in.Id); ok { - timer.(*time.Timer).Stop() - s.timeoutTimers.Delete(in.Id) - } - return &pb.AcknowledgeResponse{Success: true}, nil + info := lockInfo.(broadcast.MessageLockInfo) + if !info.Locked || time.Now().After(info.Timeout) { + return nil, status.Error(codes.FailedPrecondition, "message lock expired") + } + + // Get message from storage + msg, err := s.Storage.GetMessage(in.Id) + if err != nil { + return nil, status.Error(codes.NotFound, "message not found") + } + + // Mark message as processed in storage + msg.Processed = true + if err := s.Storage.StoreMessage(msg); err != nil { + return nil, status.Error(codes.Internal, "failed to update message") + } + + // Broadcast acknowledgment + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, + Msg: []byte(fmt.Sprintf("delete:%s", in.Id)), + } + + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + + // Clean up local state + s.messageLocks.Delete(in.Id) + if timer, ok := s.timeoutTimers.Load(in.Id); ok { + timer.(*time.Timer).Stop() + s.timeoutTimers.Delete(in.Id) + } + + return &pb.AcknowledgeResponse{Success: true}, nil } // ModifyVisibilityTimeout extends message lock timeout From 14d86b1a03eb9acdeea62e80d31fa2d953901ed1 Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy Date: Wed, 26 Feb 2025 16:51:25 +0800 Subject: [PATCH 16/21] debugged helpers.go --- helpers.go | 231 ++++++++++++++++++++++++++--------------------------- main.go | 92 ++++++++++++++++++++- 2 files changed, 203 insertions(+), 120 deletions(-) diff --git a/helpers.go b/helpers.go index e2eca5e1..bf457a3d 100644 --- a/helpers.go +++ b/helpers.go @@ -1,4 +1,4 @@ -//helpers.go +// helpers.go package main import ( @@ -13,95 +13,95 @@ import ( pb "github.com/alphauslabs/pubsub-proto/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "github.com/alphauslabs/pubsub/app" + + // "github.com/alphauslabs/pubsub/app" "github.com/alphauslabs/pubsub/broadcast" ) - // validateTopicSubscription checks if subscription exists in storage func (s *server) validateTopicSubscription(subscriptionID string) error { - subs, err := s.Storage.GetSubscribtionsForTopic(subscriptionID) - if err != nil { - return status.Errorf(codes.NotFound, "subscription not found") - } - - found := false - for _, sub := range subs { - if sub == subscriptionID { - found = true - break - } - } - - if !found { - return status.Errorf(codes.NotFound, "subscription not found") - } - - return nil + subs, err := s.Storage.GetSubscribtionsForTopic(subscriptionID) + if err != nil { + return status.Errorf(codes.NotFound, "subscription not found") + } + + found := false + for _, sub := range subs { + if sub == subscriptionID { + found = true + break + } + } + + if !found { + return status.Errorf(codes.NotFound, "subscription not found") + } + + return nil } // / broadcastLock handles distributed locking func (s *server) broadcastLock(ctx context.Context, messageID string, subscriberID string, timeout time.Duration) error { - lockInfo := broadcast.MessageLockInfo{ - Timeout: time.Now().Add(timeout), - Locked: true, - NodeID: s.Op.ID(), - SubscriberID: subscriberID, - LockHolders: make(map[string]bool), - } - - // Check if message exists in storage - _, err := s.Storage.GetMessage(messageID) - if err != nil { - return err - } - - // Store lock information - lockInfo.LockHolders[s.Op.ID()] = true - s.messageLocks.Store(messageID, lockInfo) - - // Broadcast lock request - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("lock:%s:%d:%s:%s", messageID, int(timeout.Seconds()), subscriberID, s.Op.ID())), - } - - bin, _ := json.Marshal(broadcastData) - out := s.Op.Broadcast(ctx, bin) - - // Track acknowledgments - successCount := 1 // Include self - for _, v := range out { - if v.Error == nil { - successCount++ - } - } - - // Need majority for consensus - if successCount < (len(out)/2 + 1) { - s.messageLocks.Delete(messageID) - return fmt.Errorf("failed to acquire lock across majority of nodes") - } - - // Start timeout timer - timer := time.NewTimer(timeout) - s.timeoutTimers.Store(messageID, timer) - - go func() { - <-timer.C - s.handleMessageTimeout(messageID) - }() - - return nil + lockInfo := broadcast.MessageLockInfo{ + Timeout: time.Now().Add(timeout), + Locked: true, + NodeID: "default-node", + SubscriberID: subscriberID, + LockHolders: make(map[string]bool), + } + + // Check if message exists in storage + _, err := s.Storage.GetMessage(messageID) + if err != nil { + return err + } + + // Store lock information + lockInfo.LockHolders[s.Op.ID()] = true + s.messageLocks.Store(messageID, lockInfo) + + // Broadcast lock request + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, + Msg: []byte(fmt.Sprintf("lock:%s:%d:%s:%s", messageID, int(timeout.Seconds()), subscriberID, s.Op.ID())), + } + + bin, _ := json.Marshal(broadcastData) + out := s.Op.Broadcast(ctx, bin) + + // Track acknowledgments + successCount := 1 // Include self + for _, v := range out { + if v.Error == nil { + successCount++ + } + } + + // Need majority for consensus + if successCount < (len(out)/2 + 1) { + s.messageLocks.Delete(messageID) + return fmt.Errorf("failed to acquire lock across majority of nodes") + } + + // Start timeout timer + timer := time.NewTimer(timeout) + s.timeoutTimers.Store(messageID, timer) + + go func() { + <-timer.C + s.handleMessageTimeout(messageID) + }() + + return nil } func (s *server) handleMessageTimeout(messageID string) { - if lockInfo, ok := s.messageLocks.Load(messageID); ok { - info := lockInfo.(broadcast.MessageLockInfo) - if info.NodeID == s.Op.ID() && info.Locked && time.Now().After(info.Timeout) { - s.broadcastUnlock(context.Background(), messageID) - } - } + if lockInfo, ok := s.messageLocks.Load(messageID); ok { + info := lockInfo.(broadcast.MessageLockInfo) + if info.NodeID == s.Op.ID() && info.Locked && time.Now().After(info.Timeout) { + s.broadcastUnlock(context.Background(), messageID) + } + } } // broadcastUnlock ensures that only the leader node is responsible for broadcasting unlock requests @@ -123,7 +123,6 @@ func (s *server) broadcastUnlock(ctx context.Context, messageID string) { log.Printf("[Unlock] Node %s unlocked message: %s", s.Op.ID(), messageID) } - // requestMessageFromLeader asks the leader node for messages func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { // Use the message type with proper format for requesting a message @@ -152,43 +151,43 @@ func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { // ExtendVisibilityTimeout extends the visibility timeout for a message func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, visibilityTimeout time.Duration) error { - value, exists := s.messageLocks.Load(messageID) - if !exists { - return status.Error(codes.NotFound, "message not locked") - } - - info, ok := value.(broadcast.MessageLockInfo) - if !ok { - return status.Error(codes.Internal, "invalid lock info") - } - - // Check if this node owns the lock - if info.NodeID != s.Op.ID() { - return status.Error(codes.PermissionDenied, "only the lock owner can extend timeout") - } - - // Check subscriber ID - if info.SubscriberID != subscriberID { - return status.Error(codes.PermissionDenied, "message locked by another subscriber") - } - - // Extend visibility timeout - newExpiresAt := time.Now().Add(visibilityTimeout) - info.Timeout = newExpiresAt - s.messageLocks.Store(messageID, info) - - // Create broadcast message - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("extend:%s:%d:%s", messageID, int(visibilityTimeout.Seconds()), s.Op.ID())), - } - msgBytes, _ := json.Marshal(broadcastData) - - // Broadcast new timeout to all nodes - s.Op.Broadcast(context.TODO(), msgBytes) - log.Printf("[ExtendTimeout] Node %s extended timeout for message: %s", s.Op.ID(), messageID) - - return nil + value, exists := s.messageLocks.Load(messageID) + if !exists { + return status.Error(codes.NotFound, "message not locked") + } + + info, ok := value.(broadcast.MessageLockInfo) + if !ok { + return status.Error(codes.Internal, "invalid lock info") + } + + // Check if this node owns the lock + if info.NodeID != s.Op.ID() { + return status.Error(codes.PermissionDenied, "only the lock owner can extend timeout") + } + + // Check subscriber ID + if info.SubscriberID != subscriberID { + return status.Error(codes.PermissionDenied, "message locked by another subscriber") + } + + // Extend visibility timeout + newExpiresAt := time.Now().Add(visibilityTimeout) + info.Timeout = newExpiresAt + s.messageLocks.Store(messageID, info) + + // Create broadcast message + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, + Msg: []byte(fmt.Sprintf("extend:%s:%d:%s", messageID, int(visibilityTimeout.Seconds()), s.Op.ID())), + } + msgBytes, _ := json.Marshal(broadcastData) + + // Broadcast new timeout to all nodes + s.Op.Broadcast(context.TODO(), msgBytes) + log.Printf("[ExtendTimeout] Node %s extended timeout for message: %s", s.Op.ID(), messageID) + + return nil } // HandleBroadcastMessage processes broadcast messages received from other nodes diff --git a/main.go b/main.go index 24f8e0d3..8bcb9803 100644 --- a/main.go +++ b/main.go @@ -31,13 +31,28 @@ func main() { return } - app := &app.PubSub{ - Client: spannerClient, - Storage: storage.NewStorage(), - } + // Initialize storage with message tracking capabilities + storageInstance := storage.NewStorage() + + + // Initialize app with all necessary components for distributed message handling + app := &app.PubSub{ + Client: spannerClient, + Storage: storageInstance, + MessageLocks: sync.Map{}, // For distributed message locking + MessageQueue: sync.Map{}, // For message tracking + Mutex: sync.Mutex{}, // For concurrency control + TimeoutTimers: sync.Map{}, // Add this for tracking message timeouts across nodes + } log.Println("[STORAGE]: Storage initialized") + // Configure timeout settings + const ( + defaultVisibilityTimeout = 30 * time.Second + timeoutCheckInterval = 5 * time.Second + ) + op := hedge.New( spannerClient, ":50052", // addr will be resolved internally @@ -55,7 +70,16 @@ func main() { ) app.Op = op + app.NodeID = op.ID() // Important for tracking which node handles which message + ctx, cancel := context.WithCancel(context.Background()) + + // Start timeout monitor for all nodes + go monitorMessageTimeouts(ctx, app, defaultVisibilityTimeout) + + // Start subscription validator + go validateSubscriptions(ctx, app) + go func() { if err := run(ctx, &server{PubSub: app}); err != nil { log.Fatalf("failed to run: %v", err) @@ -113,3 +137,63 @@ func serveHealthChecks() { conn.Close() } } + +// Monitor timeouts across all nodes +func monitorMessageTimeouts(ctx context.Context, app *app.PubSub, defaultTimeout time.Duration) { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + app.MessageLocks.Range(func(key, value interface{}) bool { + messageID := key.(string) + lockInfo := value.(broadcast.MessageLockInfo) + + // Check if message lock has expired + if time.Now().After(lockInfo.Timeout) { + // Broadcast unlock message to all nodes + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.msgEvent, + Msg: []byte(fmt.Sprintf("unlock:%s:%s", messageID, app.NodeID)), + } + bin, _ := json.Marshal(broadcastData) + app.Op.Broadcast(ctx, bin) + + // Clean up local state + app.MessageLocks.Delete(messageID) + if timer, ok := app.TimeoutTimers.Load(messageID); ok { + timer.(*time.Timer).Stop() + app.TimeoutTimers.Delete(messageID) + } + } + return true + }) + } + } +} + +// Validate subscriptions periodically +func validateSubscriptions(ctx context.Context, app *app.PubSub) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Request latest subscription data from leader + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.topicsub, + Msg: []byte("refresh"), + } + bin, _ := json.Marshal(broadcastData) + if _, err := app.Op.Request(ctx, bin); err != nil { + log.Printf("Error refreshing subscriptions: %v", err) + } + } + } +} \ No newline at end of file From 2e8255ed3b6638f9f87adef12d9edc2ebda23e4c Mon Sep 17 00:00:00 2001 From: Horichi Date: Wed, 26 Feb 2025 19:02:52 +0800 Subject: [PATCH 17/21] Debugged errors. Modified :Acknowledge in server.go, broadcast.go (handleBroadcastedMsg) and handleMessageTimeout (helpers.go) --- broadcast/broadcast.go | 221 ++++++++++---------- broadcast/broadcast_struct.go | 2 +- broadcast/unprocessed-broadcast.go | 2 +- helpers.go | 125 +++++++----- server.go | 318 +++++++++++++++-------------- testclient/testclient.go | 2 +- 6 files changed, 352 insertions(+), 318 deletions(-) diff --git a/broadcast/broadcast.go b/broadcast/broadcast.go index 99dfc4f8..2e88bdf5 100644 --- a/broadcast/broadcast.go +++ b/broadcast/broadcast.go @@ -4,8 +4,8 @@ import ( "encoding/json" "fmt" "log" - "strings" "strconv" + "strings" "time" pb "github.com/alphauslabs/pubsub-proto/v1" @@ -13,16 +13,16 @@ import ( ) const ( - message = "message" - topicsub = "topicsub" - msgEvent = "msgEvent" + Message = "message" + Topicsub = "topicsub" + MsgEvent = "msgEvent" // Message event types - lockMsg = "lock" - unlockMsg = "unlock" - deleteMsg = "delete" - extendMsg = "extend" - retryMsg = "retry" + LockMsg = "lock" + UnlockMsg = "unlock" + DeleteMsg = "delete" + ExtendMsg = "extend" + RetryMsg = "retry" ) type BroadCastInput struct { @@ -33,27 +33,29 @@ type BroadCastInput struct { // MessageLockInfo defines lock information structure // Note: This should be consistent with the structure in helpers.go type MessageLockInfo struct { - Locked bool - Timeout time.Time - NodeID string - SubscriberID string // Added to track which subscriber has the lock - LockHolders map[string]bool // Track which nodes have acknowledged the lock + Locked bool + Timeout time.Time + NodeID string + SubscriberID string // Added to track which subscriber has the lock + LockHolders map[string]bool // Track which nodes have acknowledged the lock } var ctrlbroadcast = map[string]func(*app.PubSub, []byte) ([]byte, error){ - message: handleBroadcastedMsg, - topicsub: handleBroadcastedTopicsub, - msgEvent: handleMessageEvent, // Handles message locks, unlocks, deletes + Message: handleBroadcastedMsg, + Topicsub: handleBroadcastedTopicsub, + MsgEvent: handleMessageEvent, // Handles message locks, unlocks, deletes } // Root handler for op.Broadcast() func Broadcast(data any, msg []byte) ([]byte, error) { -var in BroadCastInput - app := data.(*app.PubSub) + var in BroadCastInput + appInstance := data.(*app.PubSub) // Ensure we're using an instance, not a type + if err := json.Unmarshal(msg, &in); err != nil { return nil, err } - return ctrlbroadcast[in.Type](app, in.Msg) + + return ctrlbroadcast[in.Type](appInstance, in.Msg) } func handleBroadcastedMsg(app *app.PubSub, msg []byte) ([]byte, error) { @@ -63,6 +65,7 @@ func handleBroadcastedMsg(app *app.PubSub, msg []byte) ([]byte, error) { return nil, fmt.Errorf("failed to unmarshal message: %w", err) } + // Store in node queue/memory (not marking as processed yet) if err := app.Storage.StoreMessage(&message); err != nil { return nil, fmt.Errorf("failed to store message: %w", err) } @@ -70,7 +73,7 @@ func handleBroadcastedMsg(app *app.PubSub, msg []byte) ([]byte, error) { return nil, nil } - // Handles topic-subscription updates +// Handles topic-subscription updates func handleBroadcastedTopicsub(app *app.PubSub, msg []byte) ([]byte, error) { log.Println("Received topic-subscriptions:\n", string(msg)) if err := app.Storage.StoreTopicSubscriptions(msg); err != nil { @@ -80,9 +83,8 @@ func handleBroadcastedTopicsub(app *app.PubSub, msg []byte) ([]byte, error) { return nil, nil } - - // Handles lock/unlock/delete/extend operations separately -func handleMessageEvent(app *app.PubSub, msg []byte) ([]byte, error) { +// Handles lock/unlock/delete/extend operations separately +func handleMessageEvent(appInstance *app.PubSub, msg []byte) ([]byte, error) { parts := strings.Split(string(msg), ":") if len(parts) < 2 { return nil, fmt.Errorf("invalid message event format") @@ -91,13 +93,12 @@ func handleMessageEvent(app *app.PubSub, msg []byte) ([]byte, error) { messageType := parts[0] messageID := parts[1] - // Map message event handlers eventHandlers := map[string]func(*app.PubSub, string, []string) ([]byte, error){ - lockMsg: handleLockMsg, - unlockMsg: handleUnlockMsg, - deleteMsg: handleDeleteMsg, - extendMsg: handleExtendMsg, - retryMsg: handleRetryMsg, + LockMsg: handleLockMsg, + UnlockMsg: handleUnlockMsg, + DeleteMsg: handleDeleteMsg, + ExtendMsg: handleExtendMsg, + RetryMsg: handleRetryMsg, } handler, exists := eventHandlers[messageType] @@ -105,30 +106,29 @@ func handleMessageEvent(app *app.PubSub, msg []byte) ([]byte, error) { return nil, fmt.Errorf("unknown message event: %s", messageType) } - return handler(app, messageID, parts[2:]) + return handler(appInstance, messageID, parts[2:]) } - - // Message event handlers +// Message event handlers func handleLockMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { if len(params) < 3 { return nil, fmt.Errorf("invalid lock parameters") } - + timeoutSeconds, err := strconv.Atoi(params[0]) if err != nil { return nil, err } subscriberID := params[1] requestingNodeID := params[2] - + app.Mutex.Lock() defer app.Mutex.Unlock() - + // Check if already locked if existingLock, exists := app.MessageLocks.Load(messageID); exists { info := existingLock.(MessageLockInfo) - + // If lock is expired, allow new lock if time.Now().After(info.Timeout) { // Continue with new lock @@ -143,7 +143,7 @@ func handleLockMsg(app *app.PubSub, messageID string, params []string) ([]byte, } } - // Each node maintains its own timer + // Each node maintains its own timer // Create new lock lockInfo := MessageLockInfo{ Locked: true, @@ -152,41 +152,40 @@ func handleLockMsg(app *app.PubSub, messageID string, params []string) ([]byte, SubscriberID: subscriberID, LockHolders: make(map[string]bool), } - - // Mark this node as acknowledging the lock + + // Mark this node as acknowledging the lock lockInfo.LockHolders[app.NodeID] = true - + app.MessageLocks.Store(messageID, lockInfo) - + return nil, nil } - func handleUnlockMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { - if len(params) < 1 { - return nil, fmt.Errorf("invalid unlock parameters") - } - - unlockingNodeID := params[0] - - app.Mutex.Lock() - defer app.Mutex.Unlock() - - // Check if the message is locked - if lockInfo, exists := app.MessageLocks.Load(messageID); exists { - info := lockInfo.(MessageLockInfo) - - // Only the lock owner can unlock - if info.NodeID == unlockingNodeID { - app.MessageLocks.Delete(messageID) - log.Printf("[Unlock] Node %s acknowledged unlock for message: %s", app.NodeID, messageID) - } else { - log.Printf("[Unlock] Rejected unlock from non-owner node %s for message: %s", unlockingNodeID, messageID) - return nil, fmt.Errorf("only lock owner can unlock") - } - } - - return nil, nil + if len(params) < 1 { + return nil, fmt.Errorf("invalid unlock parameters") + } + + unlockingNodeID := params[0] + + app.Mutex.Lock() + defer app.Mutex.Unlock() + + // Check if the message is locked + if lockInfo, exists := app.MessageLocks.Load(messageID); exists { + info := lockInfo.(MessageLockInfo) + + // Only the lock owner can unlock + if info.NodeID == unlockingNodeID { + app.MessageLocks.Delete(messageID) + log.Printf("[Unlock] Node %s acknowledged unlock for message: %s", app.NodeID, messageID) + } else { + log.Printf("[Unlock] Rejected unlock from non-owner node %s for message: %s", unlockingNodeID, messageID) + return nil, fmt.Errorf("only lock owner can unlock") + } + } + + return nil, nil } func handleDeleteMsg(app *app.PubSub, messageID string, _ []string) ([]byte, error) { @@ -199,50 +198,50 @@ func handleDeleteMsg(app *app.PubSub, messageID string, _ []string) ([]byte, err } func handleExtendMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { - if len(params) < 2 { - return nil, fmt.Errorf("missing parameters for extend message") - } - - timeoutSeconds, err := strconv.Atoi(params[0]) - if err != nil { - return nil, err - } - - extendingNodeID := params[1] - - app.Mutex.Lock() - defer app.Mutex.Unlock() - - if lockInfo, ok := app.MessageLocks.Load(messageID); ok { - info := lockInfo.(MessageLockInfo) - - // Only update if the request comes from the lock owner - if info.NodeID == extendingNodeID { - info.Timeout = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) - app.MessageLocks.Store(messageID, info) - log.Printf("[Extend] Message %s timeout extended by node %s", messageID, extendingNodeID) - } else { - log.Printf("[Extend] Rejected extend from non-owner node %s for message: %s", extendingNodeID, messageID) - return nil, fmt.Errorf("only lock owner can extend timeout") - } - } - - return nil, nil + if len(params) < 2 { + return nil, fmt.Errorf("missing parameters for extend message") + } + + timeoutSeconds, err := strconv.Atoi(params[0]) + if err != nil { + return nil, err + } + + extendingNodeID := params[1] + + app.Mutex.Lock() + defer app.Mutex.Unlock() + + if lockInfo, ok := app.MessageLocks.Load(messageID); ok { + info := lockInfo.(MessageLockInfo) + + // Only update if the request comes from the lock owner + if info.NodeID == extendingNodeID { + info.Timeout = time.Now().Add(time.Duration(timeoutSeconds) * time.Second) + app.MessageLocks.Store(messageID, info) + log.Printf("[Extend] Message %s timeout extended by node %s", messageID, extendingNodeID) + } else { + log.Printf("[Extend] Rejected extend from non-owner node %s for message: %s", extendingNodeID, messageID) + return nil, fmt.Errorf("only lock owner can extend timeout") + } + } + + return nil, nil } func handleRetryMsg(app *app.PubSub, messageID string, params []string) ([]byte, error) { - if len(params) < 1 { - return nil, fmt.Errorf("invalid retry parameters") - } - - retryNodeID := params[0] - - // Make the message available again for processing - app.Mutex.Lock() - defer app.Mutex.Unlock() - - app.MessageLocks.Delete(messageID) - log.Printf("[Retry] Message %s is now available again (unlocked by node %s)", messageID, retryNodeID) - - return nil, nil -} \ No newline at end of file + if len(params) < 1 { + return nil, fmt.Errorf("invalid retry parameters") + } + + retryNodeID := params[0] + + // Make the message available again for processing + app.Mutex.Lock() + defer app.Mutex.Unlock() + + app.MessageLocks.Delete(messageID) + log.Printf("[Retry] Message %s is now available again (unlocked by node %s)", messageID, retryNodeID) + + return nil, nil +} diff --git a/broadcast/broadcast_struct.go b/broadcast/broadcast_struct.go index 59b6b50f..fb5fdcaa 100644 --- a/broadcast/broadcast_struct.go +++ b/broadcast/broadcast_struct.go @@ -87,7 +87,7 @@ func fetchAndBroadcast(ctx context.Context, op *hedge.Op, client *spanner.Client } broadcastMsg := BroadCastInput{ - Type: topicsub, + Type: Topicsub, Msg: msgData, } diff --git a/broadcast/unprocessed-broadcast.go b/broadcast/unprocessed-broadcast.go index c3b955fd..9ab2f901 100644 --- a/broadcast/unprocessed-broadcast.go +++ b/broadcast/unprocessed-broadcast.go @@ -72,7 +72,7 @@ func FetchAndBroadcastUnprocessedMessage(ctx context.Context, op *hedge.Op, span // Create broadcast input broadcastInput := BroadCastInput{ - Type: message, // Using const from same package + Type: Message, // Using const from same package Msg: data, } diff --git a/helpers.go b/helpers.go index bf457a3d..4fad945d 100644 --- a/helpers.go +++ b/helpers.go @@ -10,11 +10,11 @@ import ( "strings" "time" - pb "github.com/alphauslabs/pubsub-proto/v1" + // pb "github.com/alphauslabs/pubsub-proto/v1" // removed "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - // "github.com/alphauslabs/pubsub/app" + // "github.com/alphauslabs/pubsub/app" // removed "github.com/alphauslabs/pubsub/broadcast" ) @@ -45,7 +45,7 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber lockInfo := broadcast.MessageLockInfo{ Timeout: time.Now().Add(timeout), Locked: true, - NodeID: "default-node", + NodeID: s.Op.HostPort(), SubscriberID: subscriberID, LockHolders: make(map[string]bool), } @@ -57,13 +57,13 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber } // Store lock information - lockInfo.LockHolders[s.Op.ID()] = true - s.messageLocks.Store(messageID, lockInfo) + lockInfo.LockHolders[s.Op.HostPort()] = true + s.MessageLocks.Store(messageID, lockInfo) // Broadcast lock request broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("lock:%s:%d:%s:%s", messageID, int(timeout.Seconds()), subscriberID, s.Op.ID())), + Type: broadcast.MsgEvent, + Msg: []byte(fmt.Sprintf("lock:%s:%d:%s:%s", messageID, int(timeout.Seconds()), subscriberID, s.Op.HostPort())), } bin, _ := json.Marshal(broadcastData) @@ -79,7 +79,7 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber // Need majority for consensus if successCount < (len(out)/2 + 1) { - s.messageLocks.Delete(messageID) + s.MessageLocks.Delete(messageID) return fmt.Errorf("failed to acquire lock across majority of nodes") } @@ -96,9 +96,9 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber } func (s *server) handleMessageTimeout(messageID string) { - if lockInfo, ok := s.messageLocks.Load(messageID); ok { + if lockInfo, ok := s.MessageLocks.Load(messageID); ok { info := lockInfo.(broadcast.MessageLockInfo) - if info.NodeID == s.Op.ID() && info.Locked && time.Now().After(info.Timeout) { + if info.NodeID == s.Op.HostPort() && info.Locked && time.Now().After(info.Timeout) { s.broadcastUnlock(context.Background(), messageID) } } @@ -109,49 +109,76 @@ func (s *server) broadcastUnlock(ctx context.Context, messageID string) { // Any node can broadcast an unlock broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("unlock:%s:%s", messageID, s.Op.ID())), + Type: broadcast.MsgEvent, + Msg: []byte(fmt.Sprintf("unlock:%s:%s", messageID, s.Op.HostPort())), } bin, _ := json.Marshal(broadcastData) s.Op.Broadcast(ctx, bin) // Clean up local state - s.messageLocks.Delete(messageID) + s.MessageLocks.Delete(messageID) if timer, ok := s.timeoutTimers.Load(messageID); ok { timer.(*time.Timer).Stop() s.timeoutTimers.Delete(messageID) } - log.Printf("[Unlock] Node %s unlocked message: %s", s.Op.ID(), messageID) + log.Printf("[Unlock] Node %s unlocked message: %s", s.Op.HostPort(), messageID) } -// requestMessageFromLeader asks the leader node for messages -func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { - // Use the message type with proper format for requesting a message - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("getmessage:%s", topicID)), - } - - bin, _ := json.Marshal(broadcastData) - resp, err := s.Op.Request(context.Background(), bin) - if err != nil { - return nil, err - } - - if len(resp) == 0 { - return nil, status.Error(codes.NotFound, "no messages available") - } - - var message pb.Message - if err := json.Unmarshal(resp, &message); err != nil { - return nil, err - } - - return &message, nil -} +// REQUEST MESSAGE USING BROADCAST METHOD +// func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { +// broadcastData := broadcast.BroadCastInput{ +// Type: broadcast.MsgEvent, +// Msg: []byte(fmt.Sprintf("getmessage:%s", topicID)), +// } + +// bin, _ := json.Marshal(broadcastData) +// outputs := s.Op.Broadcast(context.Background(), bin) + +// // Process broadcast responses +// for _, output := range outputs { +// if output.Error != nil { +// continue +// } +// if len(output.Reply) > 0 { +// var message pb.Message +// if err := json.Unmarshal(output.Reply, &message); err != nil { +// continue +// } +// return &message, nil +// } +// } + +// return nil, status.Error(codes.NotFound, "no messages available") +// } + +//REQUEST MESSAGE USING REQUEST METHOD +// func (s *server) requestMessageFromLeader(topicID string) (*pb.Message, error) { +// // Use the message type with proper format for requesting a message +// broadcastData := broadcast.BroadCastInput{ +// Type: broadcast.MsgEvent, +// Msg: []byte(fmt.Sprintf("getmessage:%s", topicID)), +// } + +// bin, _ := json.Marshal(broadcastData) +// resp, err := s.Op.Request(context.Background(), bin) +// if err != nil { +// return nil, err +// } + +// if len(resp) == 0 { +// return nil, status.Error(codes.NotFound, "no messages available") +// } + +// var message pb.Message +// if err := json.Unmarshal(resp, &message); err != nil { +// return nil, err +// } + +// return &message, nil +// } // ExtendVisibilityTimeout extends the visibility timeout for a message func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, visibilityTimeout time.Duration) error { - value, exists := s.messageLocks.Load(messageID) + value, exists := s.MessageLocks.Load(messageID) if !exists { return status.Error(codes.NotFound, "message not locked") } @@ -162,7 +189,7 @@ func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, } // Check if this node owns the lock - if info.NodeID != s.Op.ID() { + if info.NodeID != s.Op.HostPort() { return status.Error(codes.PermissionDenied, "only the lock owner can extend timeout") } @@ -174,18 +201,18 @@ func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, // Extend visibility timeout newExpiresAt := time.Now().Add(visibilityTimeout) info.Timeout = newExpiresAt - s.messageLocks.Store(messageID, info) + s.MessageLocks.Store(messageID, info) // Create broadcast message broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("extend:%s:%d:%s", messageID, int(visibilityTimeout.Seconds()), s.Op.ID())), + Type: broadcast.MsgEvent, + Msg: []byte(fmt.Sprintf("extend:%s:%d:%s", messageID, int(visibilityTimeout.Seconds()), s.Op.HostPort())), } msgBytes, _ := json.Marshal(broadcastData) // Broadcast new timeout to all nodes s.Op.Broadcast(context.TODO(), msgBytes) - log.Printf("[ExtendTimeout] Node %s extended timeout for message: %s", s.Op.ID(), messageID) + log.Printf("[ExtendTimeout] Node %s extended timeout for message: %s", s.Op.HostPort(), messageID) return nil } @@ -209,18 +236,18 @@ func (s *server) HandleBroadcastMessage(msgType string, msgData []byte) error { } // Store the lock locally - lockInfo := MessageLockInfo{ + lockInfo := broadcast.MessageLockInfo{ Timeout: time.Now().Add(time.Duration(timeoutSeconds) * time.Second), Locked: true, - NodeID: s.Op.ID(), // This is the current node + NodeID: s.Op.HostPort(), // This is the current node SubscriberID: subscriberID, LockHolders: make(map[string]bool), } - s.messageLocks.Store(messageID, lockInfo) + s.MessageLocks.Store(messageID, lockInfo) case "unlock": messageID := string(msgData) - s.messageLocks.Delete(messageID) + s.MessageLocks.Delete(messageID) } diff --git a/server.go b/server.go index 02ca1908..aec9bcf7 100644 --- a/server.go +++ b/server.go @@ -1,12 +1,13 @@ -//server.go +// server.go package main import ( "context" "encoding/json" + "fmt" "log" + "sync" // added sync package "time" - "fmt" "cloud.google.com/go/spanner" pb "github.com/alphauslabs/pubsub-proto/v1" @@ -20,6 +21,8 @@ import ( type server struct { *app.PubSub pb.UnimplementedPubSubServiceServer + MessageLocks sync.Map // Stores message lock information + timeoutTimers sync.Map // Stores message timeout timers } // Constant for table name and message types @@ -29,10 +32,9 @@ const ( // topicsub = "topicsub" // Match the constants in broadcast.go ) - // Publish a message to a topic func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.PublishResponse, error) { - if in.Topic == "" { + if in.TopicId == "" { return nil, status.Error(codes.InvalidArgument, "topic must not be empty") } b, _ := json.Marshal(in) @@ -45,19 +47,18 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis messageID := uuid.New().String() mutation := spanner.InsertOrUpdate( - MessagesTable, - []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, - []interface{}{ - messageID, - in.Topic, - in.Payload, - spanner.CommitTimestamp, - spanner.CommitTimestamp, - nil, // Explicitly set visibilityTimeout as NULL - false, // Default to unprocessed - }, -) - + MessagesTable, + []string{"id", "topic", "payload", "createdAt", "updatedAt", "visibilityTimeout", "processed"}, + []interface{}{ + messageID, + in.TopicId, + in.Payload, + spanner.CommitTimestamp, + spanner.CommitTimestamp, + nil, // Explicitly set visibilityTimeout as NULL + false, // Default to unprocessed + }, + ) _, err := s.Client.Apply(ctx, []*spanner.Mutation{mutation}) if err != nil { @@ -67,7 +68,7 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis // broadcast message bcastin := broadcast.BroadCastInput{ - Type: broadcast.message, + Type: broadcast.Message, Msg: b, } bin, _ := json.Marshal(bcastin) @@ -82,150 +83,157 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis } // Subscribe to receive messages for a subscription -// Subscribe to receive messages for a subscription + func (s *server) Subscribe(in *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { - // Validate subscription using storage - subs, err := s.Storage.GetSubscribtionsForTopic(in.TopicId) - if err != nil { - return status.Errorf(codes.NotFound, "Topic %s not found", in.TopicId) - } - - found := false - for _, sub := range subs { - if sub == in.SubscriptionId { - found = true - break - } - } - - if !found { - return status.Errorf(codes.NotFound, "Subscription %s not found", in.SubscriptionId) - } - - log.Printf("[Subscribe] Starting subscription stream for ID: %s", in.SubscriptionId) - - for { - select { - case <-stream.Context().Done(): - return nil - default: - // Get messages from storage - messages, err := s.Storage.GetMessagesByTopic(in.TopicId) - if err != nil { - log.Printf("[Subscribe] Error getting messages: %v", err) - time.Sleep(time.Second) - continue - } - - if len(messages) == 0 { - time.Sleep(time.Second) - continue - } - - for _, message := range messages { - // Check if message is locked - if _, exists := s.messageLocks.Load(message.Id); exists { - continue - } - - // Try to acquire lock - if err := s.broadcastLock(stream.Context(), message.Id, in.SubscriptionId, 30*time.Second); err != nil { - continue - } - - // Send message - if err := stream.Send(message); err != nil { - s.broadcastUnlock(stream.Context(), message.Id) - return err - } - } - } - } + // Validate if subscription exists for the given topic + subs, err := s.Storage.GetSubscribtionsForTopic(in.TopicId) + if err != nil { + return status.Errorf(codes.NotFound, "Topic %s not found", in.TopicId) + } + + // Check if the provided subscription ID exists in the topic's subscriptions + found := false + for _, sub := range subs { + if sub == in.SubscriptionId { + found = true + break + } + } + + if !found { + return status.Errorf(codes.NotFound, "Subscription %s not found", in.SubscriptionId) + } + + log.Printf("[Subscribe] Starting subscription stream for ID: %s", in.SubscriptionId) + + // Continuous loop to stream messages + for { + select { + // Check if client has disconnected + case <-stream.Context().Done(): + return nil + default: + // Get messages from local storage for the topic + messages, err := s.Storage.GetMessagesByTopic(in.TopicId) + if err != nil { + log.Printf("[Subscribe] Error getting messages: %v", err) + time.Sleep(time.Second) // Back off on error + continue + } + + // If no messages, wait before checking again + if len(messages) == 0 { + time.Sleep(time.Second) + continue + } + + // Process each message + for _, message := range messages { + // Skip if message is already locked by another subscriber + if _, exists := s.MessageLocks.Load(message.Id); exists { + continue + } + + // Attempt to acquire distributed lock for the message + // Default visibility timeout of 30 seconds + if err := s.broadcastLock(stream.Context(), message.Id, in.SubscriptionId, 30*time.Second); err != nil { + continue // Skip if unable to acquire lock + } + + // Stream message to subscriber + if err := stream.Send(message); err != nil { + // Release lock if sending fails + s.broadcastUnlock(stream.Context(), message.Id) + return err // Return error to close stream + } + } + } + } } // Acknowledge a processed message func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*pb.AcknowledgeResponse, error) { - // Verify lock exists and is valid - lockInfo, ok := s.messageLocks.Load(in.Id) - if !ok { - return nil, status.Error(codes.NotFound, "message lock not found") - } - - info := lockInfo.(broadcast.MessageLockInfo) - if !info.Locked || time.Now().After(info.Timeout) { - return nil, status.Error(codes.FailedPrecondition, "message lock expired") - } - - // Get message from storage - msg, err := s.Storage.GetMessage(in.Id) - if err != nil { - return nil, status.Error(codes.NotFound, "message not found") - } - - // Mark message as processed in storage - msg.Processed = true - if err := s.Storage.StoreMessage(msg); err != nil { - return nil, status.Error(codes.Internal, "failed to update message") - } - - // Broadcast acknowledgment - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("delete:%s", in.Id)), - } - - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(ctx, bin) - - // Clean up local state - s.messageLocks.Delete(in.Id) - if timer, ok := s.timeoutTimers.Load(in.Id); ok { - timer.(*time.Timer).Stop() - s.timeoutTimers.Delete(in.Id) - } - - return &pb.AcknowledgeResponse{Success: true}, nil + // Check if message lock exists and is still valid (within 1 minute) + lockInfo, ok := s.MessageLocks.Load(in.Id) + if !ok { + return nil, status.Error(codes.NotFound, "message lock not found") + } + + info := lockInfo.(broadcast.MessageLockInfo) + // Check if lock is valid and not timed out + if !info.Locked || time.Now().After(info.Timeout) { + // Message already timed out - handled by handleMessageTimeout + return nil, status.Error(codes.FailedPrecondition, "message lock expired") + } + + // Get message processed in time + msg, err := s.Storage.GetMessage(in.Id) + if err != nil { + return nil, status.Error(codes.NotFound, "message not found") + } + // Mark as processed since subscriber acknowledged in time + msg.Processed = true + if err := s.Storage.StoreMessage(msg); err != nil { + return nil, status.Error(codes.Internal, "failed to update message") + } + + // Broadcast successful processing + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.MsgEvent, + Msg: []byte(fmt.Sprintf("delete:%s", in.Id)), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + + // Clean up message (processed) + s.MessageLocks.Delete(in.Id) + if timer, ok := s.timeoutTimers.Load(in.Id); ok { + timer.(*time.Timer).Stop() + s.timeoutTimers.Delete(in.Id) + } + + return &pb.AcknowledgeResponse{Success: true}, nil } // ModifyVisibilityTimeout extends message lock timeout func (s *server) ModifyVisibilityTimeout(ctx context.Context, in *pb.ModifyVisibilityTimeoutRequest) (*pb.ModifyVisibilityTimeoutResponse, error) { - lockInfo, ok := s.messageLocks.Load(in.Id) - if !ok { - return nil, status.Error(codes.NotFound, "message lock not found") - } - info := lockInfo.(broadcast.MessageLockInfo) - if !info.Locked { - return nil, status.Error(codes.FailedPrecondition, "message not locked") - } - - // Check if this node owns the lock before extending - if info.NodeID != s.Op.ID() { - return nil, status.Error(codes.PermissionDenied, "only the lock owner can extend timeout") - } - - // Broadcast new timeout - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("extend:%s:%d:%s", in.Id, in.NewTimeout, s.Op.ID())), - } - bin, _ := json.Marshal(broadcastData) - s.Op.Broadcast(ctx, bin) - - // Update local timer - if timer, ok := s.timeoutTimers.Load(in.Id); ok { - timer.(*time.Timer).Stop() - } - newTimer := time.NewTimer(time.Duration(in.NewTimeout) * time.Second) - s.timeoutTimers.Store(in.Id, newTimer) - - // Update lock info - info.Timeout = time.Now().Add(time.Duration(in.NewTimeout) * time.Second) - s.messageLocks.Store(in.Id, info) - - go func() { - <-newTimer.C - s.handleMessageTimeout(in.Id) - }() - - return &pb.ModifyVisibilityTimeoutResponse{Success: true}, nil + lockInfo, ok := s.MessageLocks.Load(in.Id) + if !ok { + return nil, status.Error(codes.NotFound, "message lock not found") + } + info := lockInfo.(broadcast.MessageLockInfo) + if !info.Locked { + return nil, status.Error(codes.FailedPrecondition, "message not locked") + } + + // Check if this node owns the lock before extending + if info.NodeID != s.Op.HostPort() { + return nil, status.Error(codes.PermissionDenied, "only the lock owner can extend timeout") + } + + // Broadcast new timeout + broadcastData := broadcast.BroadCastInput{ + Type: broadcast.MsgEvent, + Msg: []byte(fmt.Sprintf("extend:%s:%d:%s", in.Id, in.NewTimeout, s.Op.HostPort())), + } + bin, _ := json.Marshal(broadcastData) + s.Op.Broadcast(ctx, bin) + + // Update local timer + if timer, ok := s.timeoutTimers.Load(in.Id); ok { + timer.(*time.Timer).Stop() + } + newTimer := time.NewTimer(time.Duration(in.NewTimeout) * time.Second) + s.timeoutTimers.Store(in.Id, newTimer) + + // Update lock info + info.Timeout = time.Now().Add(time.Duration(in.NewTimeout) * time.Second) + s.MessageLocks.Store(in.Id, info) + + go func() { + <-newTimer.C + s.handleMessageTimeout(in.Id) + }() + + return &pb.ModifyVisibilityTimeoutResponse{Success: true}, nil } diff --git a/testclient/testclient.go b/testclient/testclient.go index cad6f5e8..682c7cf7 100644 --- a/testclient/testclient.go +++ b/testclient/testclient.go @@ -31,7 +31,7 @@ func main() { switch *method { case "publish": - r, err := c.Publish(ctx, &pb.PublishRequest{Topic: "topic1", Payload: "Hello World"}) + r, err := c.Publish(ctx, &pb.PublishRequest{TopicId: "topic1", Payload: "Hello World"}) if err != nil { log.Fatalf("Publish failed: %v", err) } From 33129d298844204ad841395a1c582dac05b7b8ba Mon Sep 17 00:00:00 2001 From: Kishea Kate Andoy <142709948+kitkatchoco2002@users.noreply.github.com> Date: Wed, 26 Feb 2025 20:08:28 +0800 Subject: [PATCH 18/21] reverted the original main.go --- main.go | 117 +++++++++++++------------------------------------------- 1 file changed, 27 insertions(+), 90 deletions(-) diff --git a/main.go b/main.go index 8bcb9803..1639594c 100644 --- a/main.go +++ b/main.go @@ -3,18 +3,22 @@ package main import ( "context" "flag" + "fmt" "log" "net" "os" "os/signal" "syscall" + "time" "cloud.google.com/go/spanner" pb "github.com/alphauslabs/pubsub-proto/v1" "github.com/alphauslabs/pubsub/app" "github.com/alphauslabs/pubsub/broadcast" + "github.com/alphauslabs/pubsub/send" "github.com/alphauslabs/pubsub/storage" - "github.com/flowerinthenight/hedge/v2" + "github.com/alphauslabs/pubsub/utils" + "github.com/flowerinthenight/hedge" "google.golang.org/grpc" "google.golang.org/grpc/reflection" ) @@ -23,6 +27,7 @@ var port = flag.String("port", ":50051", "Main gRPC server port") func main() { flag.Parse() + log.SetOutput(os.Stderr) go serveHealthChecks() // handle health checks from our LB spannerClient, err := spanner.NewClient(context.Background(), "projects/labs-169405/instances/alphaus-dev/databases/main") @@ -31,28 +36,13 @@ func main() { return } - // Initialize storage with message tracking capabilities - storageInstance := storage.NewStorage() - - - // Initialize app with all necessary components for distributed message handling - app := &app.PubSub{ - Client: spannerClient, - Storage: storageInstance, - MessageLocks: sync.Map{}, // For distributed message locking - MessageQueue: sync.Map{}, // For message tracking - Mutex: sync.Mutex{}, // For concurrency control - TimeoutTimers: sync.Map{}, // Add this for tracking message timeouts across nodes - } + app := &app.PubSub{ + Client: spannerClient, + Storage: storage.NewStorage(), + } log.Println("[STORAGE]: Storage initialized") - // Configure timeout settings - const ( - defaultVisibilityTimeout = 30 * time.Second - timeoutCheckInterval = 5 * time.Second - ) - op := hedge.New( spannerClient, ":50052", // addr will be resolved internally @@ -61,7 +51,7 @@ func main() { "logtable", hedge.WithLeaderHandler( // if leader only, handles Send() app, - send, + send.Send, ), hedge.WithBroadcastHandler( // handles Broadcast() app, @@ -70,16 +60,7 @@ func main() { ) app.Op = op - app.NodeID = op.ID() // Important for tracking which node handles which message - ctx, cancel := context.WithCancel(context.Background()) - - // Start timeout monitor for all nodes - go monitorMessageTimeouts(ctx, app, defaultVisibilityTimeout) - - // Start subscription validator - go validateSubscriptions(ctx, app) - go func() { if err := run(ctx, &server{PubSub: app}); err != nil { log.Fatalf("failed to run: %v", err) @@ -89,6 +70,22 @@ func main() { done := make(chan error, 1) // optional wait go op.Run(ctx, done) + // Wait for leader availability + func() { + var m string + defer func(l *string, t time.Time) { + log.Printf("%v: %v", *l, time.Since(t)) + }(&m, time.Now()) + log.Println("Waiting for leader to be active...") + ok, err := utils.EnsureLeaderActive(op, ctx) + switch { + case !ok: + m = fmt.Sprintf("failed: %v, no leader after ", err) + default: + m = "leader active after " + } + }() + // Start our fetching and broadcast routine for topic-subscription structure. go broadcast.StartDistributor(ctx, op, spannerClient) // Start our fetching and broadcast routine for unprocessed messages. @@ -137,63 +134,3 @@ func serveHealthChecks() { conn.Close() } } - -// Monitor timeouts across all nodes -func monitorMessageTimeouts(ctx context.Context, app *app.PubSub, defaultTimeout time.Duration) { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - app.MessageLocks.Range(func(key, value interface{}) bool { - messageID := key.(string) - lockInfo := value.(broadcast.MessageLockInfo) - - // Check if message lock has expired - if time.Now().After(lockInfo.Timeout) { - // Broadcast unlock message to all nodes - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.msgEvent, - Msg: []byte(fmt.Sprintf("unlock:%s:%s", messageID, app.NodeID)), - } - bin, _ := json.Marshal(broadcastData) - app.Op.Broadcast(ctx, bin) - - // Clean up local state - app.MessageLocks.Delete(messageID) - if timer, ok := app.TimeoutTimers.Load(messageID); ok { - timer.(*time.Timer).Stop() - app.TimeoutTimers.Delete(messageID) - } - } - return true - }) - } - } -} - -// Validate subscriptions periodically -func validateSubscriptions(ctx context.Context, app *app.PubSub) { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - // Request latest subscription data from leader - broadcastData := broadcast.BroadCastInput{ - Type: broadcast.topicsub, - Msg: []byte("refresh"), - } - bin, _ := json.Marshal(broadcastData) - if _, err := app.Op.Request(ctx, bin); err != nil { - log.Printf("Error refreshing subscriptions: %v", err) - } - } - } -} \ No newline at end of file From 90fd323f141f3174b1f2a0f16555f0d1d760a7af Mon Sep 17 00:00:00 2001 From: tituscarl Date: Wed, 26 Feb 2025 20:38:42 +0800 Subject: [PATCH 19/21] fix conflicts --- broadcast/broadcast_struct.go | 164 ---------------------------------- broadcast/broadcaststruct.go | 2 +- 2 files changed, 1 insertion(+), 165 deletions(-) delete mode 100644 broadcast/broadcast_struct.go diff --git a/broadcast/broadcast_struct.go b/broadcast/broadcast_struct.go deleted file mode 100644 index fb5fdcaa..00000000 --- a/broadcast/broadcast_struct.go +++ /dev/null @@ -1,164 +0,0 @@ -package broadcast - -import ( - "context" - "encoding/json" - "log" - "time" - - "cloud.google.com/go/spanner" - "github.com/flowerinthenight/hedge/v2" - "google.golang.org/api/iterator" -) - -// fetchAndBroadcast fetches updated topic-subscription data and broadcasts it if there are updates. -func fetchAndBroadcast(ctx context.Context, op *hedge.Op, client *spanner.Client, lastChecked *time.Time, lastBroadcasted *map[string][]string) { - stmt := spanner.Statement{ - SQL: `SELECT topic, ARRAY_AGG(name) AS subscriptions - FROM Subscriptions - WHERE updatedAt > @last_checked_time - GROUP BY topic`, - Params: map[string]interface{}{"last_checked_time": *lastChecked}, - } - - iter := client.Single().Query(ctx, stmt) - defer iter.Stop() - - topicSub := make(map[string][]string) - for { - row, err := iter.Next() - if err == iterator.Done { - break - } - if err != nil { - log.Printf("Error iterating rows: %v", err) - return - } - - var topic string - var subscriptions []string - if err := row.Columns(&topic, &subscriptions); err != nil { - log.Printf("Error reading row: %v", err) - continue - } - - // Ensure subscriptions is not nil - if subscriptions == nil { - subscriptions = []string{} - } - - topicSub[topic] = subscriptions - } - - // if there are no new updates, it will log - if len(topicSub) == 0 { - log.Println("Leader: No new updates, skipping broadcast.") - if len(*lastBroadcasted) > 0 { - log.Println("Leader: Subscription topic structure is still:", *lastBroadcasted) - } else { - log.Println("Leader: No previous topic-subscription structure available.") - } - return - } - - // compare topicSub with lastBroadcasted to check if they are exactly the same - //ex: subscription was updated but reverted back to its original state before the next check - same := true - for key, subs := range topicSub { - if lastSubs, exists := (*lastBroadcasted)[key]; !exists || !equalStringSlices(subs, lastSubs) { - same = false - break - } - } - - if same { - log.Println("Leader: No new updates, skipping broadcast.") - log.Println("Leader: Subscription topic structure is still:", *lastBroadcasted) - return - } - - log.Println("Leader: Fetched topic subscriptions:", topicSub) - - // marshal topic-subscription data - msgData, err := json.Marshal(topicSub) - if err != nil { - log.Printf("Error marshalling topicSub: %v", err) - return - } - - broadcastMsg := BroadCastInput{ - Type: Topicsub, - Msg: msgData, - } - - // marshal BroadCastInput - broadcastData, err := json.Marshal(broadcastMsg) - if err != nil { - log.Printf("Error marshalling BroadCastInput: %v", err) - return - } - - // broadcast message - for _, r := range op.Broadcast(ctx, broadcastData) { - if r.Error != nil { - log.Printf("Error broadcasting to %s: %v", r.Id, r.Error) - } - } - - // update last checked time and last broadcasted structure - *lastChecked = time.Now() - *lastBroadcasted = topicSub - log.Println("Leader: Topic-subscription structure broadcast completed.") -} - -// StartDistributor initializes the distributor that periodically checks for updates. -func StartDistributor(ctx context.Context, op *hedge.Op, client *spanner.Client) { - lastChecked := time.Now().Add(-10 * time.Second) - lastBroadcasted := make(map[string][]string) // this will store the last known structure - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if hasLock, _ := op.HasLock(); hasLock { - log.Println("Leader: Processing updates...") - fetchAndBroadcast(ctx, op, client, &lastChecked, &lastBroadcasted) - } else { - log.Println("Follower: No action needed.") - } - } - } -} - -// used to compare two string slices -func equalStringSlices(a, b []string) bool { - if len(a) != len(b) { - return false - } - exists := make(map[string]bool) - for _, val := range a { - exists[val] = true - } - for _, val := range b { - if !exists[val] { - return false - } - } - return true -} - -// Leader broadcasts topic-subscription to all nodes (even if no changes/updates happened) -/* -func broadcastTopicSubStruct(op *hedge.Op, topicSub map[string][]string) { - data, err := json.Marshal(topicSub) - if err != nil { - log.Printf("Error marshalling topic-subscription: %v", err) - return - } - op.Broadcast(context.Background(), data) - log.Println("Leader: Broadcasted topic-subscription structure to all nodes") -} -*/ diff --git a/broadcast/broadcaststruct.go b/broadcast/broadcaststruct.go index 4e2a7787..c886b4a2 100644 --- a/broadcast/broadcaststruct.go +++ b/broadcast/broadcaststruct.go @@ -110,7 +110,7 @@ func fetchAndBroadcast(ctx context.Context, op *hedge.Op, client *spanner.Client } broadcastMsg := BroadCastInput{ - Type: topicsub, + Type: Topicsub, Msg: msgData, } From c1151e24366963f728208dd0f1c5d8bbc275b063 Mon Sep 17 00:00:00 2001 From: tituscarl Date: Thu, 27 Feb 2025 08:43:52 +0800 Subject: [PATCH 20/21] add comments --- app/app.go | 6 +++--- broadcast/broadcast.go | 4 ++-- helpers.go | 14 ++++++++++---- server.go | 17 +++++------------ storage/storage.go | 6 ++---- 5 files changed, 22 insertions(+), 25 deletions(-) diff --git a/app/app.go b/app/app.go index d843b1d3..2200f105 100644 --- a/app/app.go +++ b/app/app.go @@ -13,7 +13,7 @@ type PubSub struct { Client *spanner.Client Storage *storage.Storage NodeID string - MessageLocks sync.Map - MessageQueue sync.Map - Mutex sync.Mutex + MessageLocks sync.Map // messageID -> MessageLockInfo + MessageTimer sync.Map // messageID -> *time.Timer + Mutex sync.Mutex // app level mutex } diff --git a/broadcast/broadcast.go b/broadcast/broadcast.go index 2e88bdf5..dbf9a34e 100644 --- a/broadcast/broadcast.go +++ b/broadcast/broadcast.go @@ -122,7 +122,7 @@ func handleLockMsg(app *app.PubSub, messageID string, params []string) ([]byte, subscriberID := params[1] requestingNodeID := params[2] - app.Mutex.Lock() + app.Mutex.Lock() // todo: mutex lock and unlock, might remove this if no need defer app.Mutex.Unlock() // Check if already locked @@ -193,7 +193,7 @@ func handleDeleteMsg(app *app.PubSub, messageID string, _ []string) ([]byte, err defer app.Mutex.Unlock() app.MessageLocks.Delete(messageID) - app.MessageQueue.Delete(messageID) + app.MessageTimer.Delete(messageID) return nil, nil } diff --git a/helpers.go b/helpers.go index 4fad945d..4a021d5b 100644 --- a/helpers.go +++ b/helpers.go @@ -15,9 +15,12 @@ import ( "google.golang.org/grpc/status" // "github.com/alphauslabs/pubsub/app" // removed + "github.com/alphauslabs/pubsub/broadcast" ) +// PubSubWrapper wraps app.PubSub to allow defining methods on it + // validateTopicSubscription checks if subscription exists in storage func (s *server) validateTopicSubscription(subscriptionID string) error { subs, err := s.Storage.GetSubscribtionsForTopic(subscriptionID) @@ -78,6 +81,7 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber } // Need majority for consensus + // todo: Nice idea, but what if we have to be strict, like all nodes (instead of majority) must acknowledge the lock? if successCount < (len(out)/2 + 1) { s.MessageLocks.Delete(messageID) return fmt.Errorf("failed to acquire lock across majority of nodes") @@ -85,7 +89,7 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber // Start timeout timer timer := time.NewTimer(timeout) - s.timeoutTimers.Store(messageID, timer) + s.MessageTimer.Store(messageID, timer) go func() { <-timer.C @@ -95,6 +99,8 @@ func (s *server) broadcastLock(ctx context.Context, messageID string, subscriber return nil } +// todo: not sure if we only allow the node that locked the message to unlock it, what if every node will just unlock it by themselves without broadcasting. +// todo: if the locker node will crash, no one will broadcast to unlock? func (s *server) handleMessageTimeout(messageID string) { if lockInfo, ok := s.MessageLocks.Load(messageID); ok { info := lockInfo.(broadcast.MessageLockInfo) @@ -116,9 +122,9 @@ func (s *server) broadcastUnlock(ctx context.Context, messageID string) { s.Op.Broadcast(ctx, bin) // Clean up local state s.MessageLocks.Delete(messageID) - if timer, ok := s.timeoutTimers.Load(messageID); ok { + if timer, ok := s.MessageLocks.Load(messageID); ok { timer.(*time.Timer).Stop() - s.timeoutTimers.Delete(messageID) + s.MessageTimer.Delete(messageID) } log.Printf("[Unlock] Node %s unlocked message: %s", s.Op.HostPort(), messageID) } @@ -194,7 +200,7 @@ func (s *server) ExtendVisibilityTimeout(messageID string, subscriberID string, } // Check subscriber ID - if info.SubscriberID != subscriberID { + if info.SubscriberID != subscriberID { // todo: what does this mean? return status.Error(codes.PermissionDenied, "message locked by another subscriber") } diff --git a/server.go b/server.go index 171dbc25..05c6a5c8 100644 --- a/server.go +++ b/server.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "log" - "sync" // added sync package "time" "cloud.google.com/go/spanner" @@ -21,15 +20,10 @@ import ( type server struct { *app.PubSub pb.UnimplementedPubSubServiceServer - MessageLocks sync.Map // Stores message lock information - timeoutTimers sync.Map // Stores message timeout timers } -// Constant for table name and message types const ( MessagesTable = "Messages" - // message = "message" // Match the constants in broadcast.go - // topicsub = "topicsub" // Match the constants in broadcast.go ) // Publish a message to a topic @@ -77,7 +71,6 @@ func (s *server) Publish(ctx context.Context, in *pb.PublishRequest) (*pb.Publis } // Subscribe to receive messages for a subscription - func (s *server) Subscribe(in *pb.SubscribeRequest, stream pb.PubSubService_SubscribeServer) error { // Validate if subscription exists for the given topic subs, err := s.Storage.GetSubscribtionsForTopic(in.TopicId) @@ -117,7 +110,7 @@ func (s *server) Subscribe(in *pb.SubscribeRequest, stream pb.PubSubService_Subs // If no messages, wait before checking again if len(messages) == 0 { - time.Sleep(time.Second) + time.Sleep(time.Second) // todo: not sure if this is the best way continue } @@ -181,9 +174,9 @@ func (s *server) Acknowledge(ctx context.Context, in *pb.AcknowledgeRequest) (*p // Clean up message (processed) s.MessageLocks.Delete(in.Id) - if timer, ok := s.timeoutTimers.Load(in.Id); ok { + if timer, ok := s.MessageTimer.Load(in.Id); ok { timer.(*time.Timer).Stop() - s.timeoutTimers.Delete(in.Id) + s.MessageTimer.Delete(in.Id) } return &pb.AcknowledgeResponse{Success: true}, nil @@ -214,11 +207,11 @@ func (s *server) ModifyVisibilityTimeout(ctx context.Context, in *pb.ModifyVisib s.Op.Broadcast(ctx, bin) // Update local timer - if timer, ok := s.timeoutTimers.Load(in.Id); ok { + if timer, ok := s.MessageTimer.Load(in.Id); ok { timer.(*time.Timer).Stop() } newTimer := time.NewTimer(time.Duration(in.NewTimeout) * time.Second) - s.timeoutTimers.Store(in.Id, newTimer) + s.MessageTimer.Store(in.Id, newTimer) // Update lock info info.Timeout = time.Now().Add(time.Duration(in.NewTimeout) * time.Second) diff --git a/storage/storage.go b/storage/storage.go index fe0d4241..1f56dadb 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -133,8 +133,6 @@ func (s *Storage) GetSubscribtionsForTopic(topicID string) ([]string, error) { if !exists { return nil, ErrTopicNotFound } - - result := make([]string, len(subs)) - copy(result, subs) - return result, nil + // todo: check if subs is empty meaning no subscription for that topic and also check if a subscription is attaced to a topic + return subs, nil } From 815bf952a889c02c13a7668c30edfe3b8cdc5fda Mon Sep 17 00:00:00 2001 From: tituscarl Date: Thu, 27 Feb 2025 08:45:38 +0800 Subject: [PATCH 21/21] remove comments --- helpers.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/helpers.go b/helpers.go index 4a021d5b..012a964e 100644 --- a/helpers.go +++ b/helpers.go @@ -10,17 +10,12 @@ import ( "strings" "time" - // pb "github.com/alphauslabs/pubsub-proto/v1" // removed "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - // "github.com/alphauslabs/pubsub/app" // removed - "github.com/alphauslabs/pubsub/broadcast" ) -// PubSubWrapper wraps app.PubSub to allow defining methods on it - // validateTopicSubscription checks if subscription exists in storage func (s *server) validateTopicSubscription(subscriptionID string) error { subs, err := s.Storage.GetSubscribtionsForTopic(subscriptionID)