Skip to content

Commit

Permalink
feat: Use encrypted cookie to store OAuth2 state nonce (instead of re…
Browse files Browse the repository at this point in the history
…dis) (argoproj#8241)

feat: Use encrypted cookie to store OAuth2 state nonce (instead of redis) (argoproj#8241)

Signed-off-by: Alexander Matyushentsev <[email protected]>
  • Loading branch information
Alexander Matyushentsev authored Jan 26, 2022
1 parent 0aeda43 commit ecc3ab3
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 80 deletions.
4 changes: 4 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ const (
ArgoCDUserAgentName = "argocd-client"
// AuthCookieName is the HTTP cookie name where we store our auth token
AuthCookieName = "argocd.token"
// StateCookieName is the HTTP cookie name that holds temporary nonce tokens for CSRF protection
StateCookieName = "argocd.oauthstate"
// StateCookieMaxAge is the maximum age of the oauth state cookie
StateCookieMaxAge = time.Minute * 5

// ChangePasswordSSOTokenMaxAge is the max token age for password change operation
ChangePasswordSSOTokenMaxAge = time.Minute * 5
Expand Down
17 changes: 0 additions & 17 deletions server/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
cacheutil "github.com/argoproj/argo-cd/v2/util/cache"
appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate"
"github.com/argoproj/argo-cd/v2/util/env"
"github.com/argoproj/argo-cd/v2/util/oidc"
)

var ErrCacheMiss = appstatecache.ErrCacheMiss
Expand All @@ -25,8 +24,6 @@ type Cache struct {
loginAttemptsExpiration time.Duration
}

var _ oidc.OIDCStateStorage = &Cache{}

func NewCache(
cache *appstatecache.Cache,
connectionStatusCacheExpiration time.Duration,
Expand Down Expand Up @@ -91,20 +88,6 @@ func (c *Cache) SetClusterInfo(server string, res *appv1.ClusterInfo) error {
return c.cache.SetClusterInfo(server, res)
}

func oidcStateKey(key string) string {
return fmt.Sprintf("oidc|%s", key)
}

func (c *Cache) GetOIDCState(key string) (*oidc.OIDCState, error) {
res := oidc.OIDCState{}
err := c.cache.GetItem(oidcStateKey(key), &res)
return &res, err
}

func (c *Cache) SetOIDCState(key string, state *oidc.OIDCState) error {
return c.cache.SetItem(oidcStateKey(key), state, c.oidcCacheExpiration, state == nil)
}

func (c *Cache) GetCache() *cacheutil.Cache {
return c.cache.Cache
}
18 changes: 0 additions & 18 deletions server/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
. "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
cacheutil "github.com/argoproj/argo-cd/v2/util/cache"
appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate"
"github.com/argoproj/argo-cd/v2/util/oidc"
)

type fixtures struct {
Expand Down Expand Up @@ -46,23 +45,6 @@ func TestCache_GetRepoConnectionState(t *testing.T) {
assert.Equal(t, ConnectionState{Status: "my-state"}, value)
}

func TestCache_GetOIDCState(t *testing.T) {
cache := newFixtures().Cache
// cache miss
_, err := cache.GetOIDCState("my-key")
assert.Equal(t, ErrCacheMiss, err)
// populate cache
err = cache.SetOIDCState("my-key", &oidc.OIDCState{ReturnURL: "my-return-url"})
assert.NoError(t, err)
//cache miss
_, err = cache.GetOIDCState("other-key")
assert.Equal(t, ErrCacheMiss, err)
// cache hit
value, err := cache.GetOIDCState("my-key")
assert.NoError(t, err)
assert.Equal(t, &oidc.OIDCState{ReturnURL: "my-return-url"}, value)
}

func TestAddCacheFlagsToCmd(t *testing.T) {
cache, err := AddCacheFlagsToCmd(&cobra.Command{})()
assert.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ func (a *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
tlsConfig := a.settings.TLSConfig()
tlsConfig.InsecureSkipVerify = true
}
a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.Cache, a.DexServerAddr, a.BaseHRef)
a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.BaseHRef)
errors.CheckError(err)
mux.HandleFunc(common.LoginEndpoint, a.ssoClientApp.HandleLogin)
mux.HandleFunc(common.CallbackEndpoint, a.ssoClientApp.HandleCallback)
Expand Down
62 changes: 62 additions & 0 deletions util/crypto/crypto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package crypto

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"errors"
"io"

"golang.org/x/crypto/scrypt"
)

// KeyFromPassphrase generates 32 byte key from the passphrase
func KeyFromPassphrase(passphrase string) ([]byte, error) {
// salt is just a hash of a passphrase (effectively no salt)
salt := sha256.Sum256([]byte(passphrase))
// These defaults will consume approximately 16MB of memory (128 * r * N)
const N = 16384
const r = 8
return scrypt.Key([]byte(passphrase), salt[:], N, r, 1, 32)
}

// Encrypt encrypts the given data with the given passphrase.
func Encrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext, nil
}

// Decrypt decrypts the given data using the given passphrase.
func Decrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("data length is less than nonce size")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
43 changes: 43 additions & 0 deletions util/crypto/crypto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package crypto

import (
"crypto/rand"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newKey() ([]byte, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
b = nil
}
return b, err
}

func TestEncryptDecrypt_Successful(t *testing.T) {
key, err := newKey()
require.NoError(t, err)
encrypted, err := Encrypt([]byte("test"), key)
require.NoError(t, err)

decrypted, err := Decrypt(encrypted, key)
require.NoError(t, err)

assert.Equal(t, "test", string(decrypted))
}

func TestEncryptDecrypt_Failed(t *testing.T) {
key, err := newKey()
require.NoError(t, err)
encrypted, err := Encrypt([]byte("test"), key)
require.NoError(t, err)

wrongKey, err := newKey()
require.NoError(t, err)

_, err = Decrypt(encrypted, wrongKey)
assert.Error(t, err)
}
Loading

0 comments on commit ecc3ab3

Please sign in to comment.