diff --git a/client.go b/client.go index f333652..67fc55c 100644 --- a/client.go +++ b/client.go @@ -6,9 +6,14 @@ package asyncjobs import ( "context" + "crypto/ed25519" + "crypto/rand" + "encoding/hex" "errors" "fmt" + "io" "net/http" + "os" "time" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -41,6 +46,11 @@ func NewClient(opts ...ClientOpt) (*Client, error) { } } + err = copts.validate() + if err != nil { + return nil, err + } + c := &Client{opts: copts, log: copts.logger} c.storage, err = newJetStreamStorage(copts.nc, copts.retryPolicy, c.log) if err != nil { @@ -85,7 +95,17 @@ func (c *Client) Run(ctx context.Context, router *Mux) error { // LoadTaskByID loads a task from the backend using its ID func (c *Client) LoadTaskByID(id string) (*Task, error) { - return c.storage.LoadTaskByID(id) + task, err := c.storage.LoadTaskByID(id) + if err != nil { + return nil, err + } + + err = c.verifyTaskSignature(task) + if err != nil { + return nil, err + } + + return task, nil } // RetryTaskByID will retry a task, first removing an entry from the Work Queue if already there @@ -95,9 +115,130 @@ func (c *Client) RetryTaskByID(ctx context.Context, id string) error { // EnqueueTask adds a task to the named queue which must already exist func (c *Client) EnqueueTask(ctx context.Context, task *Task) error { + task.Queue = c.opts.queue.Name + + err := c.signTask(task) + if err != nil { + return err + } + return c.opts.queue.enqueueTask(ctx, task) } +func (c *Client) verifyTaskSignature(task *Task) error { + switch { + case !c.opts.optionalTaskSignatures && task.Signature == "": + return ErrTaskNotSigned + + case task.Signature == "": + return nil + } + + var pubKey ed25519.PublicKey + + switch { + case c.opts.publicKey != nil: + pubKey = c.opts.publicKey + + case c.opts.publicKeyFile != "": + kf, err := os.ReadFile(c.opts.publicKeyFile) + if err != nil { + return err + } + + kb, err := hex.DecodeString(string(kf)) + if err != nil { + return err + } + + if len(kb) != ed25519.PublicKeySize { + return fmt.Errorf("invalid public key") + } + + pubKey = kb + + case c.opts.seedFile != "": + sf, err := os.ReadFile(c.opts.seedFile) + if err != nil { + return err + } + + sb, err := hex.DecodeString(string(sf)) + if err != nil { + return err + } + + if len(sb) != ed25519.SeedSize { + return fmt.Errorf("invalid seed length") + } + + pk := ed25519.NewKeyFromSeed(sb) + defer func() { + io.ReadFull(rand.Reader, pk[:]) + io.ReadFull(rand.Reader, sb[:]) + io.ReadFull(rand.Reader, sf[:]) + }() + + pubKey = pk.Public().(ed25519.PublicKey) + + default: + if !c.opts.optionalTaskSignatures { + return fmt.Errorf("no task verification keys configured") + } + + return nil + } + + msg, err := task.signatureMessage() + if err != nil { + return fmt.Errorf("%w: %w", ErrTaskSignatureInvalid, err) + } + + sig, err := hex.DecodeString(task.Signature) + if err != nil { + return fmt.Errorf("%w: %w", ErrTaskSignatureInvalid, err) + } + + if !ed25519.Verify(pubKey, msg, sig) { + return ErrTaskSignatureInvalid + } + + return nil +} + +func (c *Client) signTask(task *Task) error { + switch { + case c.opts.privateKey != nil: + return task.sign(c.opts.privateKey) + + case c.opts.seedFile != "": + sf, err := os.ReadFile(c.opts.seedFile) + if err != nil { + return err + } + + sb, err := hex.DecodeString(string(sf)) + if err != nil { + return err + } + + if len(sb) != ed25519.SeedSize { + return fmt.Errorf("invalid seed length") + } + + pk := ed25519.NewKeyFromSeed(sb) + defer func() { + io.ReadFull(rand.Reader, pk[:]) + io.ReadFull(rand.Reader, sb[:]) + io.ReadFull(rand.Reader, sf[:]) + }() + + return task.sign(pk) + } + + return nil +} + // StorageAdmin access admin features of the storage backend func (c *Client) StorageAdmin() StorageAdmin { return c.storage.(*jetStreamStorage) diff --git a/client_options.go b/client_options.go index 21a6d5d..a7cd061 100644 --- a/client_options.go +++ b/client_options.go @@ -5,6 +5,7 @@ package asyncjobs import ( + "crypto/ed25519" "fmt" "time" @@ -14,16 +15,21 @@ import ( // ClientOpts configures the client type ClientOpts struct { - concurrency int - replicas int - queue *Queue - taskRetention time.Duration - retryPolicy RetryPolicyProvider - memoryStore bool - statsPort int - logger Logger - skipPrepare bool - discard []TaskState + concurrency int + replicas int + queue *Queue + taskRetention time.Duration + retryPolicy RetryPolicyProvider + memoryStore bool + statsPort int + logger Logger + skipPrepare bool + discard []TaskState + privateKey ed25519.PrivateKey + seedFile string + publicKey ed25519.PublicKey + publicKeyFile string + optionalTaskSignatures bool nc *nats.Conn } @@ -31,6 +37,20 @@ type ClientOpts struct { // ClientOpt configures the client type ClientOpt func(opts *ClientOpts) error +func (c *ClientOpts) validate() error { + if c.privateKey != nil && c.seedFile != "" { + return fmt.Errorf("cannot set both private key and seed file") + } + if c.publicKey != nil && c.publicKeyFile != "" { + return fmt.Errorf("cannot set both public key and public key file") + } + if c.seedFile != "" && (c.publicKeyFile != "" || c.publicKey != nil) { + return fmt.Errorf("cannot set a seedfile and public key information") + } + + return nil +} + // DiscardTaskStates configures the client to discard Tasks that reach a final state in the list of supplied TaskState func DiscardTaskStates(states ...TaskState) ClientOpt { return func(opts *ClientOpts) error { @@ -235,3 +255,43 @@ func TaskRetention(r time.Duration) ClientOpt { return nil } } + +// TaskSigningKey sets a key used to sign tasks, will be kept in memory for the duration +func TaskSigningKey(pk ed25519.PrivateKey) ClientOpt { + return func(opts *ClientOpts) error { + opts.privateKey = pk + return nil + } +} + +// TaskSigningSeedFile sets the path to a file holding a ed25519 seed, will be used for signing and verification and wiped between uses +func TaskSigningSeedFile(sf string) ClientOpt { + return func(opts *ClientOpts) error { + opts.seedFile = sf + return nil + } +} + +// TaskVerificationKey sets a public key used to verify tasks +func TaskVerificationKey(pk ed25519.PublicKey) ClientOpt { + return func(opts *ClientOpts) error { + opts.publicKey = pk + return nil + } +} + +// TaskVerificationKeyFile sets the path to a file holding a ed25519 public key, will be used for verification of tasks +func TaskVerificationKeyFile(sf string) ClientOpt { + return func(opts *ClientOpts) error { + opts.publicKeyFile = sf + return nil + } +} + +// TaskSignaturesOptional indicates that only signed tasks can be loaded +func TaskSignaturesOptional() ClientOpt { + return func(opts *ClientOpts) error { + opts.optionalTaskSignatures = true + return nil + } +} diff --git a/client_test.go b/client_test.go index 47b772f..5fe2253 100644 --- a/client_test.go +++ b/client_test.go @@ -6,6 +6,7 @@ package asyncjobs import ( "context" + "crypto/ed25519" "encoding/json" "fmt" "log" @@ -70,6 +71,72 @@ var _ = Describe("Client", func() { log.SetOutput(GinkgoWriter) }) + Describe("SignedTasks", func() { + var pubk ed25519.PublicKey + var prik ed25519.PrivateKey + var err error + + BeforeEach(func() { + pubk, prik, err = ed25519.GenerateKey(nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("Should sign messages", func() { + withJetStream(func(nc *nats.Conn, mgr *jsm.Manager) { + client, err := NewClient(NatsConn(nc), TaskSigningKey(prik), TaskVerificationKey(pubk)) + Expect(err).ToNot(HaveOccurred()) + + task, err := NewTask("x", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(client.EnqueueTask(context.Background(), task)).ToNot(HaveOccurred()) + + Expect(task.Signature).To(HaveLen(128)) + + task, err = client.LoadTaskByID(task.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(task.Signature).To(HaveLen(128)) + }) + }) + + It("Should verify loaded tasks", func() { + withJetStream(func(nc *nats.Conn, mgr *jsm.Manager) { + client, err := NewClient(NatsConn(nc), TaskSigningKey(prik), TaskVerificationKey(pubk)) + Expect(err).ToNot(HaveOccurred()) + + task, err := NewTask("x", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(client.EnqueueTask(context.Background(), task)).ToNot(HaveOccurred()) + + client.opts.publicKey, _, err = ed25519.GenerateKey(nil) + Expect(err).ToNot(HaveOccurred()) + + task, err = client.LoadTaskByID(task.ID) + Expect(err).To(MatchError(ErrTaskSignatureInvalid)) + Expect(task).To(BeNil()) + }) + }) + + It("Should support only loading signed messages", func() { + withJetStream(func(nc *nats.Conn, mgr *jsm.Manager) { + client, err := NewClient(NatsConn(nc), TaskVerificationKey(pubk)) + Expect(err).ToNot(HaveOccurred()) + + task, err := NewTask("x", nil) + Expect(err).ToNot(HaveOccurred()) + Expect(client.EnqueueTask(context.Background(), task)).ToNot(HaveOccurred()) + + id := task.ID + task, err = client.LoadTaskByID(id) + Expect(err).To(MatchError(ErrTaskNotSigned)) + + client.opts.optionalTaskSignatures = true + task, err = client.LoadTaskByID(id) + Expect(err).ToNot(HaveOccurred()) + Expect(task).ToNot(BeNil()) + }) + }) + }) + Describe("DiscardTaskStatesByName", func() { It("Should correctly parse state names", func() { opts := &ClientOpts{} diff --git a/errors.go b/errors.go index 940c02e..abc6565 100644 --- a/errors.go +++ b/errors.go @@ -40,6 +40,10 @@ var ( ErrTaskAlreadySigned = fmt.Errorf("task is already signed") // ErrTaskSignatureRequiresQueue indicates a signature request was made without configuring the queue name for a task ErrTaskSignatureRequiresQueue = fmt.Errorf("signing a task requires the queue to be set") + // ErrTaskNotSigned indicates a task was loaded that had no signature while signatures are required + ErrTaskNotSigned = fmt.Errorf("task is not signed") + // ErrTaskSignatureInvalid indicates a signature did not pass validation + ErrTaskSignatureInvalid = fmt.Errorf("invalid task signature") // ErrNoHandlerForTaskType indicates that a task could not be handled by any known handlers ErrNoHandlerForTaskType = fmt.Errorf("no handler for task type") diff --git a/storage.go b/storage.go index c84ebca..4d2ed6b 100644 --- a/storage.go +++ b/storage.go @@ -212,11 +212,6 @@ func (s *jetStreamStorage) EnqueueTask(ctx context.Context, queue *Queue, task * task.Queue = queue.Name - err = task.Sign() - if err != nil { - return err - } - err = s.SaveTaskState(ctx, task, true) if err != nil { return err diff --git a/storage_test.go b/storage_test.go index 8432651..6a23086 100644 --- a/storage_test.go +++ b/storage_test.go @@ -7,7 +7,6 @@ package asyncjobs import ( "bytes" "context" - "crypto/ed25519" "encoding/json" "fmt" "log" @@ -915,16 +914,12 @@ var _ = Describe("Storage", func() { err = storage.PrepareQueue(q, 1, true) Expect(err).ToNot(HaveOccurred()) - _, pri, err := ed25519.GenerateKey(nil) - Expect(err).ToNot(HaveOccurred()) - - task, err := NewTask("ginkgo", nil, TaskSigner(pri)) + task, err := NewTask("ginkgo", nil) Expect(err).ToNot(HaveOccurred()) Expect(task.Signature).To(HaveLen(0)) err = storage.EnqueueTask(ctx, q, task) Expect(err).ToNot(HaveOccurred()) - Expect(task.Signature).To(HaveLen(128)) msg, err := storage.qStreams[q.Name].ReadMessage(1) Expect(err).ToNot(HaveOccurred()) @@ -935,9 +930,8 @@ var _ = Describe("Storage", func() { Expect(item.Kind).To(Equal(TaskItem)) Expect(item.JobID).To(Equal(task.ID)) - t, err := storage.LoadTaskByID(task.ID) + _, err = storage.LoadTaskByID(task.ID) Expect(err).ToNot(HaveOccurred()) - Expect(t.Signature).To(Equal(task.Signature)) }) }) }) diff --git a/task.go b/task.go index b2ea5b8..ab7d15a 100644 --- a/task.go +++ b/task.go @@ -6,11 +6,9 @@ package asyncjobs import ( "crypto/ed25519" - "crypto/rand" "encoding/hex" "encoding/json" "fmt" - "io" "sync" "time" @@ -96,7 +94,6 @@ type Task struct { // Signature is an ed25519 signature of key properties Signature string `json:"signature,omitempty"` - sigPk ed25519.PrivateKey storageOptions any mu sync.Mutex } @@ -171,28 +168,17 @@ func (t *Task) HasDependencies() bool { return len(t.Dependencies) > 0 } -func (t *Task) Sign() error { - if t.sigPk == nil { - return nil - } - +func (t *Task) sign(pk ed25519.PrivateKey) error { if t.Signature != "" { return ErrTaskAlreadySigned } - if len(t.sigPk) != ed25519.PrivateKeySize { - return ErrInvalidPrivateKey - } - msg, err := t.signatureMessage() if err != nil { return err } - t.Signature = hex.EncodeToString(ed25519.Sign(t.sigPk, msg)) - - io.ReadFull(rand.Reader, t.sigPk[:]) - t.sigPk = nil + t.Signature = hex.EncodeToString(ed25519.Sign(pk, msg)) return nil } @@ -277,11 +263,3 @@ func TaskRequiresDependencyResults() TaskOpt { return nil } } - -// TaskSigner signs the task using the given private key -func TaskSigner(key ed25519.PrivateKey) TaskOpt { - return func(t *Task) error { - t.sigPk = key - return nil - } -} diff --git a/task_test.go b/task_test.go index 624519b..78fbafd 100644 --- a/task_test.go +++ b/task_test.go @@ -1,8 +1,6 @@ package asyncjobs import ( - "crypto/ed25519" - "encoding/hex" "time" . "github.com/onsi/ginkgo/v2" @@ -35,29 +33,19 @@ var _ = Describe("Tasks", func() { Expect(task.LoadDependencies).To(BeTrue()) Expect(task.MaxTries).To(Equal(DefaultMaxTries)) - pub, pri, err := ed25519.GenerateKey(nil) - Expect(err).ToNot(HaveOccurred()) - // without dependencies, should be new - task, err = NewTask("test", payload, TaskDeadline(deadline), TaskMaxTries(10), TaskSigner(pri)) + task, err = NewTask("test", payload, TaskDeadline(deadline), TaskMaxTries(10)) Expect(err).ToNot(HaveOccurred()) Expect(task.State).To(Equal(TaskStateNew)) Expect(task.LoadDependencies).To(BeFalse()) Expect(task.MaxTries).To(Equal(10)) - Expect(task.sigPk).To(Equal(pri)) - - Expect(task.Sign()).To(MatchError(ErrTaskSignatureRequiresQueue)) - task.Queue = "x" - Expect(task.Sign()).To(Succeed()) - Expect(task.Signature).ToNot(HaveLen(0)) msg, err := task.signatureMessage() + Expect(err).To(MatchError(ErrTaskSignatureRequiresQueue)) + task.Queue = "x" + msg, err = task.signatureMessage() Expect(err).ToNot(HaveOccurred()) Expect(msg).To(HaveLen(77)) - sig, err := hex.DecodeString(task.Signature) - Expect(err).ToNot(HaveOccurred()) - Expect(sig).To(HaveLen(64)) - Expect(ed25519.Verify(pub, msg, sig)).To(BeTrue()) }) }) })