Skip to content

Commit

Permalink
Refactoring code to accommodate NULL values in the OAuth email column…
Browse files Browse the repository at this point in the history
…, as the default database value is NULL
  • Loading branch information
diegosperes committed Apr 17, 2024
1 parent a7bdddb commit ff4c955
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 12 deletions.
4 changes: 2 additions & 2 deletions app/data/mock/account_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (s *accountStore) AddOauthAccount(accountID int, provider, providerID, emai

now := time.Now()
oauthAccount := &models.OauthAccount{
Email: email,
Email: &email,
AccountID: accountID,
Provider: provider,
ProviderID: providerID,
Expand All @@ -130,7 +130,7 @@ func (s *accountStore) UpdateOauthAccount(accountID int, provider, email string)

for i, oauthAccount := range oauthAccounts {
if oauthAccount.Provider == provider {
s.oauthAccountsByID[accountID][i].Email = email
s.oauthAccountsByID[accountID][i].Email = &email
return true, nil
}
}
Expand Down
31 changes: 30 additions & 1 deletion app/data/mysql/account_store_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package mysql_test

import (
"database/sql"
"testing"

"github.com/keratin/authn-server/app/data"
"github.com/keratin/authn-server/app/data/mysql"
"github.com/keratin/authn-server/app/data/testers"
"github.com/stretchr/testify/require"
Expand All @@ -11,10 +13,37 @@ import (
func TestAccountStore(t *testing.T) {
db, err := mysql.TestDB()
require.NoError(t, err)
store := &mysql.AccountStore{db}
var store data.AccountStore = &mysql.AccountStore{db}
for _, tester := range testers.AccountStoreTesters {
db.MustExec("TRUNCATE accounts")
db.MustExec("TRUNCATE oauth_accounts")
tester(t, store)
}

t.Run("handle oauth email with null value", func(t *testing.T) {
db := store.(interface {
Exec(query string, args ...interface{}) (sql.Result, error)
})

account, err := store.Create("migrated-user", []byte("old"))
require.NoError(t, err)

err = store.AddOauthAccount(account.ID, "provider", "provider_id", "", "token")
require.NoError(t, err)

result, err := db.Exec("UPDATE oauth_accounts SET email = NULL WHERE account_id = ?", account.ID)
require.NoError(t, err)

rowsAffected, err := result.RowsAffected()
require.NoError(t, err)

require.Equal(t, int64(1), rowsAffected)

oAccounts, err := store.GetOauthAccounts(account.ID)
require.NoError(t, err)

require.Len(t, oAccounts, 1)
require.True(t, oAccounts[0].Email == nil)
require.Equal(t, oAccounts[0].GetEmail(), "")
})
}
31 changes: 30 additions & 1 deletion app/data/postgres/account_store_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package postgres_test

import (
"database/sql"
"fmt"
"net/url"
"os"
"testing"

"github.com/jmoiron/sqlx"
"github.com/keratin/authn-server/app/data"
"github.com/keratin/authn-server/app/data/postgres"
"github.com/keratin/authn-server/app/data/testers"
"github.com/pkg/errors"
Expand Down Expand Up @@ -38,10 +40,37 @@ func newTestDB() (*sqlx.DB, error) {
func TestAccountStore(t *testing.T) {
db, err := newTestDB()
require.NoError(t, err)
store := &postgres.AccountStore{db}
var store data.AccountStore = &postgres.AccountStore{db}
for _, tester := range testers.AccountStoreTesters {
db.MustExec("TRUNCATE accounts")
db.MustExec("TRUNCATE oauth_accounts")
tester(t, store)
}

t.Run("handle oauth email with null value", func(t *testing.T) {
db := store.(interface {
Exec(query string, args ...interface{}) (sql.Result, error)
})

account, err := store.Create("migrated-user", []byte("old"))
require.NoError(t, err)

err = store.AddOauthAccount(account.ID, "provider", "provider_id", "", "token")
require.NoError(t, err)

result, err := db.Exec("UPDATE oauth_accounts SET email = NULL WHERE account_id = $1", account.ID)
require.NoError(t, err)

rowsAffected, err := result.RowsAffected()
require.NoError(t, err)

require.Equal(t, int64(1), rowsAffected)

oAccounts, err := store.GetOauthAccounts(account.ID)
require.NoError(t, err)

require.Len(t, oAccounts, 1)
require.True(t, oAccounts[0].Email == nil)
require.Equal(t, oAccounts[0].GetEmail(), "")
})
}
12 changes: 10 additions & 2 deletions app/models/oauth_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,20 @@ type OauthAccount struct {
AccountID int `db:"account_id"`
Provider string
ProviderID string `db:"provider_id"`
Email string `db:"email"`
Email *string `db:"email"`
AccessToken string `db:"access_token"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

func (a OauthAccount) GetEmail() string {
if a.Email != nil {
return *a.Email
}

return ""
}

func (o OauthAccount) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Provider string `json:"provider"`
Expand All @@ -24,6 +32,6 @@ func (o OauthAccount) MarshalJSON() ([]byte, error) {
}{
Provider: o.Provider,
ProviderID: o.ProviderID,
Email: o.Email,
Email: o.GetEmail(),
})
}
4 changes: 2 additions & 2 deletions app/services/account_getter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ func TestAccountGetter(t *testing.T) {
require.Equal(t, 2, len(oAccounts))
require.Equal(t, "test", oAccounts[0].Provider)
require.Equal(t, "ID1", oAccounts[0].ProviderID)
require.Equal(t, "email1", oAccounts[0].Email)
require.Equal(t, "email1", oAccounts[0].GetEmail())

require.Equal(t, "trial", oAccounts[1].Provider)
require.Equal(t, "ID2", oAccounts[1].ProviderID)
require.Equal(t, "email2", oAccounts[1].Email)
require.Equal(t, "email2", oAccounts[1].GetEmail())
})
}
2 changes: 1 addition & 1 deletion app/services/identity_reconciler.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func updateUserInfo(accountStore data.AccountStore, accountID int, providerName
continue
}

if oAccount.Email != providerUser.Email {
if oAccount.GetEmail() != providerUser.Email {
_, err = accountStore.UpdateOauthAccount(accountID, oAccount.Provider, providerUser.Email)
if err != nil {
return errors.Wrap(err, "UpdateOauthAccount")
Expand Down
4 changes: 2 additions & 2 deletions app/services/identity_reconciler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestIdentityReconciler(t *testing.T) {
oAccounts, err := store.GetOauthAccounts(account.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(oAccounts))
assert.Equal(t, email, oAccounts[0].Email)
assert.Equal(t, email, oAccounts[0].GetEmail())
})

t.Run("update oauth email when is outdated", func(t *testing.T) {
Expand All @@ -123,6 +123,6 @@ func TestIdentityReconciler(t *testing.T) {
oAccounts, err := store.GetOauthAccounts(account.ID)
assert.NoError(t, err)
assert.Equal(t, 1, len(oAccounts))
assert.Equal(t, email, oAccounts[0].Email)
assert.Equal(t, email, oAccounts[0].GetEmail())
})
}
2 changes: 1 addition & 1 deletion server/handlers/get_account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func assertGetAccountResponse(t *testing.T, res *http.Response, acc *models.Acco
oAccounts = append(oAccounts, map[string]interface{}{
"provider": oAcc.Provider,
"provider_account_id": oAcc.ProviderID,
"email": oAcc.Email,
"email": oAcc.GetEmail(),
})
}

Expand Down

0 comments on commit ff4c955

Please sign in to comment.