diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 09de98cf79..e7c38a3d46 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -16,6 +16,9 @@ jobs: name: Checks runs-on: ubuntu-22.04 + env: + PMM_ENCRYPTION_KEY_PATH: pmm-encryption.key + steps: - name: Check out code uses: actions/checkout@v4 diff --git a/go.mod b/go.mod index f2e9f2d32a..f25d8d8e8f 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/gogo/status v1.1.1 github.com/golang-migrate/migrate/v4 v4.17.0 + github.com/google/tink/go v1.7.0 github.com/google/uuid v1.6.0 github.com/grafana/grafana-api-golang-client v0.27.0 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 @@ -101,7 +102,6 @@ require ( github.com/google/btree v1.0.0 // indirect github.com/hashicorp/go-hclog v1.6.2 // indirect github.com/hashicorp/go-msgpack/v2 v2.1.1 // indirect - github.com/hashicorp/go-uuid v1.0.2 // indirect github.com/kr/fs v0.1.0 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/miekg/dns v1.1.26 // indirect diff --git a/go.sum b/go.sum index eec4683a48..c236c1d13f 100644 --- a/go.sum +++ b/go.sum @@ -239,6 +239,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/tink/go v1.7.0 h1:6Eox8zONGebBFcCBqkVmt60LaWZa6xg1cl/DwAh/J1w= +github.com/google/tink/go v1.7.0/go.mod h1:GAUOd+QE3pgj9q8VKIGTCP33c/B7eb4NhxLcgTJZStM= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= diff --git a/managed/Makefile b/managed/Makefile index fa0ab06fa8..dcf35ec550 100644 --- a/managed/Makefile +++ b/managed/Makefile @@ -37,7 +37,6 @@ clean: ## Remove generated files release: ## Build pmm-managed release binaries env CGO_ENABLED=0 go build -v $(PMM_LD_FLAGS) -o $(PMM_RELEASE_PATH)/ ./cmd/... - $(PMM_RELEASE_PATH)/pmm-managed --version release-starlark: env CGO_ENABLED=0 go build -v $(PMM_LD_FLAGS) -o $(PMM_RELEASE_PATH)/ ./cmd/pmm-managed-starlark/... diff --git a/managed/models/agent_helpers.go b/managed/models/agent_helpers.go index 041fbe4499..eb677fc791 100644 --- a/managed/models/agent_helpers.go +++ b/managed/models/agent_helpers.go @@ -229,7 +229,8 @@ func FindAgents(q *reform.Querier, filters AgentFilters) ([]*Agent, error) { agents := make([]*Agent, len(structs)) for i, s := range structs { - agents[i] = s.(*Agent) //nolint:forcetypeassert + decryptedAgent := DecryptAgent(*s.(*Agent)) //nolint:forcetypeassert + agents[i] = &decryptedAgent } return agents, nil @@ -249,8 +250,9 @@ func FindAgentByID(q *reform.Querier, id string) (*Agent, error) { } return nil, errors.WithStack(err) } + decryptedAgent := DecryptAgent(*agent) - return agent, nil + return &decryptedAgent, nil } // FindAgentsByIDs finds Agents by IDs. @@ -272,7 +274,8 @@ func FindAgentsByIDs(q *reform.Querier, ids []string) ([]*Agent, error) { res := make([]*Agent, len(structs)) for i, s := range structs { - res[i] = s.(*Agent) //nolint:forcetypeassert + decryptedAgent := DecryptAgent(*s.(*Agent)) //nolint:forcetypeassert + res[i] = &decryptedAgent } return res, nil } @@ -323,7 +326,8 @@ func FindDBConfigForService(q *reform.Querier, serviceID string) (*DBConfig, err res := make([]*Agent, len(structs)) for i, s := range structs { - res[i] = s.(*Agent) //nolint:forcetypeassert + decryptedAgent := DecryptAgent(*s.(*Agent)) //nolint:forcetypeassert + res[i] = &decryptedAgent } if len(res) == 0 { @@ -350,8 +354,8 @@ func FindPMMAgentsRunningOnNode(q *reform.Querier, nodeID string) ([]*Agent, err res := make([]*Agent, 0, len(structs)) for _, str := range structs { - row := str.(*Agent) //nolint:forcetypeassert - res = append(res, row) + decryptedAgent := DecryptAgent(*str.(*Agent)) //nolint:forcetypeassert + res = append(res, &decryptedAgent) } return res, nil @@ -395,8 +399,8 @@ func FindPMMAgentsForService(q *reform.Querier, serviceID string) ([]*Agent, err } res := make([]*Agent, 0, len(pmmAgentRecords)) for _, str := range pmmAgentRecords { - row := str.(*Agent) //nolint:forcetypeassert - res = append(res, row) + decryptedAgent := DecryptAgent(*str.(*Agent)) //nolint:forcetypeassert + res = append(res, &decryptedAgent) } return res, nil @@ -477,7 +481,8 @@ func FindAgentsForScrapeConfig(q *reform.Querier, pmmAgentID *string, pushMetric res := make([]*Agent, len(allAgents)) for i, s := range allAgents { - res[i] = s.(*Agent) //nolint:forcetypeassert + decryptedAgent := DecryptAgent(*s.(*Agent)) //nolint:forcetypeassert + res[i] = &decryptedAgent } return res, nil } @@ -641,11 +646,14 @@ func CreateNodeExporter(q *reform.Querier, if err := row.SetCustomLabels(customLabels); err != nil { return nil, err } - if err := q.Insert(row); err != nil { + + encryptedAgent := EncryptAgent(*row) + if err := q.Insert(&encryptedAgent); err != nil { return nil, errors.WithStack(err) } + agent := DecryptAgent(encryptedAgent) - return row, nil + return &agent, nil } // CreateExternalExporterParams params for add external exporter. @@ -725,11 +733,14 @@ func CreateExternalExporter(q *reform.Querier, params *CreateExternalExporterPar if err := row.SetCustomLabels(params.CustomLabels); err != nil { return nil, err } - if err := q.Insert(row); err != nil { + + encryptedAgent := EncryptAgent(*row) + if err := q.Insert(&encryptedAgent); err != nil { return nil, errors.WithStack(err) } + agent := DecryptAgent(encryptedAgent) - return row, nil + return &agent, nil } // CreateAgentParams params for add common exporter. @@ -912,15 +923,17 @@ func CreateAgent(q *reform.Querier, agentType AgentType, params *CreateAgentPara DisabledCollectors: params.DisableCollectors, LogLevel: pointer.ToStringOrNil(params.LogLevel), } - if err := row.SetCustomLabels(params.CustomLabels); err != nil { return nil, err } - if err := q.Insert(row); err != nil { + + encryptedAgent := EncryptAgent(*row) + if err := q.Insert(&encryptedAgent); err != nil { return nil, errors.WithStack(err) } + agent := DecryptAgent(encryptedAgent) - return row, nil + return &agent, nil } // ChangeCommonAgentParams contains parameters that can be changed for all Agents. diff --git a/managed/models/database.go b/managed/models/database.go index 8f45a0364c..d2c2066c7f 100644 --- a/managed/models/database.go +++ b/managed/models/database.go @@ -27,6 +27,7 @@ import ( "net" "net/url" "os" + "slices" "strconv" "strings" @@ -36,6 +37,8 @@ import ( "google.golang.org/grpc/status" "gopkg.in/reform.v1" "gopkg.in/reform.v1/dialects/postgresql" + + "github.com/percona/pmm/managed/utils/encryption" ) const ( @@ -1146,12 +1149,87 @@ func SetupDB(ctx context.Context, sqlDB *sql.DB, params SetupDBParams) (*reform. return nil, errCV } - if err := migrateDB(db, params); err != nil { + agentColumnsToEncrypt := []encryption.Column{ + {Name: "username"}, + {Name: "password"}, + {Name: "aws_access_key"}, + {Name: "aws_secret_key"}, + {Name: "mongo_db_tls_options", CustomHandler: EncryptMongoDBOptionsHandler}, + {Name: "azure_options", CustomHandler: EncryptAzureOptionsHandler}, + {Name: "mysql_options", CustomHandler: EncryptMySQLOptionsHandler}, + {Name: "postgresql_options", CustomHandler: EncryptPostgreSQLOptionsHandler}, + {Name: "agent_password"}, + } + + itemsToEncrypt := []encryption.Table{ + { + Name: "agents", + Identifiers: []string{"agent_id"}, + Columns: agentColumnsToEncrypt, + }, + } + + if err := migrateDB(db, params, itemsToEncrypt); err != nil { return nil, err } + return db, nil } +// EncryptDB encrypts a set of columns in a specific database and table. +func EncryptDB(tx *reform.TX, params SetupDBParams, itemsToEncrypt []encryption.Table) error { + if len(itemsToEncrypt) == 0 { + return nil + } + + settings, err := GetSettings(tx) + if err != nil { + return err + } + alreadyEncrypted := make(map[string]bool) + for _, v := range settings.EncryptedItems { + alreadyEncrypted[v] = true + } + + notEncrypted := []encryption.Table{} + newlyEncrypted := []string{} + for _, table := range itemsToEncrypt { + columns := []encryption.Column{} + for _, column := range table.Columns { + dbTableColumn := fmt.Sprintf("%s.%s.%s", params.Name, table.Name, column.Name) + if alreadyEncrypted[dbTableColumn] { + continue + } + + columns = append(columns, column) + newlyEncrypted = append(newlyEncrypted, dbTableColumn) + } + if len(columns) == 0 { + continue + } + + table.Columns = columns + notEncrypted = append(notEncrypted, table) + } + + if len(notEncrypted) == 0 { + return nil + } + + err = encryption.EncryptItems(tx, notEncrypted) + if err != nil { + return err + } + _, err = UpdateSettings(tx, &ChangeSettingsParams{ + EncryptedItems: slices.Concat(settings.EncryptedItems, newlyEncrypted), + }) + if err != nil { + return err + } + + return nil +} + // checkVersion checks minimal required PostgreSQL server version. func checkVersion(ctx context.Context, db reform.DBTXContext) error { PGVersion, err := GetPostgreSQLVersion(ctx, db) @@ -1211,7 +1289,7 @@ func initWithRoot(params SetupDBParams) error { } // migrateDB runs PostgreSQL database migrations. -func migrateDB(db *reform.DB, params SetupDBParams) error { +func migrateDB(db *reform.DB, params SetupDBParams, itemsToEncrypt []encryption.Table) error { var currentVersion int errDB := db.QueryRow("SELECT id FROM schema_migrations ORDER BY id DESC LIMIT 1").Scan(¤tVersion) // undefined_table (see https://www.postgresql.org/docs/current/errcodes-appendix.html) @@ -1247,6 +1325,11 @@ func migrateDB(db *reform.DB, params SetupDBParams) error { } } + err := EncryptDB(tx, params, itemsToEncrypt) + if err != nil { + return err + } + if params.SetupFixtures == SkipFixtures { return nil } @@ -1260,14 +1343,16 @@ func migrateDB(db *reform.DB, params SetupDBParams) error { return err } - if err = setupFixture1(tx.Querier, params); err != nil { + err = setupPMMServerAgents(tx.Querier, params) + if err != nil { return err } + return nil }) } -func setupFixture1(q *reform.Querier, params SetupDBParams) error { +func setupPMMServerAgents(q *reform.Querier, params SetupDBParams) error { // create PMM Server Node and associated Agents node, err := createNodeWithID(q, PMMServerNodeID, GenericNodeType, &CreateNodeParams{ NodeName: "pmm-server", diff --git a/managed/models/database_test.go b/managed/models/database_test.go index 2b487629d5..efafd5e2f7 100644 --- a/managed/models/database_test.go +++ b/managed/models/database_test.go @@ -21,7 +21,6 @@ import ( "database/sql" "fmt" "testing" - "time" "github.com/AlekSi/pointer" "github.com/lib/pq" @@ -327,60 +326,6 @@ func TestDatabaseChecks(t *testing.T) { } func TestDatabaseMigrations(t *testing.T) { - t.Run("Update metrics resolutions", func(t *testing.T) { - sqlDB := testdb.Open(t, models.SkipFixtures, pointer.ToInt(9)) - defer sqlDB.Close() //nolint:errcheck - settings, err := models.GetSettings(sqlDB) - require.NoError(t, err) - metricsResolutions := models.MetricsResolutions{ - HR: 5 * time.Second, - MR: 5 * time.Second, - LR: 60 * time.Second, - } - settings.MetricsResolutions = metricsResolutions - err = models.SaveSettings(sqlDB, settings) - require.NoError(t, err) - - settings, err = models.GetSettings(sqlDB) - require.NoError(t, err) - require.Equal(t, metricsResolutions, settings.MetricsResolutions) - - testdb.SetupDB(t, sqlDB, models.SkipFixtures, pointer.ToInt(10)) - settings, err = models.GetSettings(sqlDB) - require.NoError(t, err) - require.Equal(t, models.MetricsResolutions{ - HR: 5 * time.Second, - MR: 10 * time.Second, - LR: 60 * time.Second, - }, settings.MetricsResolutions) - }) - t.Run("Shouldn' update metrics resolutions if it's already changed", func(t *testing.T) { - sqlDB := testdb.Open(t, models.SkipFixtures, pointer.ToInt(9)) - defer sqlDB.Close() //nolint:errcheck - settings, err := models.GetSettings(sqlDB) - require.NoError(t, err) - metricsResolutions := models.MetricsResolutions{ - HR: 1 * time.Second, - MR: 5 * time.Second, - LR: 60 * time.Second, - } - settings.MetricsResolutions = metricsResolutions - err = models.SaveSettings(sqlDB, settings) - require.NoError(t, err) - - settings, err = models.GetSettings(sqlDB) - require.NoError(t, err) - require.Equal(t, metricsResolutions, settings.MetricsResolutions) - - testdb.SetupDB(t, sqlDB, models.SkipFixtures, pointer.ToInt(10)) - settings, err = models.GetSettings(sqlDB) - require.NoError(t, err) - require.Equal(t, models.MetricsResolutions{ - HR: 1 * time.Second, - MR: 5 * time.Second, - LR: 60 * time.Second, - }, settings.MetricsResolutions) - }) t.Run("stats_collections field migration: string to string array", func(t *testing.T) { sqlDB := testdb.Open(t, models.SkipFixtures, pointer.ToInt(57)) defer sqlDB.Close() //nolint:errcheck diff --git a/managed/models/encryption_helpers.go b/managed/models/encryption_helpers.go new file mode 100644 index 0000000000..f3ae2e01a8 --- /dev/null +++ b/managed/models/encryption_helpers.go @@ -0,0 +1,296 @@ +// Copyright (C) 2023 Percona LLC +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package models + +import ( + "database/sql" + "encoding/json" + + "github.com/sirupsen/logrus" + + "github.com/percona/pmm/managed/utils/encryption" +) + +// EncryptAgent encrypt agent. +func EncryptAgent(agent Agent) Agent { + return agentEncryption(agent, encryption.Encrypt) +} + +// DecryptAgent decrypt agent. +func DecryptAgent(agent Agent) Agent { + return agentEncryption(agent, encryption.Decrypt) +} + +func agentEncryption(agent Agent, handler func(string) (string, error)) Agent { + if agent.Username != nil { + username, err := handler(*agent.Username) + if err != nil { + logrus.Warning(err) + } + agent.Username = &username + } + + if agent.Password != nil { + password, err := handler(*agent.Password) + if err != nil { + logrus.Warning(err) + } + agent.Password = &password + } + + if agent.AgentPassword != nil { + agentPassword, err := handler(*agent.AgentPassword) + if err != nil { + logrus.Warning(err) + } + agent.AgentPassword = &agentPassword + } + + if agent.AWSAccessKey != nil { + awsAccessKey, err := handler(*agent.AWSAccessKey) + if err != nil { + logrus.Warning(err) + } + agent.AWSAccessKey = &awsAccessKey + } + + if agent.AWSSecretKey != nil { + awsSecretKey, err := handler(*agent.AWSSecretKey) + if err != nil { + logrus.Warning(err) + } + agent.AWSSecretKey = &awsSecretKey + } + + var err error + if agent.MySQLOptions != nil { + agent.MySQLOptions.TLSCert, err = handler(agent.MySQLOptions.TLSCert) + if err != nil { + logrus.Warning(err) + } + agent.MySQLOptions.TLSKey, err = handler(agent.MySQLOptions.TLSKey) + if err != nil { + logrus.Warning(err) + } + } + + if agent.PostgreSQLOptions != nil { + agent.PostgreSQLOptions.SSLCert, err = handler(agent.PostgreSQLOptions.SSLCert) + if err != nil { + logrus.Warning(err) + } + agent.PostgreSQLOptions.SSLKey, err = handler(agent.PostgreSQLOptions.SSLKey) + if err != nil { + logrus.Warning(err) + } + } + + if agent.MongoDBOptions != nil { + agent.MongoDBOptions.TLSCertificateKey, err = handler(agent.MongoDBOptions.TLSCertificateKey) + if err != nil { + logrus.Warning(err) + } + agent.MongoDBOptions.TLSCertificateKeyFilePassword, err = handler(agent.MongoDBOptions.TLSCertificateKeyFilePassword) + if err != nil { + logrus.Warning(err) + } + } + + if agent.AzureOptions != nil { + agent.AzureOptions.ClientID, err = handler(agent.AzureOptions.ClientID) + if err != nil { + logrus.Warning(err) + } + agent.AzureOptions.ClientSecret, err = handler(agent.AzureOptions.ClientSecret) + if err != nil { + logrus.Warning(err) + } + agent.AzureOptions.SubscriptionID, err = handler(agent.AzureOptions.SubscriptionID) + if err != nil { + logrus.Warning(err) + } + agent.AzureOptions.TenantID, err = handler(agent.AzureOptions.TenantID) + if err != nil { + logrus.Warning(err) + } + } + + return agent +} + +// EncryptMySQLOptionsHandler returns encrypted MySQL Options. +func EncryptMySQLOptionsHandler(e *encryption.Encryption, val any) (any, error) { + return mySQLOptionsHandler(val, e.Encrypt) +} + +// DecryptMySQLOptionsHandler returns decrypted MySQL Options. +func DecryptMySQLOptionsHandler(e *encryption.Encryption, val any) (any, error) { + return mySQLOptionsHandler(val, e.Decrypt) +} + +func mySQLOptionsHandler(val any, handler func(string) (string, error)) (any, error) { + o := MySQLOptions{} + value := val.(*sql.NullString) //nolint:forcetypeassert + if !value.Valid { + return sql.NullString{}, nil + } + + err := json.Unmarshal([]byte(value.String), &o) + if err != nil { + return nil, err + } + + o.TLSCert, err = handler(o.TLSCert) + if err != nil { + return nil, err + } + o.TLSKey, err = handler(o.TLSKey) + if err != nil { + return nil, err + } + + res, err := json.Marshal(o) + if err != nil { + return nil, err + } + + return res, nil +} + +// EncryptPostgreSQLOptionsHandler returns encrypted PostgreSQL Options. +func EncryptPostgreSQLOptionsHandler(e *encryption.Encryption, val any) (any, error) { + return postgreSQLOptionsHandler(val, e.Encrypt) +} + +// DecryptPostgreSQLOptionsHandler returns decrypted PostgreSQL Options. +func DecryptPostgreSQLOptionsHandler(e *encryption.Encryption, val any) (any, error) { + return postgreSQLOptionsHandler(val, e.Decrypt) +} + +func postgreSQLOptionsHandler(val any, handler func(string) (string, error)) (any, error) { + o := PostgreSQLOptions{} + value := val.(*sql.NullString) //nolint:forcetypeassert + if !value.Valid { + return sql.NullString{}, nil + } + + err := json.Unmarshal([]byte(value.String), &o) + if err != nil { + return nil, err + } + + o.SSLCert, err = handler(o.SSLCert) + if err != nil { + return nil, err + } + o.SSLKey, err = handler(o.SSLKey) + if err != nil { + return nil, err + } + + res, err := json.Marshal(o) + if err != nil { + return nil, err + } + + return res, nil +} + +// EncryptMongoDBOptionsHandler returns encrypted MongoDB Options. +func EncryptMongoDBOptionsHandler(e *encryption.Encryption, val any) (any, error) { + return mongoDBOptionsHandler(val, e.Encrypt) +} + +// DecryptMongoDBOptionsHandler returns decrypted MongoDB Options. +func DecryptMongoDBOptionsHandler(e *encryption.Encryption, val any) (any, error) { + return mongoDBOptionsHandler(val, e.Decrypt) +} + +func mongoDBOptionsHandler(val any, handler func(string) (string, error)) (any, error) { + o := MongoDBOptions{} + value := val.(*sql.NullString) //nolint:forcetypeassert + if !value.Valid { + return sql.NullString{}, nil + } + + err := json.Unmarshal([]byte(value.String), &o) + if err != nil { + return nil, err + } + + o.TLSCertificateKey, err = handler(o.TLSCertificateKey) + if err != nil { + return nil, err + } + o.TLSCertificateKeyFilePassword, err = handler(o.TLSCertificateKeyFilePassword) + if err != nil { + return nil, err + } + + res, err := json.Marshal(o) + if err != nil { + return nil, err + } + + return res, nil +} + +// EncryptAzureOptionsHandler returns encrypted Azure Options. +func EncryptAzureOptionsHandler(e *encryption.Encryption, val any) (any, error) { + return azureOptionsHandler(val, e.Encrypt) +} + +// DecryptAzureOptionsHandler returns decrypted Azure Options. +func DecryptAzureOptionsHandler(e *encryption.Encryption, val any) (any, error) { + return azureOptionsHandler(val, e.Decrypt) +} + +func azureOptionsHandler(val any, handler func(string) (string, error)) (any, error) { + o := AzureOptions{} + value := val.(*sql.NullString) //nolint:forcetypeassert + if !value.Valid { + return sql.NullString{}, nil + } + + err := json.Unmarshal([]byte(value.String), &o) + if err != nil { + return nil, err + } + + o.ClientID, err = handler(o.ClientID) + if err != nil { + return nil, err + } + o.ClientSecret, err = handler(o.ClientSecret) + if err != nil { + return nil, err + } + o.SubscriptionID, err = handler(o.SubscriptionID) + if err != nil { + return nil, err + } + o.TenantID, err = handler(o.TenantID) + if err != nil { + return nil, err + } + + res, err := json.Marshal(o) + if err != nil { + return nil, err + } + + return res, nil +} diff --git a/managed/models/settings.go b/managed/models/settings.go index cf86398765..1b12259c23 100644 --- a/managed/models/settings.go +++ b/managed/models/settings.go @@ -109,6 +109,9 @@ type Settings struct { // Enabled is true if access control is enabled. Enabled *bool `json:"enabled"` } `json:"access_control"` + + // Contains all encrypted tables in format 'db.table.column'. + EncryptedItems []string `json:"encrypted_items"` } // IsAlertingEnabled returns true if alerting is enabled. diff --git a/managed/models/settings_helpers.go b/managed/models/settings_helpers.go index 351202dd04..fb5125f336 100644 --- a/managed/models/settings_helpers.go +++ b/managed/models/settings_helpers.go @@ -92,6 +92,9 @@ type ChangeSettingsParams struct { // DefaultRoleID sets a default role to be assigned to new users. DefaultRoleID *int + + // List of items in format 'db.table.column' to be encrypted. + EncryptedItems []string } // SetPMMServerID should be run on start up to generate unique PMM Server ID. @@ -223,6 +226,9 @@ func UpdateSettings(q reform.DBTX, params *ChangeSettingsParams) (*Settings, err settings.DefaultRoleID = *params.DefaultRoleID } + if len(params.EncryptedItems) != 0 { + settings.EncryptedItems = params.EncryptedItems + } err = SaveSettings(q, settings) if err != nil { return nil, err diff --git a/managed/models/settings_helpers_test.go b/managed/models/settings_helpers_test.go index bd92db48cd..14112c30cd 100644 --- a/managed/models/settings_helpers_test.go +++ b/managed/models/settings_helpers_test.go @@ -37,6 +37,7 @@ func TestSettings(t *testing.T) { t.Run("Defaults", func(t *testing.T) { actual, err := models.GetSettings(sqlDB) require.NoError(t, err) + require.NotEmpty(t, actual.EncryptedItems) expected := &models.Settings{ MetricsResolutions: models.MetricsResolutions{ HR: 5 * time.Second, @@ -52,7 +53,8 @@ func TestSettings(t *testing.T) { FrequentInterval: 4 * time.Hour, }, }, - DefaultRoleID: 1, + DefaultRoleID: 1, + EncryptedItems: actual.EncryptedItems, } assert.Equal(t, expected, actual) }) diff --git a/managed/services/agents/agents.go b/managed/services/agents/agents.go index 71ffbefadb..451b34f09b 100644 --- a/managed/services/agents/agents.go +++ b/managed/services/agents/agents.go @@ -109,6 +109,25 @@ func redactWords(agent *models.Agent) []string { words = append(words, s) } } + if agent.MySQLOptions != nil { + if s := agent.MySQLOptions.TLSKey; s != "" { + words = append(words, s) + } + } + if agent.PostgreSQLOptions != nil { + if s := agent.PostgreSQLOptions.SSLKey; s != "" { + words = append(words, s) + } + } + if agent.MongoDBOptions != nil { + if s := agent.MongoDBOptions.TLSCertificateKey; s != "" { + words = append(words, s) + } + if s := agent.MongoDBOptions.TLSCertificateKeyFilePassword; s != "" { + words = append(words, s) + } + } + return words } diff --git a/managed/services/agents/connection_checker.go b/managed/services/agents/connection_checker.go index 447dbb9ee8..93e5c92a73 100644 --- a/managed/services/agents/connection_checker.go +++ b/managed/services/agents/connection_checker.go @@ -86,9 +86,9 @@ func (c *ConnectionChecker) CheckConnectionToService(ctx context.Context, q *ref return err } - var sanitizedDSN string + sanitizedDSN := request.Dsn for _, word := range redactWords(agent) { - sanitizedDSN = strings.ReplaceAll(request.Dsn, word, "****") + sanitizedDSN = strings.ReplaceAll(sanitizedDSN, word, "****") } l.Infof("CheckConnectionRequest: type: %s, DSN: %s timeout: %s.", request.Type, sanitizedDSN, request.Timeout) diff --git a/managed/services/agents/mysql_test.go b/managed/services/agents/mysql_test.go index 5a8a3b7b41..6cee95abeb 100644 --- a/managed/services/agents/mysql_test.go +++ b/managed/services/agents/mysql_test.go @@ -199,7 +199,7 @@ func TestMySQLdExporterConfigTablestatsGroupDisabled(t *testing.T) { "DATA_SOURCE_NAME=username:s3cur3 p@$$w0r4.@tcp(1.2.3.4:3306)/?timeout=1s&tls=custom", "HTTP_AUTH=pmm:agent-id", }, - RedactWords: []string{"s3cur3 p@$$w0r4."}, + RedactWords: []string{"s3cur3 p@$$w0r4.", "content-of-tls-key"}, TextFiles: map[string]string{ "tlsCa": "content-of-tls-ca", "tlsCert": "content-of-tls-cert", diff --git a/managed/services/agents/service_info_broker.go b/managed/services/agents/service_info_broker.go index e8aa3e88f7..6802de66f8 100644 --- a/managed/services/agents/service_info_broker.go +++ b/managed/services/agents/service_info_broker.go @@ -154,9 +154,9 @@ func (c *ServiceInfoBroker) GetInfoFromService(ctx context.Context, q *reform.Qu return err } - var sanitizedDSN string + sanitizedDSN := request.Dsn for _, word := range redactWords(agent) { - sanitizedDSN = strings.ReplaceAll(request.Dsn, word, "****") + sanitizedDSN = strings.ReplaceAll(sanitizedDSN, word, "****") } l.Infof("ServiceInfoRequest: type: %s, DSN: %s timeout: %s.", request.Type, sanitizedDSN, request.Timeout) @@ -182,9 +182,11 @@ func (c *ServiceInfoBroker) GetInfoFromService(ctx context.Context, q *reform.Qu case models.MySQLServiceType: agent.TableCount = &sInfo.TableCount l.Debugf("Updating table count: %d.", sInfo.TableCount) - if err = q.Update(agent); err != nil { + encryptedAgent := models.EncryptAgent(*agent) + if err = q.Update(&encryptedAgent); err != nil { return errors.Wrap(err, "failed to update table count") } + return updateServiceVersion(ctx, q, resp, service) case models.PostgreSQLServiceType: if agent.PostgreSQLOptions == nil { @@ -206,9 +208,11 @@ func (c *ServiceInfoBroker) GetInfoFromService(ctx context.Context, q *reform.Qu agent.PostgreSQLOptions.DatabaseCount = int32(databaseCount - excludedDatabaseCount) l.Debugf("Updating PostgreSQL options, database count: %d.", agent.PostgreSQLOptions.DatabaseCount) - if err = q.Update(agent); err != nil { + encryptedAgent := models.EncryptAgent(*agent) + if err = q.Update(&encryptedAgent); err != nil { return errors.Wrap(err, "failed to update database count") } + return updateServiceVersion(ctx, q, resp, service) case models.MongoDBServiceType, models.ProxySQLServiceType: diff --git a/managed/services/management/agent.go b/managed/services/management/agent.go index fc13caeb90..c2e74b6960 100644 --- a/managed/services/management/agent.go +++ b/managed/services/management/agent.go @@ -130,9 +130,9 @@ func (s *ManagementService) agentToAPI(agent *models.Agent) (*managementv1.Unive Disabled: agent.Disabled, DisabledCollectors: agent.DisabledCollectors, IsConnected: s.r.IsConnected(agent.AgentID), - IsAgentPasswordSet: agent.AgentPassword != nil, - IsAwsSecretKeySet: agent.AWSSecretKey != nil, - IsPasswordSet: agent.Password != nil, + IsAgentPasswordSet: pointer.GetString(agent.AgentPassword) != "", + IsAwsSecretKeySet: pointer.GetString(agent.AWSSecretKey) != "", + IsPasswordSet: pointer.GetString(agent.Password) != "", ListenPort: uint32(pointer.GetUint16(agent.ListenPort)), LogLevel: pointer.GetString(agent.LogLevel), MaxQueryLength: agent.MaxQueryLength, diff --git a/managed/services/management/mongodb.go b/managed/services/management/mongodb.go index 6ce4d94072..fe7782c194 100644 --- a/managed/services/management/mongodb.go +++ b/managed/services/management/mongodb.go @@ -57,8 +57,6 @@ func (s *ManagementService) addMongoDB(ctx context.Context, req *managementv1.Ad } mongodb.Service = invService.(*inventoryv1.MongoDBService) //nolint:forcetypeassert - mongoDBOptions := models.MongoDBOptionsFromRequest(req) - req.MetricsMode, err = supportedMetricsMode(tx.Querier, req.MetricsMode, req.PmmAgentId) if err != nil { return err @@ -72,7 +70,7 @@ func (s *ManagementService) addMongoDB(ctx context.Context, req *managementv1.Ad AgentPassword: req.AgentPassword, TLS: req.Tls, TLSSkipVerify: req.TlsSkipVerify, - MongoDBOptions: mongoDBOptions, + MongoDBOptions: models.MongoDBOptionsFromRequest(req), PushMetrics: isPushMode(req.MetricsMode), ExposeExporter: req.ExposeExporter, DisableCollectors: req.DisableCollectors, @@ -106,7 +104,7 @@ func (s *ManagementService) addMongoDB(ctx context.Context, req *managementv1.Ad Password: req.Password, TLS: req.Tls, TLSSkipVerify: req.TlsSkipVerify, - MongoDBOptions: mongoDBOptions, + MongoDBOptions: models.MongoDBOptionsFromRequest(req), MaxQueryLength: req.MaxQueryLength, LogLevel: services.SpecifyLogLevel(req.LogLevel, inventoryv1.LogLevel_LOG_LEVEL_FATAL), // TODO QueryExamplesDisabled https://jira.percona.com/browse/PMM-7860 diff --git a/managed/services/management/postgresql.go b/managed/services/management/postgresql.go index 94ccff5845..e1f1a473da 100644 --- a/managed/services/management/postgresql.go +++ b/managed/services/management/postgresql.go @@ -64,7 +64,6 @@ func (s *ManagementService) addPostgreSQL(ctx context.Context, req *managementv1 return err } - options := models.PostgreSQLOptionsFromRequest(req) row, err := models.CreateAgent(tx.Querier, models.PostgresExporterType, &models.CreateAgentParams{ PMMAgentID: req.PmmAgentId, ServiceID: service.ServiceID, @@ -76,7 +75,7 @@ func (s *ManagementService) addPostgreSQL(ctx context.Context, req *managementv1 PushMetrics: isPushMode(req.MetricsMode), ExposeExporter: req.ExposeExporter, DisableCollectors: req.DisableCollectors, - PostgreSQLOptions: options, + PostgreSQLOptions: models.PostgreSQLOptionsFromRequest(req), LogLevel: services.SpecifyLogLevel(req.LogLevel, inventoryv1.LogLevel_LOG_LEVEL_ERROR), }) if err != nil { @@ -117,7 +116,7 @@ func (s *ManagementService) addPostgreSQL(ctx context.Context, req *managementv1 CommentsParsingDisabled: req.DisableCommentsParsing, TLS: req.Tls, TLSSkipVerify: req.TlsSkipVerify, - PostgreSQLOptions: options, + PostgreSQLOptions: models.PostgreSQLOptionsFromRequest(req), LogLevel: services.SpecifyLogLevel(req.LogLevel, inventoryv1.LogLevel_LOG_LEVEL_FATAL), }) if err != nil { @@ -142,7 +141,7 @@ func (s *ManagementService) addPostgreSQL(ctx context.Context, req *managementv1 CommentsParsingDisabled: req.DisableCommentsParsing, TLS: req.Tls, TLSSkipVerify: req.TlsSkipVerify, - PostgreSQLOptions: options, + PostgreSQLOptions: models.PostgreSQLOptionsFromRequest(req), LogLevel: services.SpecifyLogLevel(req.LogLevel, inventoryv1.LogLevel_LOG_LEVEL_FATAL), }) if err != nil { diff --git a/managed/utils/encryption/encryption.go b/managed/utils/encryption/encryption.go new file mode 100644 index 0000000000..a7cba048df --- /dev/null +++ b/managed/utils/encryption/encryption.go @@ -0,0 +1,205 @@ +// Copyright (C) 2023 Percona LLC +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +// Package encryption contains functions to encrypt/decrypt items or DB. +package encryption + +import ( + "encoding/base64" + "os" + "slices" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "gopkg.in/reform.v1" +) + +// DefaultEncryptionKeyPath contains default PMM encryption key path. +const DefaultEncryptionKeyPath = "/srv/pmm-encryption.key" + +var ( + // ErrEncryptionNotInitialized is error in case of encryption is not initialized. + ErrEncryptionNotInitialized = errors.New("encryption is not initialized") + // DefaultEncryption is the default implementation of encryption. + DefaultEncryption = New(DefaultEncryptionKeyPath) +) + +// New creates an encryption; if key on path doesn't exist, it will be generated. +func New(keyPath string) *Encryption { + e := &Encryption{} + customKeyPath := os.Getenv("PMM_ENCRYPTION_KEY_PATH") + if customKeyPath != "" { + e.Path = customKeyPath + } else { + e.Path = keyPath + } + + bytes, err := os.ReadFile(e.Path) + switch { + case os.IsNotExist(err): + err = e.generateKey() + if err != nil { + logrus.Panicf("Encryption: %v", err) + } + case err != nil: + logrus.Panicf("Encryption: %v", err) + default: + e.Key = string(bytes) + } + + primitive, err := e.getPrimitive() + if err != nil { + logrus.Panicf("Encryption: %v", err) + } + e.Primitive = primitive + + return e +} + +// Encrypt is a wrapper around DefaultEncryption.Encrypt. +func Encrypt(secret string) (string, error) { + return DefaultEncryption.Encrypt(secret) +} + +// Encrypt returns input string encrypted. +func (e *Encryption) Encrypt(secret string) (string, error) { + if e == nil || e.Primitive == nil { + return secret, ErrEncryptionNotInitialized + } + if secret == "" { + return secret, nil + } + cipherText, err := e.Primitive.Encrypt([]byte(secret), []byte("")) + if err != nil { + return secret, err + } + + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// EncryptItems is a wrapper around DefaultEncryption.EncryptItems. +func EncryptItems(tx *reform.TX, tables []Table) error { + return DefaultEncryption.EncryptItems(tx, tables) +} + +// EncryptItems will encrypt all columns provided in DB connection. +func (e *Encryption) EncryptItems(tx *reform.TX, tables []Table) error { + if len(tables) == 0 { + return nil + } + + for _, table := range tables { + res, err := table.read(tx) + if err != nil { + return err + } + + for k, v := range res.SetValues { + for i, val := range v { + var encrypted any + var err error + switch table.Columns[i].CustomHandler { + case nil: + encrypted, err = encryptColumnStringHandler(e, val) + default: + encrypted, err = table.Columns[i].CustomHandler(e, val) + } + + if err != nil { + return err + } + res.SetValues[k][i] = encrypted + } + data := slices.Concat([]any{}, v) + data = slices.Concat(data, res.WhereValues[k]) + _, err := tx.Exec(res.Query, data...) + if err != nil { + return err + } + } + } + + return nil +} + +// Decrypt is wrapper around DefaultEncryption.Decrypt. +func Decrypt(cipherText string) (string, error) { + return DefaultEncryption.Decrypt(cipherText) +} + +// Decrypt returns input string decrypted. +func (e *Encryption) Decrypt(cipherText string) (string, error) { + if e == nil || e.Primitive == nil { + return cipherText, ErrEncryptionNotInitialized + } + if cipherText == "" { + return cipherText, nil + } + decoded, err := base64.StdEncoding.DecodeString(cipherText) + if err != nil { + return cipherText, err + } + secret, err := e.Primitive.Decrypt(decoded, []byte("")) + if err != nil { + return cipherText, err + } + + return string(secret), nil +} + +// DecryptItems is wrapper around DefaultEncryption.DecryptItems. +func DecryptItems(tx *reform.TX, tables []Table) error { + return DefaultEncryption.DecryptItems(tx, tables) +} + +// DecryptItems will decrypt all columns provided in DB connection. +func (e *Encryption) DecryptItems(tx *reform.TX, tables []Table) error { + if len(tables) == 0 { + return nil + } + + for _, table := range tables { + res, err := table.read(tx) + if err != nil { + return err + } + + for k, v := range res.SetValues { + for i, val := range v { + var decrypted any + var err error + switch table.Columns[i].CustomHandler { + case nil: + decrypted, err = decryptColumnStringHandler(e, val) + default: + decrypted, err = table.Columns[i].CustomHandler(e, val) + } + + if err != nil { + return err + } + res.SetValues[k][i] = decrypted + } + data := slices.Concat([]any{}, v) + data = slices.Concat(data, res.WhereValues[k]) + _, err := tx.Exec(res.Query, data...) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/managed/utils/encryption/helpers.go b/managed/utils/encryption/helpers.go new file mode 100644 index 0000000000..c8c11ab4e3 --- /dev/null +++ b/managed/utils/encryption/helpers.go @@ -0,0 +1,177 @@ +// Copyright (C) 2023 Percona LLC +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package encryption + +import ( + "bytes" + "database/sql" + "encoding/base64" + "fmt" + "os" + "slices" + "strings" + + "github.com/google/tink/go/aead" + "github.com/google/tink/go/insecurecleartextkeyset" + "github.com/google/tink/go/keyset" + "github.com/google/tink/go/tink" + "gopkg.in/reform.v1" +) + +func prepareRowPointers(rows *sql.Rows) ([]any, error) { + columnTypes, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + columns := make(map[string]string) + for _, columnType := range columnTypes { + columns[columnType.Name()] = columnType.DatabaseTypeName() + } + + row := []any{} + for _, t := range columns { + switch t { + case "VARCHAR", "JSONB": + row = append(row, &sql.NullString{}) + default: + return nil, fmt.Errorf("unsupported identificator type %s", t) + } + } + + return row, nil +} + +func encryptColumnStringHandler(e *Encryption, val any) (any, error) { + value := val.(*sql.NullString) //nolint:forcetypeassert + if !value.Valid { + return sql.NullString{}, nil + } + + encrypted, err := e.Encrypt(value.String) + if err != nil { + return nil, err + } + + return encrypted, nil +} + +func decryptColumnStringHandler(e *Encryption, val any) (any, error) { + value := val.(*sql.NullString) //nolint:forcetypeassert + if !value.Valid { + return nil, nil //nolint:nilnil + } + + decrypted, err := e.Decrypt(value.String) + if err != nil { + return nil, err + } + + return decrypted, nil +} + +func (e *Encryption) getPrimitive() (tink.AEAD, error) { //nolint:ireturn + serializedKeyset, err := base64.StdEncoding.DecodeString(e.Key) + if err != nil { + return nil, err + } + + binaryReader := keyset.NewBinaryReader(bytes.NewBuffer(serializedKeyset)) + parsedHandle, err := insecurecleartextkeyset.Read(binaryReader) + if err != nil { + return nil, err + } + + return aead.New(parsedHandle) +} + +func (e *Encryption) generateKey() error { + handle, err := keyset.NewHandle(aead.AES256GCMKeyTemplate()) + if err != nil { + return err + } + + buff := &bytes.Buffer{} + err = insecurecleartextkeyset.Write(handle, keyset.NewBinaryWriter(buff)) + if err != nil { + return err + } + e.Key = base64.StdEncoding.EncodeToString(buff.Bytes()) + + return e.saveKeyToFile() +} + +func (e *Encryption) saveKeyToFile() error { + return os.WriteFile(e.Path, []byte(e.Key), 0o644) //nolint:gosec +} + +func (table Table) columnsList() []string { + res := []string{} + for _, c := range table.Columns { + res = append(res, c.Name) + } + + return res +} + +func (table Table) read(tx *reform.TX) (*QueryValues, error) { + what := slices.Concat(table.Identifiers, table.columnsList()) + query := fmt.Sprintf("SELECT %s FROM %s", strings.Join(what, ", "), table.Name) + rows, err := tx.Query(query) + if err != nil { + return nil, err + } + + q := &QueryValues{} + for rows.Next() { + row, err := prepareRowPointers(rows) + if err != nil { + return nil, err + } + err = rows.Scan(row...) + if err != nil { + return nil, err + } + + i := 1 + set := []string{} + setValues := []any{} + for k, v := range row[len(table.Identifiers):] { + set = append(set, fmt.Sprintf("%s = $%d", table.Columns[k].Name, i)) + setValues = append(setValues, v) + i++ + } + setSQL := fmt.Sprintf("SET %s", strings.Join(set, ", ")) + q.SetValues = append(q.SetValues, setValues) + + where := []string{} + whereValues := []any{} + for k, id := range table.Identifiers { + where = append(where, fmt.Sprintf("%s = $%d", id, i)) + whereValues = append(whereValues, row[k]) + i++ + } + whereSQL := "WHERE " + strings.Join(where, " AND ") + q.WhereValues = append(q.WhereValues, whereValues) + + q.Query = fmt.Sprintf("UPDATE %s %s %s", table.Name, setSQL, whereSQL) + } + err = rows.Close() //nolint:sqlclosecheck + if err != nil { + return nil, err + } + + return q, nil +} diff --git a/managed/utils/encryption/models.go b/managed/utils/encryption/models.go new file mode 100644 index 0000000000..257b49b1de --- /dev/null +++ b/managed/utils/encryption/models.go @@ -0,0 +1,45 @@ +// Copyright (C) 2023 Percona LLC +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package encryption + +import "github.com/google/tink/go/tink" + +// Encryption contains fields required for encryption. +type Encryption struct { + Path string + Key string + Primitive tink.AEAD +} + +// Table represents table name, it's identifiers and columns to be encrypted/decrypted. +type Table struct { + Name string + Identifiers []string + Columns []Column +} + +// Column represents column name and column's custom handler (if needed). +type Column struct { + Name string + CustomHandler func(e *Encryption, val any) (any, error) +} + +// QueryValues represents query to update row after encrypt/decrypt. +type QueryValues struct { + Query string + SetValues [][]any + WhereValues [][]any +} diff --git a/managed/utils/testdb/db.go b/managed/utils/testdb/db.go index d5707e367a..2f6ee84cc7 100644 --- a/managed/utils/testdb/db.go +++ b/managed/utils/testdb/db.go @@ -70,7 +70,8 @@ func Open(tb testing.TB, setupFixtures models.SetupFixturesMode, migrationVersio // Please use Open method to recreate DB for each test if you don't need to control migrations. func SetupDB(tb testing.TB, db *sql.DB, setupFixtures models.SetupFixturesMode, migrationVersion *int) { tb.Helper() - _, err := models.SetupDB(context.TODO(), db, models.SetupDBParams{ + ctx := context.TODO() + params := models.SetupDBParams{ // Uncomment to see all setup queries: // Logf: tb.Logf, Address: models.DefaultPostgreSQLAddr, @@ -79,7 +80,9 @@ func SetupDB(tb testing.TB, db *sql.DB, setupFixtures models.SetupFixturesMode, Password: password, SetupFixtures: setupFixtures, MigrationVersion: migrationVersion, - }) + } + + _, err := models.SetupDB(ctx, db, params) require.NoError(tb, err) }