diff --git a/ratchet/ratchet.go b/ratchet/ratchet.go index a6c26f2..833d87c 100644 --- a/ratchet/ratchet.go +++ b/ratchet/ratchet.go @@ -238,17 +238,20 @@ func (r *Ratchet) CompleteKeyExchange(kx *KeyExchange, alice bool) error { } // Encrypt acts like append() but appends an encrypted version of msg to out. -func (r *Ratchet) Encrypt(out, msg []byte) []byte { +func (r *Ratchet) Encrypt(out, msg []byte) ([]byte, error) { if r.ratchet { r.randBytes(r.sendRatchetPrivate[:]) copy(r.sendHeaderKey[:], r.nextSendHeaderKey[:]) - var sharedKey, keyMaterial [32]byte - curve25519.ScalarMult(&sharedKey, &r.sendRatchetPrivate, &r.recvRatchetPublic) + var keyMaterial [32]byte + sharedKey, err := curve25519.X25519(r.sendRatchetPrivate[:], r.recvRatchetPublic[:]) + if err != nil { + return nil, err + } sha := sha256.New() sha.Write(rootKeyUpdateLabel) sha.Write(r.rootKey[:]) - sha.Write(sharedKey[:]) + sha.Write(sharedKey) sha.Sum(keyMaterial[:0]) h := hmac.New(sha256.New, keyMaterial[:]) deriveKey(&r.rootKey, rootKeyLabel, h) @@ -277,7 +280,7 @@ func (r *Ratchet) Encrypt(out, msg []byte) []byte { out = append(out, headerNonce[:]...) out = secretbox.Seal(out, header[:], &headerNonce, &r.sendHeaderKey) r.sendCount++ - return secretbox.Seal(out, msg, &messageNonce, &messageKey) + return secretbox.Seal(out, msg, &messageNonce, &messageKey), nil } // trySavedKeys tries to decrypt ciphertext using keys saved for missing messages. @@ -458,15 +461,17 @@ func (r *Ratchet) Decrypt(ciphertext []byte) ([]byte, error) { return nil, err } - var dhPublic, sharedKey, rootKey, chainKey, keyMaterial [32]byte + var dhPublic, rootKey, chainKey, keyMaterial [32]byte copy(dhPublic[:], header[8:]) - curve25519.ScalarMult(&sharedKey, &r.sendRatchetPrivate, &dhPublic) - + sharedKey, err := curve25519.X25519(r.sendRatchetPrivate[:], dhPublic[:]) + if err != nil { + return nil, err + } sha := sha256.New() sha.Write(rootKeyUpdateLabel) sha.Write(r.rootKey[:]) - sha.Write(sharedKey[:]) + sha.Write(sharedKey) var rootKeyHMAC hash.Hash diff --git a/ratchet/ratchet_test.go b/ratchet/ratchet_test.go index d97ef7d..c3dbb8c 100644 --- a/ratchet/ratchet_test.go +++ b/ratchet/ratchet_test.go @@ -97,7 +97,10 @@ func TestExchange(t *testing.T) { a, b := pairedRatchet(t) msg := []byte(strings.Repeat("test message", 1024*1024)) - encrypted := a.Encrypt(nil, msg) + encrypted, err := a.Encrypt(nil, msg) + if err != nil { + t.Fatal(err) + } result, err := b.Decrypt(encrypted) if err != nil { t.Fatal(err) @@ -163,8 +166,10 @@ func testScript(t *testing.T, script []scriptAction) { var msg [20]byte rand.Reader.Read(msg[:]) - encrypted := sender.Encrypt(nil, msg[:]) - + encrypted, err := sender.Encrypt(nil, msg[:]) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } switch action.result { case deliver: result, err := receiver.Decrypt(encrypted) @@ -339,7 +344,10 @@ func TestDiskState(t *testing.T) { a, b := pairedRatchet(t) msg := []byte("test message") - encrypted := a.Encrypt(nil, msg) + encrypted, err := a.Encrypt(nil, msg) + if err != nil { + t.Fatal(err) + } result, err := b.Decrypt(encrypted) if err != nil { t.Fatal(err) @@ -348,7 +356,10 @@ func TestDiskState(t *testing.T) { t.Fatalf("result doesn't match: %x vs %x", msg, result) } - encrypted = b.Encrypt(nil, msg) + encrypted, err = b.Encrypt(nil, msg) + if err != nil { + t.Fatal(err) + } result, err = a.Decrypt(encrypted) if err != nil { t.Fatal(err) @@ -414,7 +425,10 @@ func TestDiskState(t *testing.T) { } // send message to alice - encrypted = newBob.Encrypt(nil, msg) + encrypted, err = newBob.Encrypt(nil, msg) + if err != nil { + t.Fatal(err) + } result, err = newAlice.Decrypt(encrypted) if err != nil { t.Fatal(err) @@ -423,7 +437,10 @@ func TestDiskState(t *testing.T) { t.Fatalf("result doesn't match: %x vs %x", msg, result) } - encrypted = newAlice.Encrypt(nil, msg) + encrypted, err = newAlice.Encrypt(nil, msg) + if err != nil { + t.Fatal(err) + } result, err = newBob.Decrypt(encrypted) if err != nil { t.Fatal(err) diff --git a/zkclient/msg.go b/zkclient/msg.go index 9abfc23..ddd06cf 100644 --- a/zkclient/msg.go +++ b/zkclient/msg.go @@ -415,9 +415,7 @@ func (z *ZKC) crpc(r *ratchet.Ratchet, payload interface{}) ([]byte, error) { bb.Write(p) // encrypt CRPC - blob := r.Encrypt(nil, bb.Bytes()) - - return blob, nil + return r.Encrypt(nil, bb.Bytes()) } func (z *ZKC) pm(id [zkidentity.IdentitySize]byte, message string, mode rpc.MessageMode) error { diff --git a/zkidentity/zkidentity_test.go b/zkidentity/zkidentity_test.go index 4b8fece..64404f2 100644 --- a/zkidentity/zkidentity_test.go +++ b/zkidentity/zkidentity_test.go @@ -78,7 +78,10 @@ func TestEncryptDecryptSmall(t *testing.T) { a, b := pairedRatchet() msg := []byte("test message") - encrypted := a.Encrypt(nil, msg) + encrypted, err := a.Encrypt(nil, msg) + if err != nil { + t.Fatal(err) + } result, err := b.Decrypt(encrypted) if err != nil { t.Fatal(err) @@ -124,7 +127,10 @@ func TestEncryptDecryptLarge(t *testing.T) { a, b := pairedRatchet() msg := []byte(strings.Repeat("test message", 1024*1024)) - encrypted := a.Encrypt(nil, msg) + encrypted, err := a.Encrypt(nil, msg) + if err != nil { + t.Fatal(err) + } result, err := b.Decrypt(encrypted) if err != nil { t.Fatal(err)