Skip to content

Commit

Permalink
Support multiple rds instances (#9)
Browse files Browse the repository at this point in the history
* Allow multiple RDS instances

* Update the test example

* PR reviews updates

* Note that the instance needs to be a writer

* Set target group port to RDS instance port

* Fix a typo

* Change lambda status if update fails

* Add source_code_hash

* Terraform format
  • Loading branch information
bobbyiliev authored Apr 11, 2024
1 parent 7fad4e0 commit 8878b26
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 138 deletions.
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ The module creates the following resources:

## Important Remarks

> **Note**
> [!NOTE]
> The RDS instance needs to be private. If your RDS instance is public, there is no need to use PrivateLink.
> [!NOTE]
> When using Aurora, the RDS instance needs to be a **writer** instance as the reader instances will not work.
- The RDS instance must be in the same VPC as the PrivateLink endpoint.
- Review this module with your Cloud Security team to ensure that it meets your security requirements.
- Finally, after the Terraform module has been applied, you will need to make sure that the Target Groups heatlth checks are passing. As the NLB does not have security groups, you will need to make sure that the NLB is able to reach the RDS instance by allowing the subnet CIDR blocks in the security groups of the RDS instance.
- Finally, after the Terraform module has been applied, you will need to make sure that the Target Groups health checks are passing. As the NLB does not have security groups, you will need to make sure that the NLB is able to reach the RDS instance by allowing the subnet CIDR blocks in the security groups of the RDS instance.

To override the default AWS provider variables, you can export the following environment variables:

Expand All @@ -43,7 +46,7 @@ cp terraform.tfvars.example terraform.tfvars

| Name | Description | Type | Example | Required |
|------|-------------|:----:|:-----:|:-----:|
| mz_rds_instance_name | The name of the RDS instance | string | `'my-rds-instance'` | yes |
| mz_rds_instance_names | The name of the RDS instances | list | `{ name = "instance1", listener_port = 5001 }` | yes |
| mz_rds_vpc_id | The VPC ID of the RDS instance | string | `'vpc-1234567890abcdef0'` | yes |
| mz_acceptance_required | Whether or not to require manual acceptance of new connections | bool | `true` | no |
| schedule_expression | [The scheduling expression](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/cloudwatch_event_rule#schedule_expression). For example, `cron(0 20 * * ? *)` | string | `'rate(5 minutes)'` | no |
Expand Down
21 changes: 13 additions & 8 deletions datasources.tf
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Get the state of the RDS instance using aws_db_instance
# Get the state of the RDS instances using aws_db_instance
data "aws_db_instance" "mz_rds_instance" {
db_instance_identifier = var.mz_rds_instance_name
for_each = { for inst in var.mz_rds_instance_details : inst.name => inst }

db_instance_identifier = each.key

lifecycle {
postcondition {
condition = self.publicly_accessible == false
error_message = "The RDS instance needs to be private, but it is public."
condition = self.publicly_accessible == false && self.replicate_source_db == ""
error_message = "The RDS instance must be private and a writer instance."
}
}
}
Expand All @@ -16,16 +18,19 @@ data "aws_vpc" "mz_rds_vpc" {
}

data "aws_db_subnet_group" "mz_rds_subnet_group" {
name = data.aws_db_instance.mz_rds_instance.db_subnet_group
for_each = { for inst in var.mz_rds_instance_details : inst.name => inst }
name = data.aws_db_instance.mz_rds_instance[each.key].db_subnet_group
}

data "aws_subnet" "mz_rds_subnet" {
for_each = toset(data.aws_db_subnet_group.mz_rds_subnet_group.subnet_ids)
id = each.value
for_each = toset(flatten([for inst in var.mz_rds_instance_details : data.aws_db_subnet_group.mz_rds_subnet_group[inst.name].subnet_ids]))

id = each.value
}

data "dns_a_record_set" "rds_ip" {
host = data.aws_db_instance.mz_rds_instance.address
for_each = { for inst in var.mz_rds_instance_details : inst.name => inst }
host = data.aws_db_instance.mz_rds_instance[each.key].address
}

data "aws_iam_policy_document" "lambda_assume_role_policy" {
Expand Down
39 changes: 0 additions & 39 deletions examples/rds_privatelink_setup/.terraform.lock.hcl

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 10 additions & 4 deletions examples/rds_privatelink_setup/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,25 @@ This example demonstrates how to create a new Amazon RDS Postgres instance and c

1. Once the resources have been created, you can test the module with:

> [!NOTE]
> The module requires that each RDS instance has a **unique** listener port.
```hcl
module "materialize_privatelink_rds" {
source = "../.."
mz_rds_instance_name = var.mz_rds_instance_name
mz_rds_vpc_id = module.rds_postgres.vpc.vpc_id
aws_region = var.aws_region
mz_rds_instance_details = [
{ name = "instance1", listener_port = 5001 },
{ name = "instance2", listener_port = 5002 }
]
mz_rds_vpc_id = module.rds_postgres.vpc.vpc_id
aws_region = var.aws_region
}
```
1. **Follow the Output Instructions**
After Terraform successfully applies the configuration, it will output instructions for configuring the PrivateLink endpoint and the Postgres connection in Materialize. Follow these instructions to complete the setup.
After Terraform successfully applies the configuration, it will output instructions for configuring the PrivateLink endpoint and the Postgres connections in Materialize. Follow these instructions to complete the setup.
## Cleanup
Expand Down
4 changes: 2 additions & 2 deletions examples/rds_privatelink_setup/outputs.tf
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
output "rds_instance_endpoint" {
value = module.rds_postgres.rds_instance.endpoint
value = module.rds_postgres.rds_instance.endpoint
sensitive = true
}

output "mz_rds_details" {
value = module.rds_postgres.mz_rds_details
value = module.rds_postgres.mz_rds_details
sensitive = true
}
89 changes: 46 additions & 43 deletions lambda_function.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,60 @@
import boto3
import socket
import os
import json

# Define the clients at the top of your function
# Initialize clients
elbv2_client = boto3.client('elbv2')
rds_client = boto3.client('rds')

RDS_IDENTIFIER = os.environ['RDS_IDENTIFIER'] # RDS instance identifier
TARGET_GROUP_ARN = os.environ['TARGET_GROUP_ARN'] # Target Group ARN
# Load RDS details from environment variables
RDS_DETAILS = json.loads(os.environ['RDS_DETAILS'])

def update_target_registration(rds_identifier, details):
try:
# Retrieve the current IP address of the RDS instance
rds_instances = rds_client.describe_db_instances(DBInstanceIdentifier=rds_identifier)
rds_port = rds_instances['DBInstances'][0]['Endpoint']['Port']
if not rds_instances['DBInstances']:
raise Exception(f"No instances found for {rds_identifier}")

def lambda_handler(event, context):
# Retrieve the current IP address of the RDS instance
rds_instances = rds_client.describe_db_instances(
DBInstanceIdentifier=RDS_IDENTIFIER)
rds_endpoint = rds_instances['DBInstances'][0]['Endpoint']['Address']
ip_address = socket.gethostbyname(rds_endpoint)
rds_port = rds_instances['DBInstances'][0]['Endpoint']['Port']

# Retrieve the existing target of the target group
targets = elbv2_client.describe_target_health(
TargetGroupArn=TARGET_GROUP_ARN)

# Get the current IP address in the target group
if targets['TargetHealthDescriptions']:
current_ip = targets['TargetHealthDescriptions'][0]['Target']['Id']
else:
current_ip = None

# If the IP addresses don't match, update the target group
if current_ip and current_ip != ip_address:
# Deregister the current target
elbv2_client.deregister_targets(
TargetGroupArn=TARGET_GROUP_ARN,
Targets=[
{
'Id': current_ip
},
]
)
rds_endpoint = rds_instances['DBInstances'][0]['Endpoint']['Address']
ip_address = socket.gethostbyname(rds_endpoint)

# Retrieve the existing target of the target group
target_group_arn = details['target_group_arn']
targets = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn)

# Check and update the target group
current_ip = targets['TargetHealthDescriptions'][0]['Target']['Id'] if targets['TargetHealthDescriptions'] else None
if current_ip != ip_address:
if current_ip:
# Deregister the current target
elbv2_client.deregister_targets(TargetGroupArn=target_group_arn, Targets=[{'Id': current_ip}])

# Register the new target
elbv2_client.register_targets(
TargetGroupArn=TARGET_GROUP_ARN,
Targets=[
{
'Id': ip_address,
'Port': rds_port
},
]
)
elbv2_client.register_targets(TargetGroupArn=target_group_arn, Targets=[{'Id': ip_address, 'Port': rds_port}])
message = f"Target group {target_group_arn} updated. New target IP: {ip_address}"
else:
message = f"Target group {target_group_arn} already up to date. Current target IP: {ip_address} and Port: {rds_port}"

return {'success': True, 'message': message}
except Exception as e:
return {'success': False, 'message': f"Failed to update targets for {rds_identifier} with error: {e}"}

def lambda_handler(event, context):
update_messages = []
all_success = True

for rds_identifier, details in RDS_DETAILS.items():
result = update_target_registration(rds_identifier, details)
update_messages.append(result['message'])
if not result['success']:
all_success = False

status_code = 200 if all_success else 500

return {
'statusCode': 200,
'body': f'Target group updated. Current target IP: {ip_address}'
'statusCode': status_code,
'body': json.dumps(update_messages)
}
44 changes: 27 additions & 17 deletions main.tf
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
# Create a target group for the RDS instance
# Create a target group for each RDS instance
resource "aws_lb_target_group" "mz_rds_target_group" {
name = "mz-rds-${substr(var.mz_rds_instance_name, 0, 12)}-tg"
port = data.aws_db_instance.mz_rds_instance.port
for_each = { for inst in var.mz_rds_instance_details : inst.name => inst }

name = "${substr(each.key, 0, 12)}-${each.value.listener_port}-tg"
port = data.aws_db_instance.mz_rds_instance[each.key].port
protocol = "TCP"
vpc_id = data.aws_vpc.mz_rds_vpc.id
target_type = "ip"
}

# Attach a target to the target group
# Attach a target to each target group
resource "aws_lb_target_group_attachment" "mz_rds_target_group_attachment" {
target_group_arn = aws_lb_target_group.mz_rds_target_group.arn
target_id = data.dns_a_record_set.rds_ip.addrs[0]
for_each = { for inst in var.mz_rds_instance_details : inst.name => inst }

target_group_arn = aws_lb_target_group.mz_rds_target_group[each.key].arn
target_id = data.dns_a_record_set.rds_ip[each.key].addrs[0]

lifecycle {
ignore_changes = [target_id]
}
depends_on = [aws_lb_target_group.mz_rds_target_group]
}

# Create a network Load Balancer
resource "aws_lb" "mz_rds_lb" {
name = "mz-rds-${substr(var.mz_rds_instance_name, 0, 12)}-lb"
name = var.mz_nlb_name
internal = true
load_balancer_type = "network"
subnets = values(data.aws_subnet.mz_rds_subnet)[*].id
Expand All @@ -28,14 +34,16 @@ resource "aws_lb" "mz_rds_lb" {
}
}

# Create a tcp listener on the Load Balancer for the RDS instance
# Create listeners for each RDS instance, mapping each to its respective target group
resource "aws_lb_listener" "mz_rds_listener" {
for_each = { for inst in var.mz_rds_instance_details : inst.name => inst }

load_balancer_arn = aws_lb.mz_rds_lb.arn
port = data.aws_db_instance.mz_rds_instance.port
port = each.value.listener_port
protocol = "TCP"
default_action {
type = "forward"
target_group_arn = aws_lb_target_group.mz_rds_target_group.arn
target_group_arn = aws_lb_target_group.mz_rds_target_group[each.key].arn
}
}

Expand All @@ -50,30 +58,32 @@ resource "aws_vpc_endpoint_service" "mz_rds_lb_endpoint_service" {

# Create an IAM policy for the Lambda function
resource "aws_iam_role" "lambda_execution_role" {
name = "lambda_execution_${substr(var.mz_rds_instance_name, 0, 12)}-role"
name = "lambda_execution_${substr(var.mz_nlb_name, 0, 12)}-role"
assume_role_policy = data.aws_iam_policy_document.lambda_assume_role_policy.json
}

# Create a Lambda function to check the RDS instance IP address
resource "aws_lambda_function" "check_rds_ip" {
function_name = "${substr(var.mz_rds_instance_name, 0, 12)}-check-rds-ip"
function_name = "${substr(var.mz_nlb_name, 0, 12)}-check-rds-ip"
role = aws_iam_role.lambda_execution_role.arn
handler = "lambda_function.lambda_handler"
runtime = "python3.11"

filename = data.archive_file.lambda_zip.output_path

source_code_hash = data.archive_file.lambda_zip.output_base64sha256

environment {
variables = {
RDS_IDENTIFIER = var.mz_rds_instance_name
TARGET_GROUP_ARN = aws_lb_target_group.mz_rds_target_group.arn
RDS_DETAILS = jsonencode({ for inst in var.mz_rds_instance_details : inst.name => { port = inst.listener_port, target_group_arn = aws_lb_target_group.mz_rds_target_group[inst.name].arn } })
}
}
}


# Create an IAM policy for the Lambda function
resource "aws_iam_role_policy" "lambda_execution_role_policy" {
name = "${substr(var.mz_rds_instance_name, 0, 12)}-lambda-execution-role-policy"
name = "${substr(var.mz_nlb_name, 0, 12)}-lambda-execution-role-policy"
role = aws_iam_role.lambda_execution_role.id
policy = <<EOF
{
Expand All @@ -98,14 +108,14 @@ EOF
}

resource "aws_cloudwatch_event_rule" "rds_ip_check_rule" {
name = "${substr(var.mz_rds_instance_name, 0, 12)}-rds-ip-check-rule"
name = "${substr(var.mz_nlb_name, 0, 12)}-rds-ip-check-rule"
description = "Fires every ${var.schedule_expression} to check the RDS instance IP address"
schedule_expression = var.schedule_expression
}

resource "aws_cloudwatch_event_target" "check_rds_ip_event_target" {
rule = aws_cloudwatch_event_rule.rds_ip_check_rule.name
target_id = "${substr(var.mz_rds_instance_name, 0, 12)}-check-rds-ip"
target_id = "${substr(var.mz_nlb_name, 0, 12)}-check-rds-ip"
arn = aws_lambda_function.check_rds_ip.arn
}

Expand Down
Loading

0 comments on commit 8878b26

Please sign in to comment.