Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(fixer): add Prowler Fixer feature! #3634

Merged
merged 20 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prowler/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def prowler():
global_provider,
custom_checks_metadata,
getattr(args, "mutelist_file", None),
args.config_file,
args,
)
else:
logger.error(
Expand Down
59 changes: 53 additions & 6 deletions prowler/lib/check/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,14 @@ def import_check(check_path: str) -> ModuleType:


def run_check(check: Check, output_options) -> list:
"""
Run the check and return the findings
Args:
check (Check): check class
output_options (Any): output options
Returns:
list: list of findings
"""
findings = []
if output_options.verbose:
print(
Expand All @@ -419,12 +427,40 @@ def run_check(check: Check, output_options) -> list:
return findings


def run_fixer(check_findings: list, check_name: str, check_class: Check):
"""
Run the fixer for the check if it exists and there are any FAIL findings
Args:
check_findings (list): list of findings
check_name (str): check name
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you don't need the check name here, you can use check.CheckID.

check_class (Check): check class
"""
try:
fixer = getattr(check_class, "fixer")
# Check if there are any FAIL findings
if any("FAIL" in finding.status for finding in check_findings):
print(
f"Fixing fails for check {Fore.YELLOW}{check_name}{Style.RESET_ALL}...\n"
)
for finding in check_findings:
if finding.status == "FAIL":
print(
f"\t{orange_color}FIXING{Style.RESET_ALL} {finding.region}... {(Fore.GREEN + 'DONE') if fixer(finding.region) else (Fore.RED + 'ERROR')}{Style.RESET_ALL}\n"
)
except AttributeError:
logger.error(f"Fixer method not implemented for check {check_name}")
except Exception as error:
logger.error(
f"{check_name} - {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)


def execute_checks(
checks_to_execute: list,
global_provider: Any,
custom_checks_metadata: Any,
mutelist_file: str,
config_file: str,
args: Any,
) -> list:
# List to store all the check's findings
all_findings = []
Expand Down Expand Up @@ -476,6 +512,7 @@ def execute_checks(
services_executed,
checks_executed,
custom_checks_metadata,
args,
)
all_findings.extend(check_findings)

Expand All @@ -491,7 +528,7 @@ def execute_checks(
else:
# Prepare your messages
messages = [
f"{Style.BRIGHT}Config File: {Style.RESET_ALL}{Fore.YELLOW}{config_file}{Style.RESET_ALL}"
f"{Style.BRIGHT}Config File: {Style.RESET_ALL}{Fore.YELLOW}{args.config_file}{Style.RESET_ALL}"
]
if mutelist_file:
messages.append(
Expand Down Expand Up @@ -536,6 +573,7 @@ def execute_checks(
services_executed,
checks_executed,
custom_checks_metadata,
args,
)
all_findings.extend(check_findings)

Expand All @@ -562,21 +600,26 @@ def execute(
services_executed: set,
checks_executed: set,
custom_checks_metadata: Any,
args: Any,
):
try:
# Import check module
check_module_path = f"prowler.providers.{global_provider.type}.services.{service}.{check_name}.{check_name}"
lib = import_check(check_module_path)
# Recover functions from check
check_to_execute = getattr(lib, check_name)
c = check_to_execute()
check_class = check_to_execute()

# Update check metadata to reflect that in the outputs
if custom_checks_metadata and custom_checks_metadata["Checks"].get(c.CheckID):
c = update_check_metadata(c, custom_checks_metadata["Checks"][c.CheckID])
if custom_checks_metadata and custom_checks_metadata["Checks"].get(
check_class.CheckID
):
check_class = update_check_metadata(
check_class, custom_checks_metadata["Checks"][check_class.CheckID]
)

# Run check
check_findings = run_check(c, global_provider.output_options)
check_findings = run_check(check_class, global_provider.output_options)

# Update Audit Status
services_executed.add(service)
Expand All @@ -595,6 +638,10 @@ def execute(
# Report the check's findings
report(check_findings, global_provider)

# Prowler Fixer
if args.fix and args.check:
run_fixer(check_findings, check_name, check_class)

if os.environ.get("PROWLER_REPORT_LIB_PATH"):
try:
logger.info("Using custom report interface ...")
Expand Down
8 changes: 8 additions & 0 deletions prowler/providers/aws/lib/arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ def init_parser(self):
help="Scan unused services",
)

# Prowler Fixer
prowler_fixer_subparser = aws_parser.add_argument_group("Prowler Fixer")
prowler_fixer_subparser.add_argument(
"--fix",
action="store_true",
help="Fix the failed findings that can be fixed by Prowler",
)


def validate_session_duration(duration):
"""validate_session_duration validates that the AWS STS Assume Role Session Duration is between 900 and 43200 seconds."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ def execute(self):
findings.append(report)

return findings

def fixer(self, region):
return ec2_client.__enable_ebs_encryption_by_default__(region)
12 changes: 12 additions & 0 deletions prowler/providers/aws/services/ec2/ec2_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,18 @@ def __get_ebs_encryption_settings__(self, regional_client):
f"{regional_client.region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)

def __enable_ebs_encryption_by_default__(self, region):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test for this in the service. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

try:
regional_client = self.regional_clients[region]
return regional_client.enable_ebs_encryption_by_default()[
"EbsEncryptionByDefault"
]
except Exception as error:
logger.error(
f"{region} -- {error.__class__.__name__}[{error.__traceback__.tb_lineno}]: {error}"
)
return False


class Instance(BaseModel):
id: str
Expand Down
6 changes: 6 additions & 0 deletions tests/lib/cli/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,12 @@ def test_aws_parser_scan_unused_services(self):
parsed = self.parser.parse(command)
assert parsed.scan_unused_services

def test_aws_parser_fixer(self):
argument = "--fix"
command = [prowler_command, argument]
parsed = self.parser.parse(command)
assert parsed.fix

def test_aws_parser_config_file(self):
argument = "--config-file"
config_file = "./test-config.yaml"
Expand Down
16 changes: 16 additions & 0 deletions tests/providers/aws/services/ec2/ec2_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,19 @@ def test__describe_volumes__(self):
assert ec2.volumes[0].tags == [
{"Key": "test", "Value": "test"},
]

# Test EC2 EBS Enabling Encryption by Default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lovely!

@mock_aws
def test__describe_ebs_encryption_by_default__(self):
# Generate EC2 Client
ec2_client = client("ec2", region_name=AWS_REGION_US_EAST_1)

# EC2 client for this test class
aws_provider = set_mocked_aws_provider(
[AWS_REGION_EU_WEST_1, AWS_REGION_US_EAST_1]
)
ec2 = EC2(aws_provider)

assert not ec2.__enable_ebs_encryption_by_default__()
ec2_client.enable_ebs_encryption_by_default()
assert ec2.__enable_ebs_encryption_by_default__()
Loading