diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 4bb42ab..d64f98c 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -47,13 +47,11 @@ var runCmd = &cobra.Command{ fmt.Println() } - // Nuclei args flag if nucleiArgs == "" { log.Fatal("Nuclei arguments are required") os.Exit(1) } - // Targets flag if targets == "" && target == "" { log.Fatal("Either a target or a list of targets is required") os.Exit(1) @@ -66,12 +64,12 @@ var runCmd = &cobra.Command{ batches := helpers.SplitSlice(urls, batchSize) log.Println("Splitting targets into", len(batches), "individual executions") log.Println("Running with " + fmt.Sprint(threads) + " threads") - core.ExecuteScans(batches, output, functionName, nucleiArgs, threads, silent) + core.ExecuteScans(batches, output, functionName, nucleiArgs, threads, silent, region) } else { log.Println("Running nuclei against the target", target) log.Println("Running with " + fmt.Sprint(threads) + " threads") batches := [][]string{{target}} - core.ExecuteScans(batches, output, functionName, nucleiArgs, threads, silent) + core.ExecuteScans(batches, output, functionName, nucleiArgs, threads, silent, region) } }, } @@ -108,13 +106,15 @@ func init() { // Region flag runCmd.Flags().StringVarP(®ion, "region", "r", "", "AWS region to run nuclei") if region == "" { - region, ok := os.LookupEnv("AWS_REGION") + var ok bool // Declare ok here to avoid shadowing + region, ok = os.LookupEnv("AWS_REGION") // Removed := to modify the existing region variable if !ok { runCmd.MarkFlagRequired("region") } else { runCmd.Flags().Set("region", region) } } + // Function name flag runCmd.Flags().StringVarP(&functionName, "function-name", "f", "", "AWS Lambda function name") if functionName == "" { diff --git a/pkg/core/core.go b/pkg/core/core.go index 3de16e4..db82083 100644 --- a/pkg/core/core.go +++ b/pkg/core/core.go @@ -10,7 +10,7 @@ import ( "github.com/DevSecOpsDocs/nuclearpond/pkg/lambda" ) -func ExecuteScans(batches [][]string, output string, lambdaName string, nucleiArgs string, threads int, silent bool) { +func ExecuteScans(batches [][]string, output string, lambdaName string, nucleiArgs string, threads int, silent bool, region string) { // Get start time start := time.Now() @@ -47,7 +47,7 @@ func ExecuteScans(batches [][]string, output string, lambdaName string, nucleiAr Output: output, } tasks <- func() { - lambda.InvokeLambdas(lambdaInvoke, lambdaName, output) + lambda.InvokeLambdas(lambdaInvoke, lambdaName, output, region) } } @@ -59,3 +59,4 @@ func ExecuteScans(batches [][]string, output string, lambdaName string, nucleiAr log.Println("Completed all parallel operations, best of luck! Completed in", time.Since(start)) } } + diff --git a/pkg/lambda/lambda.go b/pkg/lambda/lambda.go index b7b032a..bd7ed6f 100644 --- a/pkg/lambda/lambda.go +++ b/pkg/lambda/lambda.go @@ -20,7 +20,7 @@ type LambdaInvoke struct { } // Stage the lambda function for executing -func InvokeLambdas(payload LambdaInvoke, lambda string, output string) { +func InvokeLambdas(payload LambdaInvoke, lambda string, output string, region string) { // Bug to fix another day if payload.Targets[0] == "" { return @@ -33,7 +33,7 @@ func InvokeLambdas(payload LambdaInvoke, lambda string, output string) { } // invoke lambda function - response, err := invokeFunction(string(lambdaInvokeJson), lambda) + response, err := invokeFunction(string(lambdaInvokeJson), lambda, region) if err != nil { fmt.Println(err) } @@ -56,10 +56,10 @@ func InvokeLambdas(payload LambdaInvoke, lambda string, output string) { } // Execute a lambda function and return the response -func invokeFunction(payload string, functionName string) (string, error) { +func invokeFunction(payload string, functionName string, region string) (string, error) { // Create a new session sess, err := session.NewSession(&aws.Config{ - Region: aws.String("us-east-1")}, + Region: aws.String(region)}, // Using the passed region here ) // Create a Lambda service client. @@ -81,3 +81,4 @@ func invokeFunction(payload string, functionName string) (string, error) { // Return the response return string(result.Payload), nil } + diff --git a/pkg/server/scanner.go b/pkg/server/scanner.go index 55506a6..1edc729 100644 --- a/pkg/server/scanner.go +++ b/pkg/server/scanner.go @@ -23,20 +23,19 @@ func backgroundScan(scanInput Request, scanId string) { NucleiArgs := base64.StdEncoding.EncodeToString([]byte(scanInput.Args)) silent := true - // Fail if AWS_LAMBDA_FUNCTION_NAME and AWS_REGION are not set functionName := os.Getenv("AWS_LAMBDA_FUNCTION_NAME") regionName := os.Getenv("AWS_REGION") dynamodbTable := os.Getenv("AWS_DYNAMODB_TABLE") + if functionName == "" || regionName == "" || dynamodbTable == "" { - log.Fatal("AWS_LAMBDA_FUNCTION_NAME is not set") + log.Fatal("Environment variables (AWS_LAMBDA_FUNCTION_NAME, AWS_REGION, AWS_DYNAMODB_TABLE) are not set.") } - // Convert scanId to a valid DynamoDB key requestId := strings.ReplaceAll(scanId, "-", "") log.Println("Initiating scan with the id of ", scanId, "with", len(targets), "targets") storeScanState(requestId, "running") - core.ExecuteScans(batches, output, functionName, NucleiArgs, threads, silent) + core.ExecuteScans(batches, output, functionName, NucleiArgs, threads, silent, regionName) storeScanState(requestId, "completed") log.Println("Scan", scanId, "completed") }