Skip to content

Commit

Permalink
PMM-13132 Suggested refactor.
Browse files Browse the repository at this point in the history
  • Loading branch information
JiriCtvrtka committed Sep 20, 2024
1 parent 89692ac commit f603530
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions managed/models/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ func SetupDB(ctx context.Context, sqlDB *sql.DB, params SetupDBParams) (*reform.

// EncryptDB encrypts a set of columns in a specific database and table.
func EncryptDB(tx *reform.TX, database string, itemsToEncrypt []encryption.Table) error {
return dbEncryption(tx, database, itemsToEncrypt, encryption.EncryptItems, encryptionExists, addToEncryptedItems)
return dbEncryption(tx, database, itemsToEncrypt, encryption.EncryptItems, true, addToEncryptedItems)
}

func encryptionExists(m map[string]bool, key string) bool {
Expand All @@ -1190,7 +1190,7 @@ func addToEncryptedItems(encryptedItems []string, items []string) []string {

// DecryptDB decrypts a set of columns in a specific database and table.
func DecryptDB(tx *reform.TX, database string, itemsToEncrypt []encryption.Table) error {
return dbEncryption(tx, database, itemsToEncrypt, encryption.DecryptItems, encryptionNotExists, removeFromEncryptedItems)
return dbEncryption(tx, database, itemsToEncrypt, encryption.DecryptItems, false, removeFromEncryptedItems)
}

func encryptionNotExists(m map[string]bool, key string) bool {
Expand Down Expand Up @@ -1219,7 +1219,7 @@ func removeFromEncryptedItems(encryptedItems []string, items []string) []string

func dbEncryption(tx *reform.TX, database string, items []encryption.Table,
encryptionHandler func(tx *reform.TX, tables []encryption.Table) error,
checkHandler func(m map[string]bool, key string) bool,
expectedState bool,
settingsHandler func(encryptedItems []string, items []string) []string,
) error {
if len(items) == 0 {
Expand All @@ -1241,7 +1241,7 @@ func dbEncryption(tx *reform.TX, database string, items []encryption.Table,
columns := []encryption.Column{}
for _, column := range table.Columns {
dbTableColumn := fmt.Sprintf("%s.%s.%s", database, table.Name, column.Name)
if checkHandler(currentColumns, dbTableColumn) {
if currentColumns[dbTableColumn] == expectedState {
continue
}

Expand Down

0 comments on commit f603530

Please sign in to comment.