diff --git a/assets/tests/Makefile b/assets/tests/Makefile index 417d331..c6c6ca5 100644 --- a/assets/tests/Makefile +++ b/assets/tests/Makefile @@ -1,11 +1,7 @@ include .env -BUILD_FLAGS=-ldflags="-s -w" -OUT_DIR=dist -OUT_PREFIX=nuvola - all: docker compose up -d pipenv install pipenv run tflocal apply -auto-approve - pipenv run ../../nuvola dump --output-dir ./ --aws-endpoint-url http://localhost:4566 \ No newline at end of file + pipenv run ../../nuvola dump --output-dir ./ --aws-endpoint-url http://localhost:4566 --verbose --debug \ No newline at end of file diff --git a/cmd/assess.go b/cmd/assess.go index 60c77fc..6d72f1b 100644 --- a/cmd/assess.go +++ b/cmd/assess.go @@ -63,7 +63,7 @@ func importZipFile(connector *connector.StorageConnector, zipfile string) { continue } if err := processZipFile(connector, f); err != nil { - logging.HandleError(err, "Assess", "Processing ZIP file") + logger.Error("Processing ZIP file", "err", err) } } } diff --git a/cmd/dump.go b/cmd/dump.go index 9bee81c..217889f 100644 --- a/cmd/dump.go +++ b/cmd/dump.go @@ -3,7 +3,6 @@ package cmd import ( "encoding/json" "fmt" - "reflect" "time" "github.com/primait/nuvola/pkg/connector" @@ -12,80 +11,94 @@ import ( "github.com/spf13/cobra" ) -var AWSResults = map[string]interface{}{ - "Whoami": nil, - "CredentialReport": nil, - "Groups": nil, - "Users": nil, - "Roles": nil, - "Buckets": nil, - "EC2s": nil, - "VPCs": nil, - "Lambdas": nil, - "RDS": nil, - "DynamoDBs": nil, - "RedshiftDBs": nil, -} +var ( + AWSResults = map[string]interface{}{ + "Whoami": nil, + "CredentialReport": nil, + "Groups": nil, + "Users": nil, + "Roles": nil, + "Buckets": nil, + "EC2s": nil, + "VPCs": nil, + "Lambdas": nil, + "RDS": nil, + "DynamoDBs": nil, + "RedshiftDBs": nil, + } + dumpCmd = &cobra.Command{ + Use: "dump", + Short: "Dump AWS resources and policies information and store them in Neo4j", + Run: runDumpCmd, + } +) -var dumpCmd = &cobra.Command{ - Use: "dump", - Short: "Dump AWS resources and policies information and store them in Neo4j", - Run: func(cmd *cobra.Command, args []string) { - startTime := time.Now() - if cmd.Flags().Changed(flagVerbose) { - logger.SetVerboseLevel() - } - if cmd.Flags().Changed(flagDebug) { - logger.SetDebugLevel() - } +func runDumpCmd(cmd *cobra.Command, args []string) { + startTime := time.Now() - cloudConnector, err := connector.NewCloudConnector(awsProfile, awsEndpointUrl) - if err != nil { - logger.Error(err.Error()) - } + if cmd.Flags().Changed(flagVerbose) { + logger.SetVerboseLevel() + } + if cmd.Flags().Changed(flagDebug) { + logger.SetDebugLevel() + } - if dumpOnly { - dumpData(nil, cloudConnector) - } else { - storageConnector := connector.NewStorageConnector().FlushAll() - dumpData(storageConnector, cloudConnector) - } - saveResults(awsProfile, outputDirectory, outputFormat) - logger.Info("Execution Time", "seconds", time.Since(startTime)) - }, + cloudConnector, err := connector.NewCloudConnector(awsProfile, awsEndpointUrl) + if err != nil { + logger.Error("Failed to create cloud connector", "err", err) + return + } + + if dumpOnly { + dumpData(nil, cloudConnector) + } else { + storageConnector := connector.NewStorageConnector().FlushAll() + dumpData(storageConnector, cloudConnector) + } + + saveResults(awsProfile, outputDirectory, outputFormat) + logger.Info("Execution Time", "seconds", time.Since(startTime)) } func dumpData(storageConnector *connector.StorageConnector, cloudConnector *connector.CloudConnector) { dataChan := make(chan map[string]interface{}) - go cloudConnector.DumpAll("aws", dataChan) - for { - a, ok := <-dataChan // receive data step by step and import it to Neo4j - if !ok { - break - } - v := reflect.ValueOf(a) - mapKey := v.MapKeys()[0].Interface().(string) - obj, err := json.Marshal(a[mapKey]) + go func() { + cloudConnector.DumpAll("aws", dataChan) + defer close(dataChan) + }() + + for data := range dataChan { + processData(data, storageConnector) + } +} + +func processData(data map[string]interface{}, storageConnector *connector.StorageConnector) { + for key, value := range data { + obj, err := json.Marshal(value) if err != nil { - logger.Error("DumpData: error marshalling output", err) + logger.Error("Error marshalling output", "err", err) + continue + } + if storageConnector != nil { + storageConnector.ImportResults(key, obj) } - storageConnector.ImportResults(mapKey, obj) - AWSResults[mapKey] = a[mapKey] + AWSResults[key] = value } } -func saveResults(awsProfile string, outputDir string, outputFormat string) { +func saveResults(awsProfile, outputDir, outputFormat string) { if awsProfile == "" { awsProfile = "default" } if outputFormat == "zip" { - zip.Zip(outputDir, awsProfile, &AWSResults) + zip.Zip(outputDir, awsProfile, AWSResults) } today := time.Now().Format("20060102") for key, value := range AWSResults { if outputFormat == "json" { - files.PrettyJSONToFile(outputDir, fmt.Sprintf("%s_%s.json", key, today), value) + filename := fmt.Sprintf("%s_%s.json", key, today) + files.PrettyJSONToFile(outputDir, filename, value) } } } diff --git a/pkg/connector/cloud_connector.go b/pkg/connector/cloud_connector.go index 1c8c176..7745b21 100644 --- a/pkg/connector/cloud_connector.go +++ b/pkg/connector/cloud_connector.go @@ -26,56 +26,30 @@ func SetActions() { func (cc *CloudConnector) DumpAll(cloudprovider string, c chan map[string]interface{}) { switch strings.ToLower(cloudprovider) { case "aws": - whoami := cc.AWSConfig.DumpWhoami() - c <- map[string]interface{}{ - "Whoami": whoami, - } - credentialReport := cc.AWSConfig.DumpCredentialReport() - c <- map[string]interface{}{ - "CredentialReport": credentialReport, - } - groups := cc.AWSConfig.DumpIAMGroups() - c <- map[string]interface{}{ - "Groups": groups, - } - users := cc.AWSConfig.DumpIAMUsers() - c <- map[string]interface{}{ - "Users": users, - } - roles := cc.AWSConfig.DumpIAMRoles() - c <- map[string]interface{}{ - "Roles": roles, - } - buckets := cc.AWSConfig.DumpBuckets() - c <- map[string]interface{}{ - "Buckets": buckets, - } - ec2 := cc.AWSConfig.DumpEC2Instances() - c <- map[string]interface{}{ - "EC2s": ec2, - } - vpc := cc.AWSConfig.DumpVpcs() - c <- map[string]interface{}{ - "VPCs": vpc, - } - lambda := cc.AWSConfig.DumpLambdas() - c <- map[string]interface{}{ - "Lambdas": lambda, - } - rds := cc.AWSConfig.DumpRDS() - c <- map[string]interface{}{ - "RDS": rds, - } - dynamodb := cc.AWSConfig.DumpDynamoDBs() - c <- map[string]interface{}{ - "DynamoDBs": dynamodb, - } - redshift := cc.AWSConfig.DumpRedshiftDBs() - c <- map[string]interface{}{ - "RedshiftDBs": redshift, - } - close(c) + cc.dumpAWSData(c) default: + cc.logger.Error("Unsupported cloud provider", "cloudprovider", cloudprovider) + } +} + +func (cc *CloudConnector) dumpAWSData(c chan map[string]interface{}) { + data := map[string]interface{}{ + "Whoami": cc.AWSConfig.DumpWhoami(), + "CredentialReport": cc.AWSConfig.DumpCredentialReport(), + "Groups": cc.AWSConfig.DumpIAMGroups(), + "Users": cc.AWSConfig.DumpIAMUsers(), + "Roles": cc.AWSConfig.DumpIAMRoles(), + "Buckets": cc.AWSConfig.DumpBuckets(), + "EC2s": cc.AWSConfig.DumpEC2Instances(), + "VPCs": cc.AWSConfig.DumpVpcs(), + "Lambdas": cc.AWSConfig.DumpLambdas(), + "RDS": cc.AWSConfig.DumpRDS(), + "DynamoDBs": cc.AWSConfig.DumpDynamoDBs(), + "RedshiftDBs": cc.AWSConfig.DumpRedshiftDBs(), + } + + for key, value := range data { + c <- map[string]interface{}{key: value} } } @@ -84,6 +58,7 @@ func (cc *CloudConnector) testConnection(cloudprovider string) bool { case "aws": return cc.AWSConfig.TestConnection() default: + cc.logger.Error("Unsupported cloud provider", "cloudprovider", cloudprovider) return false } } diff --git a/pkg/connector/connector_structs.go b/pkg/connector/connector_structs.go index c0f073f..6cd3b0b 100644 --- a/pkg/connector/connector_structs.go +++ b/pkg/connector/connector_structs.go @@ -12,6 +12,6 @@ type StorageConnector struct { } type CloudConnector struct { - AWSConfig awsconfig.AWSConfig + AWSConfig *awsconfig.AWSConfig logger logging.LogManager } diff --git a/pkg/connector/services/aws/aws.go b/pkg/connector/services/aws/aws.go index 3533528..2f465a5 100644 --- a/pkg/connector/services/aws/aws.go +++ b/pkg/connector/services/aws/aws.go @@ -23,7 +23,7 @@ var ( countRetries = 100 ) -func InitAWSConfiguration(profile string, awsEndpoint string) (awsc AWSConfig) { +func InitAWSConfiguration(profile string, awsEndpoint string) (awsc *AWSConfig) { // Load the Shared AWS Configuration (~/.aws/config) cfg, _ := config.LoadDefaultConfig(context.TODO(), config.WithSharedConfigProfile(profile), config.WithRetryer(func() aws.Retryer { @@ -34,7 +34,7 @@ func InitAWSConfiguration(profile string, awsEndpoint string) (awsc AWSConfig) { if awsEndpoint != "" { cfg.BaseEndpoint = aws.String(awsEndpoint) } - awsc = AWSConfig{Profile: profile, Config: cfg} + awsc = &AWSConfig{Profile: profile, Config: cfg} SetActions() // Get the available AWS regions dynamically ec2.ListAndSaveRegions(cfg) diff --git a/pkg/connector/services/aws/tools.go b/pkg/connector/services/aws/tools.go index a6bb98b..4a86a01 100644 --- a/pkg/connector/services/aws/tools.go +++ b/pkg/connector/services/aws/tools.go @@ -11,37 +11,38 @@ import ( ) func SetActions() { + logger := logging.GetLogManager() URL := "https://awspolicygen.s3.amazonaws.com/js/policies.js" client := req.C().SetBaseURL(URL).SetTimeout(30 * time.Second).SetUserAgent("Mozilla/5.0 (X11; Linux x86_64; rv:103.0) Gecko/20100101 Firefox/103.0") - response := client.Get(). + response, err := client.R(). SetHeader("Connection", "keep-alive"). SetHeader("Pragma", "no-cache"). SetHeader("Cache-Control", "no-cache"). - Do() - if response.Err != nil { - logging.HandleError(response.Err, "AWS - SetActions", "Error on calling HTTP endpoint") + Get(URL) + if err != nil { + logger.Error("Error on calling HTTP endpoint", "err", err) } resString := strings.Replace(response.String(), "app.PolicyEditorConfig=", "", 1) obj, err := oj.ParseString(resString) if err != nil { - logging.HandleError(err, "AWS - SetActions", "Error on parsing output string") + logger.Error("Error on parsing output string", "err", err) } query, err := gojq.Parse(`.serviceMap[] | .StringPrefix as $prefix | .Actions[] | "\($prefix):\(.)"`) if err != nil { - logging.HandleError(err, "AWS - SetActions", "Error on mapping string to object") + logger.Error("Error on mapping string to object", "err", err) } iter := query.Run(obj) - ActionsMap = make(map[string][]string, 0) + ActionsMap = make(map[string][]string) for { v, ok := iter.Next() if !ok { break } if err, ok := v.(error); ok { - logging.HandleError(err, "AWS - SetActions", "Error on itering over objects") + logger.Error("Error on iterating over objects", "err", err) } ActionsList = append(ActionsList, v.(string)) @@ -54,9 +55,9 @@ func SetActions() { func unique(slice []string) []string { keys := make(map[string]bool) - list := []string{} + var list []string for _, entry := range slice { - if _, value := keys[entry]; !value { + if !keys[entry] { keys[entry] = true list = append(list, entry) } diff --git a/pkg/connector/storage_connector.go b/pkg/connector/storage_connector.go index 769308b..0a18643 100644 --- a/pkg/connector/storage_connector.go +++ b/pkg/connector/storage_connector.go @@ -17,10 +17,9 @@ import ( func NewStorageConnector() *StorageConnector { neo4jURL := os.Getenv("NEO4J_URL") - neo4jUsername := "neo4j" neo4jPassword := os.Getenv("PASSWORD") logger := logging.GetLogManager() - client, err := neo4j.Connect(neo4jURL, neo4jUsername, neo4jPassword) + client, err := neo4j.Connect(neo4jURL, "neo4j", neo4jPassword) if err != nil { logger.Error("Error connecting to database", "err", err) } diff --git a/pkg/io/logging/handlerrors.go b/pkg/io/logging/handlerrors.go index f0bf2c9..393447c 100644 --- a/pkg/io/logging/handlerrors.go +++ b/pkg/io/logging/handlerrors.go @@ -1,31 +1,27 @@ package logging import ( - "fmt" - "log" "runtime" "github.com/aws/aws-sdk-go-v2/aws/transport/http" ) func HandleAWSError(err *http.ResponseError, service string, operation string) { - fmt.Println(runtime.Caller(1)) + logger.Warn(runtime.Caller(1)) switch err.Response.StatusCode { - case 400: - log.Fatalf("Service: %s, error: %v\n", service, err.Unwrap()) case 403: - log.Fatalf("Service: %s, Operation: %s, error: %s\n", service, operation, "Permission Denied") + logger.Warn("service", service, "operation", operation, "status", err.HTTPStatusCode(), "err", "permission denied") default: - log.Fatalf("Service: %s, Operation: %s, StatusCode: %d, error: %v", service, operation, err.HTTPStatusCode(), err.ResponseError) + logger.Warn("service", service, "operation", operation, "status", err.HTTPStatusCode(), "error", err.ResponseError, "err", err.Unwrap()) } } func HandleError(err error, service string, operation string, exitonError ...bool) { _, file, line, _ := runtime.Caller(1) - fmt.Printf("Error pointer: %s:%d\n", file, line) + logger.Warn("Error pointer: %s:%d\n", file, line) if len(exitonError) >= 1 && !exitonError[0] { - log.Printf("Service: %s, Operation: %s, Error: %s\n", service, operation, err) + logger.Warn("service", service, "operation", operation, "err", err) } else { - log.Fatalf("Service: %s, Operation: %s, Error: %s\n", service, operation, err) + logger.Error("service", service, "operation", operation, "err", err) } } diff --git a/pkg/io/logging/logging.go b/pkg/io/logging/logging.go index 5967b87..ff62836 100644 --- a/pkg/io/logging/logging.go +++ b/pkg/io/logging/logging.go @@ -18,14 +18,20 @@ type LogManager interface { Info(message interface{}, keyvals ...interface{}) Warn(message interface{}, keyvals ...interface{}) Error(message interface{}, keyvals ...interface{}) + PrettyJSON(s interface{}) []byte + JSON(s interface{}) []byte } type logManager struct { logger *log.Logger } -var logger *logManager -var once sync.Once +const INDENT_SPACES int = 4 + +var ( + logger *logManager + once sync.Once +) func GetLogManager() LogManager { once.Do(func() { @@ -40,56 +46,54 @@ func GetLogManager() LogManager { } }) - return *logger + return logger } -func (lm logManager) SetVerboseLevel() { +func (lm *logManager) SetVerboseLevel() { lm.logger.SetLevel(log.InfoLevel) } -func (lm logManager) SetDebugLevel() { +func (lm *logManager) SetDebugLevel() { lm.logger.SetLevel(log.DebugLevel) } -func (lm logManager) Debug(message interface{}, keyvals ...interface{}) { +func (lm *logManager) Debug(message interface{}, keyvals ...interface{}) { lm.logger.Debug(message, keyvals...) } -func (lm logManager) Info(message interface{}, keyvals ...interface{}) { +func (lm *logManager) Info(message interface{}, keyvals ...interface{}) { lm.logger.Info(message, keyvals...) } -func (lm logManager) Warn(message interface{}, keyvals ...interface{}) { +func (lm *logManager) Warn(message interface{}, keyvals ...interface{}) { lm.logger.Warn(message, keyvals...) } -func (lm logManager) Error(message interface{}, keyvals ...interface{}) { +func (lm *logManager) Error(message interface{}, keyvals ...interface{}) { lm.logger.Error(message, keyvals...) os.Exit(1) } -var INDENT_SPACES int = 4 - -func PrettyJSON(s interface{}) (data []byte) { +func (lm *logManager) PrettyJSON(s interface{}) []byte { data, err := json.MarshalIndent(s, "", strings.Repeat(" ", INDENT_SPACES)) if err != nil { if _, ok := err.(*json.UnsupportedTypeError); ok { - return []byte("Tried to Marshal Invalid Type") + lm.Error("Tried to Marshal invalid type", "err", err) } - return []byte("Struct does not exist") + lm.Error("Struct does not exist", "err", err) } - return + return data } -func JSON(s interface{}) (data []byte) { +func (lm *logManager) JSON(s interface{}) []byte { data, err := json.Marshal(s) if err != nil { if _, ok := err.(*json.UnsupportedTypeError); ok { - return []byte("Tried to Marshal Invalid Type") + lm.Error("Tried to Marshal invalid type", "err", err) } - return []byte("Struct does not exist") + lm.Error("Struct does not exist", "err", err) } - return + return data } func PrintRed(s string) { diff --git a/tools/filesystem/files/files.go b/tools/filesystem/files/files.go index af0ecc3..271a430 100644 --- a/tools/filesystem/files/files.go +++ b/tools/filesystem/files/files.go @@ -12,13 +12,14 @@ import ( ) func PrettyJSONToFile(filePath string, fileName string, s interface{}) { + logger := logging.GetLogManager() if err := os.MkdirAll(filePath, os.FileMode(0775)); err != nil { - logging.HandleError(err, "Files - PrettyJSONToFile", "Error on creating/reading output folder") + logger.Error("Error on creating/reading output folder", "err", err) } filePath = filePath + string(filepath.Separator) + fileName - if err := os.WriteFile(filePath, logging.PrettyJSON(s), 0600); err != nil { - logging.HandleError(err, "Files - PrettyJSONToFile", "Error on writing file") + if err := os.WriteFile(filePath, logger.PrettyJSON(s), 0600); err != nil { + logger.Error("Error on writing file", "err", err) } } @@ -36,7 +37,7 @@ func GetFiles(root, pattern string) []string { return nil }) if err != nil { - logging.HandleError(err, "Files - GetFiles", "Error on reading file") + logging.GetLogManager().Error("Error on reading file", "err", err) } return a } diff --git a/tools/filesystem/zip/zip.go b/tools/filesystem/zip/zip.go index 876420f..45ecfb2 100644 --- a/tools/filesystem/zip/zip.go +++ b/tools/filesystem/zip/zip.go @@ -11,40 +11,48 @@ import ( "github.com/primait/nuvola/pkg/io/logging" ) -func Zip(path string, profile string, values *map[string]interface{}) { +func Zip(path string, profile string, values map[string]interface{}) { + logger := logging.GetLogManager() today := time.Now().Format("20060102") - fileSeparator := string(filepath.Separator) - profile = filepath.Clean(strings.Replace(profile, fileSeparator, "-", -1)) - filePtr, err := os.Create(fmt.Sprintf("%s%snuvola-%s_%s.zip", filepath.Clean(path), fileSeparator, profile, today)) + profile = filepath.Clean(strings.Replace(profile, string(filepath.Separator), "-", -1)) + filePtr, err := os.Create( + filepath.Join( + filepath.Clean(path), + fmt.Sprintf("nuvola-%s_%s.zip", profile, today)), + ) if err != nil { - logging.HandleError(err, "Zip", "Error on creating output folder") + logger.Error("Error on creating output folder", "err", err) } defer func() { - if err := filePtr.Close(); err != nil { - logging.HandleError(err, "Zip", "Error closing file") + if cerr := filePtr.Close(); cerr != nil { + logger.Error("Error closing file", "err", err) } }() - MyZipWriter := zip.NewWriter(filePtr) - defer MyZipWriter.Close() + zipWriter := zip.NewWriter(filePtr) + defer func() { + if cerr := zipWriter.Close(); cerr != nil { + logger.Error("Error closing zip writer", "err", cerr) + } + }() - for key, value := range *values { - writer, err := MyZipWriter.Create(fmt.Sprintf("%s_%s.json", key, today)) + for key, value := range values { + writer, err := zipWriter.Create(fmt.Sprintf("%s_%s.json", key, today)) if err != nil { - fmt.Println(err) + logger.Error("Error on creating file", "err", err) } - _, err = writer.Write(logging.PrettyJSON(value)) - if err != nil { - logging.HandleError(err, "Zip", "Error on writing file content") + data := logger.PrettyJSON(value) + if _, err := writer.Write(data); err != nil { + logger.Error("Error writing file content", "err", err) } } } -func UnzipInMemory(zipfile string) (r *zip.ReadCloser) { +func UnzipInMemory(zipfile string) *zip.ReadCloser { r, err := zip.OpenReader(zipfile) if err != nil { - logging.HandleError(err, "Zip", "Error on opening ZIP file") + logging.GetLogManager().Error("Error on opening ZIP file", "err", err) } - return + return r } diff --git a/tools/yamler/yamler.go b/tools/yamler/yamler.go index 92372aa..33b0c69 100644 --- a/tools/yamler/yamler.go +++ b/tools/yamler/yamler.go @@ -23,6 +23,7 @@ type Conf struct { Return []string `yaml:"return"` Enabled bool `yaml:"enabled"` Find Find `yaml:"find,omitempty"` + logger logging.LogManager } type Find struct { @@ -33,15 +34,16 @@ type Find struct { } func GetConf(file string) (c *Conf) { - c = &Conf{} + logger := logging.GetLogManager() + c = &Conf{Enabled: true, logger: logger} yamlFile, err := os.ReadFile(files.NormalizePath(file)) if err != nil { - logging.HandleError(err, "Yamler - GetConf", "Error on reading rule file") + logger.Error("Error on reading rule file", "err", err) } c.Enabled = true // Default value is: Enabled err = yaml.Unmarshal(yamlFile, &c) if err != nil { - logging.HandleError(err, "Yamler - GetConf", "Umarshalling yamlFile") + logger.Error("Error unmarshalling yamlFile", "err", err) } return c @@ -49,107 +51,87 @@ func GetConf(file string) (c *Conf) { func PrepareQuery(config *Conf) (query string, arguments map[string]interface{}) { arguments = make(map[string]interface{}, 0) - if config.Services != nil { + if len(config.Services) > 0 { // Direct access to properties query = prepareService(config.Services, arguments) + prepareProperties(config.Properties, arguments) + prepareResults(config.Return, arguments) - } else if config.Find.To != nil || config.Find.Who != nil || config.Find.With != nil { - if config.Find.Target != nil { + } else if len(config.Find.To) > 0 || len(config.Find.Who) > 0 || len(config.Find.With) > 0 { + if len(config.Find.Target) > 0 { query = prepareQueryPrivEsc(config, arguments) } else { query = preparePathQuery(config, arguments) } } else { - logging.HandleError(nil, "Yamler - PrepareQuery", "Malformed rule!") + config.logger.Error("Malformed rule") } return query, arguments } -func preparePathQuery(rule *Conf, arguments map[string]interface{}) (query string) { +func preparePathQuery(rule *Conf, arguments map[string]interface{}) string { template := "MATCH m%d = (who)-[:HAS_POLICY]->(:Policy)-[:ALLOWS]->(:Action {Service: $service%d, Action: $action%d}) \n" - - matchQueries := "" - whereFilters := "" - returnValues := "" + var matchQueries, whereFilters, returnValues strings.Builder for i, perm := range rule.Find.With { withSplit := strings.Split(perm, ":") - service := withSplit[0] - action := withSplit[1] - arguments["action"+strconv.Itoa(i)] = action - arguments["service"+strconv.Itoa(i)] = service + service, action := withSplit[0], withSplit[1] + arguments[fmt.Sprintf("action%d", i)] = action + arguments[fmt.Sprintf("service%d", i)] = service - matchQueries += fmt.Sprintf(template, i, i, i) - returnValues += fmt.Sprintf("NODES(m%d) + ", i) + matchQueries.WriteString(fmt.Sprintf(template, i, i, i)) + returnValues.WriteString(fmt.Sprintf("NODES(m%d) + ", i)) } - returnValues = strings.TrimRight(returnValues, "+ ") - query = matchQueries + query := matchQueries.String() + returnValuesStr := strings.TrimSuffix(returnValues.String(), " + ") if len(rule.Find.Who) > 0 { - whereFilters = "WHERE (" - } - for i, who := range rule.Find.Who { - block := fmt.Sprintf(`$who%d IN LABELS(who)`, i) - arguments["who"+strconv.Itoa(i)] = cases.Title(language.Und).String(who) - whereFilters += fmt.Sprintf("%s OR ", block) - } - if len(rule.Find.Who) > 0 { - whereFilters = strings.TrimRight(whereFilters, "OR ") - whereFilters += ") " + whereFilters.WriteString("WHERE (") + for i, who := range rule.Find.Who { + arguments[fmt.Sprintf("who%d", i)] = cases.Title(language.Und).String(who) + whereFilters.WriteString(fmt.Sprintf(`$who%d IN LABELS(who) OR `, i)) + } + whereFiltersStr := strings.TrimSuffix(whereFilters.String(), " OR ") + query += whereFiltersStr + ") " } - query += whereFilters - query += fmt.Sprintf("\nWITH %s AS nds UNWIND nds as nd RETURN DISTINCT nd", returnValues) - return + query += fmt.Sprintf("\nWITH %s AS nds UNWIND nds as nd RETURN DISTINCT nd", returnValuesStr) + return query } -func prepareQueryPrivEsc(rule *Conf, arguments map[string]interface{}) (query string) { +func prepareQueryPrivEsc(rule *Conf, arguments map[string]interface{}) string { template := "MATCH m%d = (who)-[:HAS_POLICY]->(:Policy)-[:ALLOWS]->(:Action {Service: $service%d, Action: $action%d}) \n" - - matchQueries := "" - whereFilters := "" - returnValues := "" - shortestPath := "" + var matchQueries, whereFilters, returnValues, shortestPath strings.Builder for i, perm := range rule.Find.With { withSplit := strings.Split(perm, ":") - service := withSplit[0] - action := withSplit[1] - arguments["action"+strconv.Itoa(i)] = action - arguments["service"+strconv.Itoa(i)] = service + service, action := withSplit[0], withSplit[1] + arguments[fmt.Sprintf("action%d", i)] = action + arguments[fmt.Sprintf("service%d", i)] = service - matchQueries += fmt.Sprintf(template, i, i, i) - returnValues += fmt.Sprintf("NODES(m%d) + ", i) + matchQueries.WriteString(fmt.Sprintf(template, i, i, i)) + returnValues.WriteString(fmt.Sprintf("NODES(m%d) + ", i)) } - returnValues = strings.TrimRight(returnValues, "+ ") - query = matchQueries + returnValuesStr := strings.TrimSuffix(returnValues.String(), " + ") + query := matchQueries.String() if len(rule.Find.Who) > 0 { - whereFilters = "WHERE (" - } - for i, who := range rule.Find.Who { - block := fmt.Sprintf(`$who%d IN LABELS(who)`, i) - arguments["who"+strconv.Itoa(i)] = cases.Title(language.Und).String(who) - whereFilters += fmt.Sprintf("%s OR ", block) - } - if len(rule.Find.Who) > 0 { - whereFilters = strings.TrimRight(whereFilters, "OR ") - whereFilters += ") " + whereFilters.WriteString("WHERE (") + for i, who := range rule.Find.Who { + arguments[fmt.Sprintf("who%d", i)] = cases.Title(language.Und).String(who) + whereFilters.WriteString(fmt.Sprintf(`$who%d IN LABELS(who) OR `, i)) + } + whereFiltersStr := strings.TrimSuffix(whereFilters.String(), " OR ") + query += whereFiltersStr + ") " } if len(rule.Find.Target) > 0 { - targetWhereFilter := "WHERE (%s) AND (%s)" - targetWhereLabelFilters := "" - targetWherePropertyFilters := "" + var targetWhereLabelFilters, targetWherePropertyFilters strings.Builder for i, target := range rule.Find.Target { for what, id := range target { what = cases.Title(language.Und).String(what) + arguments[fmt.Sprintf("targetType%d", i)] = what + targetWhereLabelFilters.WriteString(fmt.Sprintf(`$targetType%d IN LABELS(target) OR `, i)) - blockLabel := fmt.Sprintf(`$targetType%d IN LABELS(target)`, i) - arguments["targetType"+strconv.Itoa(i)] = what - targetWhereLabelFilters += fmt.Sprintf("%s OR ", blockLabel) - - blockProperty := "" + var blockProperty string switch what { case "Policy": blockProperty = fmt.Sprintf(`target.Name = $target%d`, i) @@ -162,68 +144,61 @@ func prepareQueryPrivEsc(rule *Conf, arguments map[string]interface{}) (query st case "User": blockProperty = fmt.Sprintf(`target.UserName = $target%d`, i) } - arguments["target"+strconv.Itoa(i)] = id - targetWherePropertyFilters += fmt.Sprintf("%s OR ", blockProperty) + arguments[fmt.Sprintf("target%d", i)] = id + targetWherePropertyFilters.WriteString(fmt.Sprintf("%s OR ", blockProperty)) } } - targetWhereLabelFilters = strings.TrimRight(targetWhereLabelFilters, "OR ") - targetWherePropertyFilters = strings.TrimRight(targetWherePropertyFilters, "OR ") - targetWhereFilter = fmt.Sprintf(targetWhereFilter, targetWhereLabelFilters, targetWherePropertyFilters) - shortestPath = fmt.Sprintf("\nMATCH p0 = allShortestPaths((who)-[*1..10]->(target))\n%s", targetWhereFilter) - returnValues += fmt.Sprintf(" + NODES(p%d)", 0) + targetWhereLabelFiltersStr := strings.TrimSuffix(targetWhereLabelFilters.String(), " OR ") + targetWherePropertyFiltersStr := strings.TrimSuffix(targetWherePropertyFilters.String(), " OR ") + shortestPath.WriteString(fmt.Sprintf("\nMATCH p0 = allShortestPaths((who)-[*1..10]->(target))\nWHERE (%s) AND (%s)", targetWhereLabelFiltersStr, targetWherePropertyFiltersStr)) + returnValues.WriteString(" + NODES(p0)") } - query += whereFilters - query += shortestPath - query += fmt.Sprintf("\nWITH %s AS nds UNWIND nds as nd RETURN DISTINCT nd", returnValues) - return + query += shortestPath.String() + query += fmt.Sprintf("\nWITH %s AS nds UNWIND nds as nd RETURN DISTINCT nd", returnValuesStr) + fmt.Println(query) + return query } func prepareService(services []string, arguments map[string]interface{}) string { - var query string - query = "MATCH (s:Service)\nWHERE" + var query strings.Builder + query.WriteString("MATCH (s:Service)\nWHERE") for i, service := range services { service = cases.Title(language.Und).String(service) + query.WriteString(fmt.Sprintf(" $name%d IN LABELS(s)", i)) if i < len(services)-1 { - query = fmt.Sprintf("%s $name%d IN LABELS(s) OR", query, i) - } else { - query = fmt.Sprintf("%s $name%d IN LABELS(s)\n", query, i) + query.WriteString(" OR") } - arguments["name"+strconv.Itoa(i)] = service + arguments[fmt.Sprintf("name%d", i)] = service } - query += "WITH s\n" - return query + query.WriteString("\nWITH s\n") + return query.String() } func prepareProperties(props []map[string]interface{}, arguments map[string]interface{}) string { - var query = "MATCH (s)\n" - var separator = "_" - var count = 0 - - if len(props) > 0 { - query += "WHERE" + if len(props) == 0 { + return "MATCH (s)\n" } + var query strings.Builder + separator := "_" + query.WriteString("MATCH (s)\nWHERE") + var count int for _, prop := range props { subprops := walk(reflect.ValueOf(prop), separator) for _, p := range subprops { - // Split the input in Key/Value - key := strings.Split(p, "__")[0] - value := strings.Split(p, "__")[1] - arguments["key"+strconv.Itoa(count)] = key - // Boolean must be used on queries + key, value := strings.Split(p, "__")[0], strings.Split(p, "__")[1] + arguments[fmt.Sprintf("key%d", count)] = key if b, err := strconv.ParseBool(value); err == nil { - arguments["value"+strconv.Itoa(count)] = b + arguments[fmt.Sprintf("value%d", count)] = b } else { - arguments["value"+strconv.Itoa(count)] = value + arguments[fmt.Sprintf("value%d", count)] = value } - query = fmt.Sprintf(`%s any(prop in keys(s) where toLower(prop) STARTS WITH toLower($key%d) AND s[prop] = $value%d) AND`, - query, count, count) + query.WriteString(fmt.Sprintf(` any(prop in keys(s) where toLower(prop) STARTS WITH toLower($key%d) AND s[prop] = $value%d) AND`, count, count)) count++ } } - - query = strings.TrimRight(query, " AND") + "\n" - return query + queryStr := strings.TrimSuffix(query.String(), " AND") + "\n" + return queryStr } //nolint:unused @@ -246,7 +221,8 @@ func prepareResults(results []string, arguments map[string]interface{}) string { return query } -func walk(v reflect.Value, separator string) (output []string) { +func walk(v reflect.Value, separator string) []string { + var output []string // Indirect through pointers and interfaces for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { v = v.Elem()