Skip to content

Commit

Permalink
fix: add lambda to return rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerald Iakobinyi-Pich committed Dec 16, 2024
1 parent d428100 commit 533006c
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ postgres_db_passport_data
.DS_Store
**/.next/
infra/aws/python
__lambda__
38 changes: 28 additions & 10 deletions api/embed/lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
ProgrammingError,
Unauthorized,
parse_body,
strip_event,
)

""" Load the django apps after the aws_lambdas.utils """ # pylint: disable=pointless-string-statement
from django.db import close_old_connections
from ninja import Schema

from account.models import Account, AccountAPIKey, AccountAPIKeyAnalytics
from ceramic_cache.api.schema import CacheStampPayload
from ceramic_cache.api.v1 import handle_add_stamps
from registry.api.utils import (
Expand All @@ -44,6 +46,8 @@
InvalidAddressException,
)

from .api import AccountAPIKeySchema, AddStampsPayload

# pylint: enable=wrong-import-position

logger = logging.getLogger(__name__)
Expand All @@ -55,15 +59,17 @@ def with_embed_request_exception_handling(func):
"""

def wrapper(event, context, *args):
def wrapper(_event, context, *args):
try:
bind_contextvars(request_id=context.aws_request_id)
sensitive_data, event = strip_event(_event)

logger.info("Received event: %s", event)

# Parse the body and call the function
body = parse_body(event)

return func(event, context, body)
return func(event, context, body, sensitive_data)
except Exception as e:
if isinstance(e, APIException):
status = e.status_code
Expand Down Expand Up @@ -115,11 +121,6 @@ def wrapper(event, context, *args):
return wrapper


class AddStampsPayload(Schema):
scorer_id: int
stamps: List[Any]


# Define the pattern
pattern = r"/([^/]+)/?$"

Expand All @@ -135,7 +136,7 @@ def get_address(value: str) -> str:


@with_embed_request_exception_handling
def _handler(event, _context, body):
def _handler_save_stamps(event, _context, body, _sensitive_date):
"""
Request handler implementation.
Expand Down Expand Up @@ -170,6 +171,23 @@ def _handler(event, _context, body):
}


def lambda_handler(*args, **kwargs):
def lambda_handler_save_stamps(*args, **kwargs):
close_old_connections()
return _handler_save_stamps(*args, **kwargs)


@with_embed_request_exception_handling
def _handler_get_rate_limit(_event, _context, body, sensitive_date):
# TODO: raise 404 if key does not exist
api_key = AccountAPIKey.objects.get_from_key(sensitive_date["x-api-key"])

return {
"statusCode": 200,
"headers": {"Content-Type": "application/json"},
"body": AccountAPIKeySchema.from_orm(api_key).model_dump_json(),
}


def lambda_handler_get_rate_limit(*args, **kwargs):
close_old_connections()
return _handler(*args, **kwargs)
return _handler_get_rate_limit(*args, **kwargs)
2 changes: 1 addition & 1 deletion infra/aws/embed/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export function createEmbedLambda(config: {
description: "Handle requests related to the embed API",
code: new pulumi.asset.FileArchive("lambda_function_payload.zip"),
// role: lambdaRole.arn,
handler: "embed.lambda.lambda_handler", // TODO: change this
handler: "embed.lambda.lambda_handler_save_stamps", // TODO: change this
sourceCodeHash: lambdaCode.then((archive) => archive.outputBase64sha256),
runtime: aws.lambda.Runtime.Python3d12,
environment: {
Expand Down
173 changes: 173 additions & 0 deletions infra/aws/embed/rate_limit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import * as pulumi from "@pulumi/pulumi";
import * as archive from "@pulumi/archive";
import * as aws from "@pulumi/aws";
import { ListenerRule } from "@pulumi/aws/lb";
import { Listener } from "@pulumi/aws/alb";
import { secretsManager } from "infra-libs";
import { defaultTags, stack } from "../../lib/tags";

import { createLambdaFunction } from "../../lib/lambda";

export function createEmbedLambdaRateLimit(config: {
name: string;
snsAlertsTopicArn: pulumi.Input<string>;
httpsListener: pulumi.Output<Listener>;
ceramicCacheScorerId: number;
scorerSecret: aws.secretsmanager.Secret;
privateSubnetSecurityGroup: aws.ec2.SecurityGroup;
vpcId: pulumi.Input<string>;
vpcPrivateSubnetIds: pulumi.Input<any>;
lambdaLayerArn: pulumi.Input<string>;
bucketId: pulumi.Input<string>;
}) {
const apiLambdaEnvironment = [
...secretsManager.getEnvironmentVars({
vault: "DevOps",
repo: "passport-scorer",
env: stack,
section: "api",
}),
{
name: "DEBUG",
value: "off",
},
{
name: "LOGGING_STRATEGY",
value: "structlog_json",
},
{
name: "FF_API_ANALYTICS",
value: "on",
},
{
name: "CERAMIC_CACHE_SCORER_ID",
value: `${config.ceramicCacheScorerId}`,
},
{
name: "SCORER_SERVER_SSM_ARN",
value: config.scorerSecret.arn,
},
{
name: "VERIFIER_URL",
value: "http://core-alb.private.gitcoin.co/verifier/verify",
},
].sort(secretsManager.sortByName);

// The lambda will contain our own code (everything from the `api` folder for now)
const lambdaCode = archive.getFile({
type: "zip",
sourceDir: "../../api",
outputPath: "lambda_function_payload.zip",
excludes: ["**/__pycache__"],
});

const lambdaName = `${config.name}-lambda`;
const { lambdaFunction, lambdaFunctionUrl } = createLambdaFunction(
[config.scorerSecret.arn],
config.vpcId,
config.vpcPrivateSubnetIds,
{
name: lambdaName,
description: "Retreive the rate limit for an API key",
code: new pulumi.asset.FileArchive("lambda_function_payload.zip"),
// role: lambdaRole.arn,
handler: "embed.lambda.lambda_handler_get_rate_limit", // TODO: change this
sourceCodeHash: lambdaCode.then((archive) => archive.outputBase64sha256),
runtime: aws.lambda.Runtime.Python3d12,
environment: {
variables: apiLambdaEnvironment.reduce(
(
acc: { [key: string]: pulumi.Input<string> },
e: { name: string; value: pulumi.Input<string> }
) => {
acc[e.name] = e.value;
return acc;
},
{}
),
},
memorySize: 128,
timeout: 60,
layers: [config.lambdaLayerArn],
tags: {
...defaultTags,
Name: lambdaName,
},
}
);

// Create alarm to monitor lambda errors
const metricAlarmName = `${config.name}-lambda-errors`;
const lambdaErrorsAlarm = new aws.cloudwatch.MetricAlarm(metricAlarmName, {
tags: { ...defaultTags, Name: metricAlarmName },
alarmActions: [config.snsAlertsTopicArn],
okActions: [config.snsAlertsTopicArn],
comparisonOperator: "GreaterThanOrEqualToThreshold",
dimensions: {
FunctionName: lambdaName,
},
datapointsToAlarm: 3,
evaluationPeriods: 5,
metricName: "Errors",
name: metricAlarmName,
namespace: "AWS/Lambda",
period: 60, // 1 min
unit: "Seconds",
statistic: "SampleCount",
treatMissingData: "notBreaching",
threshold: 1,
});

///////////////////////////////////////////////////////////////////////////
const lambdaTargetGroup = new aws.lb.TargetGroup(
`${config.name}-lambda-target-group`,
{
name: `${config.name}-lambda-target-group`,
targetType: "lambda",
tags: { ...defaultTags, Name: `${config.name}-lambda` },
}
);

const withLb = new aws.lambda.Permission(`${config.name}-lambda-permission`, {
action: "lambda:InvokeFunction",
function: lambdaFunction.name,
principal: "elasticloadbalancing.amazonaws.com",
sourceArn: lambdaTargetGroup.arn,
});
const lambdaTargetGroupAttachment = new aws.lb.TargetGroupAttachment(
`${config.name}-lambda-target-group-attachment`,
{
targetGroupArn: lambdaTargetGroup.arn,
targetId: lambdaFunction.arn,
},
{
dependsOn: [withLb],
}
);

const conditions: any = [
{
pathPattern: {
values: ["/embed/validate-api-key"],
},
},
{
httpRequestMethod: {
values: ["GET"],
},
},
];

const targetPassportRule = new ListenerRule(`${config.name}-rule-lambda`, {
tags: { ...defaultTags, Name: `${config.name}-rule-lambda` },
listenerArn: config.httpsListener.arn,
priority: 2101,
actions: [
{
type: "forward",
targetGroupArn: lambdaTargetGroup.arn,
},
],
conditions,
});
}
14 changes: 14 additions & 0 deletions infra/aws/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { createS3InitiatedECSTask } from "../lib/scorer/s3_initiated_ecs_task";
import { stack, defaultTags, StackType } from "../lib/tags";
import { createV2Api } from "./v2/index";
import { createEmbedLambda } from "./embed";
import { createEmbedLambdaRateLimit } from "./embed/rate_limit";
import { createPythonLambdaLayer } from "./layer";

//////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -2215,3 +2216,16 @@ const embedLambda = createEmbedLambda({
lambdaLayerArn: pythonLambdaLayer.arn,
bucketId: codeBucketId,
});

const embedLambdaRateLimit = createEmbedLambdaRateLimit({
name: "embed-rl",
vpcId: vpcID,
snsAlertsTopicArn: pagerdutyTopic.arn,
httpsListener: httpsListener,
ceramicCacheScorerId: CERAMIC_CACHE_SCORER_ID,
scorerSecret: scorerSecret,
privateSubnetSecurityGroup: privateSubnetSecurityGroup,
vpcPrivateSubnetIds: vpcPrivateSubnetIds,
lambdaLayerArn: pythonLambdaLayer.arn,
bucketId: codeBucketId,
});

0 comments on commit 533006c

Please sign in to comment.