Skip to content

Commit

Permalink
feat: add new methods to ds and awslambda
Browse files Browse the repository at this point in the history
  • Loading branch information
oxh252 committed Feb 19, 2025
1 parent 41dd981 commit b40901d
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 12 deletions.
22 changes: 20 additions & 2 deletions moto/awslambda/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def __init__(self, spec: Dict[str, Any], account_id: str, region: str):
self.compatible_architectures = spec.get("CompatibleArchitectures", [])
self.compatible_runtimes = spec.get("CompatibleRuntimes", [])
self.license_info = spec.get("LicenseInfo", "")
self.policy = Policy(self) # type: ignore[no-untyped-call]

# auto-generated
self.created_date = utcnow().strftime("%Y-%m-%d %H:%M:%S")
Expand Down Expand Up @@ -620,7 +621,7 @@ def __init__(
self.run_time = spec.get("Runtime")
self.logs_backend = logs_backends[account_id][self.region]
self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.policy = Policy(self)
self.policy = Policy(self) # type: ignore[no-untyped-call]
self.url_config: Optional[FunctionUrlConfig] = None
self.state = "Active"
self.reserved_concurrency = spec.get("ReservedConcurrentExecutions", None)
Expand Down Expand Up @@ -2348,7 +2349,8 @@ def add_permission(
self, function_name: str, qualifier: str, raw: str
) -> Dict[str, Any]:
fn = self.get_function(function_name, qualifier)
return fn.policy.add_statement(raw, qualifier)
statement, revision = fn.policy.add_statement(raw, qualifier)
return statement

def remove_permission(
self, function_name: str, sid: str, revision: str = ""
Expand Down Expand Up @@ -2499,6 +2501,22 @@ def list_function_event_invoke_configs(self, function_name: str) -> Dict[str, An
except UnknownEventConfig:
return response

def add_layer_version_permission(
self, layer_name: str, version_number: int, statement: str
) -> Tuple[str, str]:
layer_version = self.get_layer_version(layer_name, str(version_number))
return layer_version.policy.add_statement(statement)

def get_layer_version_policy(self, layer_name: str, version_number: int) -> str:
layer_version = self.get_layer_version(layer_name, str(version_number))
return layer_version.policy.wire_format()

def remove_layer_version_permission(
self, layer_name: str, version_number: str, sid: str, revision: str = ""
) -> None:
layer_version = self.get_layer_version(layer_name, str(version_number))
layer_version.policy.del_statement(sid, revision)


def do_validate_s3() -> bool:
return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"]
Expand Down
21 changes: 11 additions & 10 deletions moto/awslambda/policy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar

from moto.awslambda.exceptions import (
GenericResourcNotFound,
Expand All @@ -8,15 +8,11 @@
)
from moto.moto_api._internal import mock_random

if TYPE_CHECKING:
from .models import LambdaFunction


TYPE_IDENTITY = TypeVar("TYPE_IDENTITY")


class Policy:
def __init__(self, parent: "LambdaFunction"):
def __init__(self, parent): # type: ignore[no-untyped-def]
self.revision = str(mock_random.uuid4())
self.statements: List[Dict[str, Any]] = []
self.parent = parent
Expand All @@ -41,7 +37,7 @@ def get_policy(self) -> Dict[str, Any]:
# adds the raw JSON statement to the policy
def add_statement(
self, raw: str, qualifier: Optional[str] = None
) -> Dict[str, Any]:
) -> Tuple[Any, str]:
policy = json.loads(raw, object_hook=self.decode_policy)
if len(policy.revision) > 0 and self.revision != policy.revision:
raise PreconditionFailedException(
Expand All @@ -58,7 +54,7 @@ def add_statement(
)
self.statements.append(policy.statements[0])
self.revision = str(mock_random.uuid4())
return policy.statements[0]
return policy.statements[0], self.revision

# removes the statement that matches 'sid' from the policy
def del_statement(self, sid: str, revision: str = "") -> None:
Expand All @@ -78,12 +74,17 @@ def del_statement(self, sid: str, revision: str = "") -> None:
# converts AddPermission request to PolicyStatement
# https://docs.aws.amazon.com/lambda/latest/dg/API_AddPermission.html
def decode_policy(self, obj: Dict[str, Any]) -> "Policy":
policy = Policy(self.parent)
policy = Policy(self.parent) # type: ignore[no-untyped-call]
policy.revision = obj.get("RevisionId", "")
# get function_arn or arn from parent
if hasattr(self.parent, "arn"):
resource_arn = self.parent.arn
else:
resource_arn = self.parent.function_arn

# set some default values if these keys are not set
self.ensure_set(obj, "Effect", "Allow")
self.ensure_set(obj, "Resource", self.parent.function_arn + ":$LATEST")
self.ensure_set(obj, "Resource", resource_arn + ":$LATEST")
self.ensure_set(obj, "StatementId", str(mock_random.uuid4()))

# transform field names and values
Expand Down
31 changes: 31 additions & 0 deletions moto/awslambda/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,34 @@ def list_function_event_invoke_configs(self) -> str:
return json.dumps(
self.backend.list_function_event_invoke_configs(function_name)
)

def add_layer_version_permission(self) -> str:
statement = self.body
layer_name = self._get_param("LayerName")
version_number = self._get_param("VersionNumber")
statement, revision_id = self.backend.add_layer_version_permission(
layer_name=layer_name,
version_number=version_number,
statement=statement,
)
return json.dumps(dict(Statement=json.dumps(statement), RevisionId=revision_id))

def get_layer_version_policy(self) -> str:
layer_name = self._get_param("LayerName")
version_number = self._get_param("VersionNumber")
return self.backend.get_layer_version_policy(
layer_name=layer_name, version_number=version_number
)

def remove_layer_version_permission(self) -> TYPE_RESPONSE:
layer_name = self._get_param("LayerName")
version_number = self._get_param("VersionNumber")
statement_id = self.path.split("/")[-1].split("?")[0]
revision = self.querystring.get("RevisionId", "")
if self.backend.get_layer_version(layer_name, version_number):
self.backend.remove_layer_version_permission(
layer_name, version_number, statement_id, revision
)
return 204, {"status": 204}, "{}"
else:
return 404, {"status": 404}, "{}"
2 changes: 2 additions & 0 deletions moto/awslambda/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@
r"{0}/(?P<api_version>[^/]+)/layers/(?P<layer_name>.+)/versions/(?P<layer_version>[\w_-]+)$": LambdaResponse.dispatch,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/event-invoke-config/?$": LambdaResponse.dispatch,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/event-invoke-config/list$": LambdaResponse.dispatch,
r"{0}/(?P<api_version>[^/]+)/layers/(?P<layer_name>.+)/versions/(?P<layer_version>[\w_-]+)/policy$": LambdaResponse.dispatch,
r"{0}/(?P<api_version>[^/]+)/layers/(?P<layer_name>.+)/versions/(?P<layer_version>[\w_-]+)/policy/(?P<statement_id>[\w_-]+)$": LambdaResponse.dispatch,
}
49 changes: 49 additions & 0 deletions moto/ds/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,31 @@ def to_dict(self) -> Dict[str, Any]:
return attributes


class LogSubscription(BaseModel):
def __init__(
self,
directory_id: str,
log_group_name: str,
):
self.directory_id = directory_id
self.log_group_name = log_group_name
self.subscription_created_date_time = unix_time()

def to_dict(self) -> Dict[str, Any]:
return {
"DirectoryId": self.directory_id,
"LogGroupName": self.log_group_name,
"SubscriptionCreatedDateTime": self.subscription_created_date_time,
}


class DirectoryServiceBackend(BaseBackend):
"""Implementation of DirectoryService APIs."""

def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.directories: Dict[str, Directory] = {}
self.log_subscriptions: Dict[str, LogSubscription] = {}
self.tagger = TaggingService()

def _verify_subnets(self, region: str, vpc_settings: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -740,5 +759,35 @@ def update_settings(self, directory_id: str, settings: List[Dict[str, Any]]) ->

return directory_id

def create_log_subscription(self, directory_id: str, log_group_name: str) -> None:
self._validate_directory_id(directory_id)
directory = self.directories[directory_id]
if directory.directory_type != "MicrosoftAD":
raise UnsupportedOperationException(
message="Log subscriptions are only supported for Microsoft AD."
)
log_subscription = LogSubscription(directory_id, log_group_name)
self.log_subscriptions[directory_id] = log_subscription
return

def delete_log_subscription(self, directory_id: str) -> None:
self._validate_directory_id(directory_id)
self.log_subscriptions.pop(directory_id)
return

def list_log_subscriptions(
self, directory_id: Optional[str]
) -> List[LogSubscription]:
"""
Pagination is not yet implemented
"""
if directory_id:
self._validate_directory_id(directory_id)
log_subscriptions = [self.log_subscriptions[directory_id]]
else:
log_subscriptions = list(self.log_subscriptions.values())

return log_subscriptions


ds_backends = BackendDict(DirectoryServiceBackend, service_name="ds")
25 changes: 25 additions & 0 deletions moto/ds/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,28 @@ def update_settings(self) -> str:
settings=settings,
)
return json.dumps(dict(DirectoryId=directory_id))

def create_log_subscription(self) -> str:
directory_id = self._get_param("DirectoryId")
log_group_name = self._get_param("LogGroupName")
self.ds_backend.create_log_subscription(
directory_id=directory_id,
log_group_name=log_group_name,
)
return json.dumps(dict())

def delete_log_subscription(self) -> str:
directory_id = self._get_param("DirectoryId")
self.ds_backend.delete_log_subscription(
directory_id=directory_id,
)
return json.dumps(dict())

def list_log_subscriptions(self) -> str:
directory_id = self._get_param("DirectoryId")
log_subscriptions = self.ds_backend.list_log_subscriptions(
directory_id=directory_id,
)
return json.dumps(
dict(LogSubscriptions=[ls.to_dict() for ls in log_subscriptions])
)
135 changes: 135 additions & 0 deletions tests/test_awslambda/test_lambda_layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import sys
from unittest import SkipTest, mock
Expand Down Expand Up @@ -321,3 +322,137 @@ def get_layer_by_layer_name_from_list_of_layer_dicts(layer_name, layer_list):
)["LatestMatchingVersion"]["Version"]
== 3
)


@mock_aws
def test_add_layer_version_permission():
if LooseVersion(boto3_version) < LooseVersion("1.29.0"):
raise SkipTest("Parameters only available in newer versions")
bucket_name = str(uuid4())
s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)

zip_content = get_test_zip_file1()
s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", _lambda_region)
layer_name = str(uuid4())[0:6]

resp = conn.publish_layer_version(
LayerName=layer_name,
Content={"ZipFile": get_test_zip_file1()},
CompatibleRuntimes=["python3.6"],
LicenseInfo="MIT",
CompatibleArchitectures=["x86_64"],
)
layer_version = resp["Version"]
resp = conn.add_layer_version_permission(
LayerName=layer_name,
VersionNumber=layer_version,
StatementId="xaccount",
Action="lambda:GetLayerVersion",
Principal="432143214321",
OrganizationId="o-123456",
)
assert "RevisionId" in resp
assert "Statement" in resp
res = json.loads(resp["Statement"])
assert res["Action"] == "lambda:GetLayerVersion"


@mock_aws
def test_get_layer_version_policy():
if LooseVersion(boto3_version) < LooseVersion("1.29.0"):
raise SkipTest("Parameters only available in newer versions")
bucket_name = str(uuid4())
s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)

zip_content = get_test_zip_file1()
s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", _lambda_region)
layer_name = str(uuid4())[0:6]

resp = conn.publish_layer_version(
LayerName=layer_name,
Content={"ZipFile": get_test_zip_file1()},
CompatibleRuntimes=["python3.6"],
LicenseInfo="MIT",
CompatibleArchitectures=["x86_64"],
)
layer_version = resp["Version"]
conn.add_layer_version_permission(
LayerName=layer_name,
VersionNumber=layer_version,
StatementId="xaccount",
Action="lambda:GetLayerVersion",
Principal="432143214321",
)
resp = conn.get_layer_version_policy(
LayerName=layer_name, VersionNumber=layer_version
)
assert "Policy" in resp
assert "RevisionId" in resp
res = json.loads(resp["Policy"])
assert res["Statement"][0]["Action"] == "lambda:GetLayerVersion"
assert (
res["Statement"][0]["Resource"]
== f"arn:aws:lambda:us-west-2:123456789012:layer:{layer_name}:1"
)


@mock_aws
def test_remove_layer_version_permission():
if LooseVersion(boto3_version) < LooseVersion("1.29.0"):
raise SkipTest("Parameters only available in newer versions")
bucket_name = str(uuid4())
s3_conn = boto3.client("s3", _lambda_region)
s3_conn.create_bucket(
Bucket=bucket_name,
CreateBucketConfiguration={"LocationConstraint": _lambda_region},
)

zip_content = get_test_zip_file1()
s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content)
conn = boto3.client("lambda", _lambda_region)
layer_name = str(uuid4())[0:6]

resp = conn.publish_layer_version(
LayerName=layer_name,
Content={"ZipFile": get_test_zip_file1()},
CompatibleRuntimes=["python3.6"],
LicenseInfo="MIT",
CompatibleArchitectures=["x86_64"],
)
layer_version = resp["Version"]
conn.add_layer_version_permission(
LayerName=layer_name,
VersionNumber=layer_version,
StatementId="xaccount",
Action="lambda:GetLayerVersion",
Principal="432143214321",
)
resp = conn.get_layer_version_policy(
LayerName=layer_name, VersionNumber=layer_version
)
assert "Policy" in resp

resp = conn.remove_layer_version_permission(
LayerName=layer_name,
VersionNumber=layer_version,
StatementId="xaccount",
)
assert resp["ResponseMetadata"]["HTTPStatusCode"] == 204
with pytest.raises(ClientError) as exc:
conn.get_layer_version_policy(
LayerName=layer_name, VersionNumber=layer_version
)["Policy"]

err = exc.value.response["Error"]
assert err["Code"] == "ResourceNotFoundException"
assert err["Message"] == "The resource you requested does not exist."
Loading

0 comments on commit b40901d

Please sign in to comment.