Skip to content

Commit

Permalink
PMM-13132 Another suggested refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
JiriCtvrtka committed Sep 20, 2024
1 parent f603530 commit 7fab00f
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 74 deletions.
46 changes: 9 additions & 37 deletions managed/models/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"net"
"net/url"
"os"
"slices"
"strconv"
"strings"

Expand Down Expand Up @@ -1177,50 +1176,17 @@ func SetupDB(ctx context.Context, sqlDB *sql.DB, params SetupDBParams) (*reform.

// EncryptDB encrypts a set of columns in a specific database and table.
func EncryptDB(tx *reform.TX, database string, itemsToEncrypt []encryption.Table) error {
return dbEncryption(tx, database, itemsToEncrypt, encryption.EncryptItems, true, addToEncryptedItems)
}

func encryptionExists(m map[string]bool, key string) bool {
return m[key]
}

func addToEncryptedItems(encryptedItems []string, items []string) []string {
return slices.Concat(encryptedItems, items)
return dbEncryption(tx, database, itemsToEncrypt, encryption.EncryptItems, true)
}

// DecryptDB decrypts a set of columns in a specific database and table.
func DecryptDB(tx *reform.TX, database string, itemsToEncrypt []encryption.Table) error {
return dbEncryption(tx, database, itemsToEncrypt, encryption.DecryptItems, false, removeFromEncryptedItems)
}

func encryptionNotExists(m map[string]bool, key string) bool {
return !encryptionExists(m, key)
}

func removeFromEncryptedItems(encryptedItems []string, items []string) []string {
res := []string{}
for _, encryptedItem := range encryptedItems {
exists := false
for _, item := range items {
if encryptedItem == item {
exists = true
}
}

if exists {
continue
}

res = append(res, encryptedItem)
}

return res
return dbEncryption(tx, database, itemsToEncrypt, encryption.DecryptItems, false)
}

func dbEncryption(tx *reform.TX, database string, items []encryption.Table,
encryptionHandler func(tx *reform.TX, tables []encryption.Table) error,
expectedState bool,
settingsHandler func(encryptedItems []string, items []string) []string,
) error {
if len(items) == 0 {
return nil
Expand Down Expand Up @@ -1263,8 +1229,14 @@ func dbEncryption(tx *reform.TX, database string, items []encryption.Table,
if err != nil {
return err
}

encryptedItems := []string{}
if expectedState {
encryptedItems = prepared
}

_, err = UpdateSettings(tx, &ChangeSettingsParams{
EncryptedItems: settingsHandler(settings.EncryptedItems, prepared),
EncryptedItems: encryptedItems,
})
if err != nil {
return err
Expand Down
4 changes: 3 additions & 1 deletion managed/models/settings_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ func UpdateSettings(q reform.DBTX, params *ChangeSettingsParams) (*Settings, err
settings.DefaultRoleID = *params.DefaultRoleID
}

settings.EncryptedItems = params.EncryptedItems
if settings.EncryptedItems != nil {
settings.EncryptedItems = params.EncryptedItems
}

err = SaveSettings(q, settings)
if err != nil {
Expand Down
20 changes: 2 additions & 18 deletions managed/utils/encryption/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,6 @@ func RotateEncryptionKey() error {
return nil
}

// RotateEncryptionKey is will backup old encryption key and generate new one.
func (e *Encryption) RotateEncryptionKey() error {
err := e.backupOldEncryptionKey()
if err != nil {
return err
}

enc := New()
e.Key = enc.Key
e.Path = enc.Path
e.Primitive = enc.Primitive

return nil
}

// RestoreOldEncryptionKey is a wrapper around DefaultEncryption.RestoreOldEncryptionKey.
func RestoreOldEncryptionKey() error {
err := os.Rename(fmt.Sprintf("%s_old.key", strings.TrimSuffix(encryptionKeyPath(), ".key")), encryptionKeyPath())
Expand All @@ -105,9 +90,8 @@ func RestoreOldEncryptionKey() error {
return nil
}

// RestoreOldEncryptionKey will restore previous backup during rotation.
func (e *Encryption) RestoreOldEncryptionKey() error {
err := os.Rename(fmt.Sprintf("%s_old.key", strings.TrimSuffix(e.Path, ".key")), e.Path)
func backupOldEncryptionKey() error {
err := os.Rename(encryptionKeyPath(), fmt.Sprintf("%s_old.key", strings.TrimSuffix(encryptionKeyPath(), ".key")))
if err != nil {
return err
}
Expand Down
18 changes: 0 additions & 18 deletions managed/utils/encryption/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,6 @@ func encryptionKeyPath() string {
return DefaultEncryptionKeyPath
}

func backupOldEncryptionKey() error {
err := os.Rename(encryptionKeyPath(), fmt.Sprintf("%s_old.key", strings.TrimSuffix(encryptionKeyPath(), ".key")))
if err != nil {
return err
}

return nil
}

func (e *Encryption) backupOldEncryptionKey() error {
err := os.Rename(e.Path, fmt.Sprintf("%s_old.key", strings.TrimSuffix(e.Path, ".key")))
if err != nil {
return err
}

return nil
}

func prepareRowPointers(rows *sql.Rows) ([]any, error) {
columnTypes, err := rows.ColumnTypes()
if err != nil {
Expand Down

0 comments on commit 7fab00f

Please sign in to comment.