Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: R.I.Pienaar <[email protected]>
  • Loading branch information
ripienaar committed May 8, 2023
1 parent ced4df5 commit cf62095
Show file tree
Hide file tree
Showing 8 changed files with 291 additions and 64 deletions.
143 changes: 142 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@ package asyncjobs

import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"io"
"net/http"
"os"
"time"

"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -41,6 +46,11 @@ func NewClient(opts ...ClientOpt) (*Client, error) {
}
}

err = copts.validate()
if err != nil {
return nil, err
}

c := &Client{opts: copts, log: copts.logger}
c.storage, err = newJetStreamStorage(copts.nc, copts.retryPolicy, c.log)
if err != nil {
Expand Down Expand Up @@ -85,7 +95,17 @@ func (c *Client) Run(ctx context.Context, router *Mux) error {

// LoadTaskByID loads a task from the backend using its ID
func (c *Client) LoadTaskByID(id string) (*Task, error) {
return c.storage.LoadTaskByID(id)
task, err := c.storage.LoadTaskByID(id)
if err != nil {
return nil, err
}

err = c.verifyTaskSignature(task)
if err != nil {
return nil, err
}

return task, nil
}

// RetryTaskByID will retry a task, first removing an entry from the Work Queue if already there
Expand All @@ -95,9 +115,130 @@ func (c *Client) RetryTaskByID(ctx context.Context, id string) error {

// EnqueueTask adds a task to the named queue which must already exist
func (c *Client) EnqueueTask(ctx context.Context, task *Task) error {
task.Queue = c.opts.queue.Name

err := c.signTask(task)
if err != nil {
return err
}

return c.opts.queue.enqueueTask(ctx, task)
}

func (c *Client) verifyTaskSignature(task *Task) error {
switch {
case !c.opts.optionalTaskSignatures && task.Signature == "":
return ErrTaskNotSigned

case task.Signature == "":
return nil
}

var pubKey ed25519.PublicKey

switch {
case c.opts.publicKey != nil:
pubKey = c.opts.publicKey

case c.opts.publicKeyFile != "":
kf, err := os.ReadFile(c.opts.publicKeyFile)
if err != nil {
return err
}

kb, err := hex.DecodeString(string(kf))
if err != nil {
return err
}

if len(kb) != ed25519.PublicKeySize {
return fmt.Errorf("invalid public key")
}

pubKey = kb

case c.opts.seedFile != "":
sf, err := os.ReadFile(c.opts.seedFile)
if err != nil {
return err
}

sb, err := hex.DecodeString(string(sf))
if err != nil {
return err
}

if len(sb) != ed25519.SeedSize {
return fmt.Errorf("invalid seed length")
}

pk := ed25519.NewKeyFromSeed(sb)
defer func() {
io.ReadFull(rand.Reader, pk[:])
io.ReadFull(rand.Reader, sb[:])
io.ReadFull(rand.Reader, sf[:])
}()

pubKey = pk.Public().(ed25519.PublicKey)

default:
if !c.opts.optionalTaskSignatures {
return fmt.Errorf("no task verification keys configured")
}

return nil
}

msg, err := task.signatureMessage()
if err != nil {
return fmt.Errorf("%w: %w", ErrTaskSignatureInvalid, err)
}

sig, err := hex.DecodeString(task.Signature)
if err != nil {
return fmt.Errorf("%w: %w", ErrTaskSignatureInvalid, err)
}

if !ed25519.Verify(pubKey, msg, sig) {
return ErrTaskSignatureInvalid
}

return nil
}

func (c *Client) signTask(task *Task) error {
switch {
case c.opts.privateKey != nil:
return task.sign(c.opts.privateKey)

case c.opts.seedFile != "":
sf, err := os.ReadFile(c.opts.seedFile)
if err != nil {
return err
}

sb, err := hex.DecodeString(string(sf))
if err != nil {
return err
}

if len(sb) != ed25519.SeedSize {
return fmt.Errorf("invalid seed length")
}

pk := ed25519.NewKeyFromSeed(sb)
defer func() {
io.ReadFull(rand.Reader, pk[:])
io.ReadFull(rand.Reader, sb[:])
io.ReadFull(rand.Reader, sf[:])
}()

return task.sign(pk)
}

return nil
}

// StorageAdmin access admin features of the storage backend
func (c *Client) StorageAdmin() StorageAdmin {
return c.storage.(*jetStreamStorage)
Expand Down
80 changes: 70 additions & 10 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package asyncjobs

import (
"crypto/ed25519"
"fmt"
"time"

Expand All @@ -14,23 +15,42 @@ import (

// ClientOpts configures the client
type ClientOpts struct {
concurrency int
replicas int
queue *Queue
taskRetention time.Duration
retryPolicy RetryPolicyProvider
memoryStore bool
statsPort int
logger Logger
skipPrepare bool
discard []TaskState
concurrency int
replicas int
queue *Queue
taskRetention time.Duration
retryPolicy RetryPolicyProvider
memoryStore bool
statsPort int
logger Logger
skipPrepare bool
discard []TaskState
privateKey ed25519.PrivateKey
seedFile string
publicKey ed25519.PublicKey
publicKeyFile string
optionalTaskSignatures bool

nc *nats.Conn
}

// ClientOpt configures the client
type ClientOpt func(opts *ClientOpts) error

func (c *ClientOpts) validate() error {
if c.privateKey != nil && c.seedFile != "" {
return fmt.Errorf("cannot set both private key and seed file")
}
if c.publicKey != nil && c.publicKeyFile != "" {
return fmt.Errorf("cannot set both public key and public key file")
}
if c.seedFile != "" && (c.publicKeyFile != "" || c.publicKey != nil) {
return fmt.Errorf("cannot set a seedfile and public key information")
}

return nil
}

// DiscardTaskStates configures the client to discard Tasks that reach a final state in the list of supplied TaskState
func DiscardTaskStates(states ...TaskState) ClientOpt {
return func(opts *ClientOpts) error {
Expand Down Expand Up @@ -235,3 +255,43 @@ func TaskRetention(r time.Duration) ClientOpt {
return nil
}
}

// TaskSigningKey sets a key used to sign tasks, will be kept in memory for the duration
func TaskSigningKey(pk ed25519.PrivateKey) ClientOpt {
return func(opts *ClientOpts) error {
opts.privateKey = pk
return nil
}
}

// TaskSigningSeedFile sets the path to a file holding a ed25519 seed, will be used for signing and verification and wiped between uses
func TaskSigningSeedFile(sf string) ClientOpt {
return func(opts *ClientOpts) error {
opts.seedFile = sf
return nil
}
}

// TaskVerificationKey sets a public key used to verify tasks
func TaskVerificationKey(pk ed25519.PublicKey) ClientOpt {
return func(opts *ClientOpts) error {
opts.publicKey = pk
return nil
}
}

// TaskVerificationKeyFile sets the path to a file holding a ed25519 public key, will be used for verification of tasks
func TaskVerificationKeyFile(sf string) ClientOpt {
return func(opts *ClientOpts) error {
opts.publicKeyFile = sf
return nil
}
}

// TaskSignaturesOptional indicates that only signed tasks can be loaded
func TaskSignaturesOptional() ClientOpt {
return func(opts *ClientOpts) error {
opts.optionalTaskSignatures = true
return nil
}
}
67 changes: 67 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package asyncjobs

import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"log"
Expand Down Expand Up @@ -70,6 +71,72 @@ var _ = Describe("Client", func() {
log.SetOutput(GinkgoWriter)
})

Describe("SignedTasks", func() {
var pubk ed25519.PublicKey
var prik ed25519.PrivateKey
var err error

BeforeEach(func() {
pubk, prik, err = ed25519.GenerateKey(nil)
Expect(err).ToNot(HaveOccurred())
})

It("Should sign messages", func() {
withJetStream(func(nc *nats.Conn, mgr *jsm.Manager) {
client, err := NewClient(NatsConn(nc), TaskSigningKey(prik), TaskVerificationKey(pubk))
Expect(err).ToNot(HaveOccurred())

task, err := NewTask("x", nil)
Expect(err).ToNot(HaveOccurred())
Expect(client.EnqueueTask(context.Background(), task)).ToNot(HaveOccurred())

Expect(task.Signature).To(HaveLen(128))

task, err = client.LoadTaskByID(task.ID)
Expect(err).ToNot(HaveOccurred())
Expect(task.Signature).To(HaveLen(128))
})
})

It("Should verify loaded tasks", func() {
withJetStream(func(nc *nats.Conn, mgr *jsm.Manager) {
client, err := NewClient(NatsConn(nc), TaskSigningKey(prik), TaskVerificationKey(pubk))
Expect(err).ToNot(HaveOccurred())

task, err := NewTask("x", nil)
Expect(err).ToNot(HaveOccurred())
Expect(client.EnqueueTask(context.Background(), task)).ToNot(HaveOccurred())

client.opts.publicKey, _, err = ed25519.GenerateKey(nil)
Expect(err).ToNot(HaveOccurred())

task, err = client.LoadTaskByID(task.ID)
Expect(err).To(MatchError(ErrTaskSignatureInvalid))
Expect(task).To(BeNil())
})
})

It("Should support only loading signed messages", func() {
withJetStream(func(nc *nats.Conn, mgr *jsm.Manager) {
client, err := NewClient(NatsConn(nc), TaskVerificationKey(pubk))
Expect(err).ToNot(HaveOccurred())

task, err := NewTask("x", nil)
Expect(err).ToNot(HaveOccurred())
Expect(client.EnqueueTask(context.Background(), task)).ToNot(HaveOccurred())

id := task.ID
task, err = client.LoadTaskByID(id)
Expect(err).To(MatchError(ErrTaskNotSigned))

client.opts.optionalTaskSignatures = true
task, err = client.LoadTaskByID(id)
Expect(err).ToNot(HaveOccurred())
Expect(task).ToNot(BeNil())
})
})
})

Describe("DiscardTaskStatesByName", func() {
It("Should correctly parse state names", func() {
opts := &ClientOpts{}
Expand Down
4 changes: 4 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ var (
ErrTaskAlreadySigned = fmt.Errorf("task is already signed")
// ErrTaskSignatureRequiresQueue indicates a signature request was made without configuring the queue name for a task
ErrTaskSignatureRequiresQueue = fmt.Errorf("signing a task requires the queue to be set")
// ErrTaskNotSigned indicates a task was loaded that had no signature while signatures are required
ErrTaskNotSigned = fmt.Errorf("task is not signed")
// ErrTaskSignatureInvalid indicates a signature did not pass validation
ErrTaskSignatureInvalid = fmt.Errorf("invalid task signature")

// ErrNoHandlerForTaskType indicates that a task could not be handled by any known handlers
ErrNoHandlerForTaskType = fmt.Errorf("no handler for task type")
Expand Down
5 changes: 0 additions & 5 deletions storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,6 @@ func (s *jetStreamStorage) EnqueueTask(ctx context.Context, queue *Queue, task *

task.Queue = queue.Name

err = task.Sign()
if err != nil {
return err
}

err = s.SaveTaskState(ctx, task, true)
if err != nil {
return err
Expand Down
Loading

0 comments on commit cf62095

Please sign in to comment.