Skip to content

Commit

Permalink
Apigateway Account APIs (#8119)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattheidelbaugh authored Sep 14, 2024
1 parent 551da1b commit c636d01
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
60 changes: 60 additions & 0 deletions moto/apigateway/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .utils import create_id, to_path

STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
PATCH_OPERATIONS = ["add", "remove", "replace", "move", "copy", "test"]


class Deployment(CloudFormationModel):
Expand Down Expand Up @@ -1020,6 +1021,7 @@ class RestAPI(CloudFormationModel):
PROP_POLICY = "policy"
PROP_DISABLE_EXECUTE_API_ENDPOINT = "disableExecuteApiEndpoint"
PROP_MINIMUM_COMPRESSION_SIZE = "minimumCompressionSize"
PROP_ROOT_RESOURCE_ID = "rootResourceId"

# operations
OPERATION_ADD = "add"
Expand Down Expand Up @@ -1064,6 +1066,7 @@ def __init__(
self.models: Dict[str, Model] = {}
self.request_validators: Dict[str, RequestValidator] = {}
self.default = self.add_child("/") # Add default child
self.root_resource_id = self.default.id

def __repr__(self) -> str:
return str(self.id)
Expand All @@ -1082,6 +1085,7 @@ def to_dict(self) -> Dict[str, Any]:
self.PROP_POLICY: self.policy,
self.PROP_DISABLE_EXECUTE_API_ENDPOINT: self.disableExecuteApiEndpoint,
self.PROP_MINIMUM_COMPRESSION_SIZE: self.minimum_compression_size,
self.PROP_ROOT_RESOURCE_ID: self.root_resource_id,
}

def apply_patch_operations(self, patch_operations: List[Dict[str, Any]]) -> None:
Expand Down Expand Up @@ -1529,6 +1533,54 @@ def to_json(self) -> Dict[str, Any]:
return dct


class Account(BaseModel):
def __init__(self) -> None:
self.cloudwatch_role_arn: Optional[str] = None
self.throttle_settings: Dict[str, Any] = {
"burstLimit": 5000,
"rateLimit": 10000.0,
}
self.features: Optional[List[str]] = None
self.api_key_version: str = "1"

def apply_patch_operations(
self, patch_operations: List[Dict[str, Any]]
) -> "Account":
for op in patch_operations:
if "/cloudwatchRoleArn" in op["path"]:
self.cloudwatch_role_arn = op["value"]
elif "/features" in op["path"]:
if op["op"] == "add":
if self.features is None:
self.features = [op["value"]]
else:
self.features.append(op["value"])
elif op["op"] == "remove":
if op["value"] == "UsagePlans":
raise BadRequestException(
"Usage Plans cannot be disabled once enabled"
)
if self.features is not None:
self.features.remove(op["value"])
else:
raise NotImplementedError(
f'Patch operation "{op["op"]}" for "/features" not implemented'
)
else:
raise NotImplementedError(
f'Patch operation "{op["op"]}" for "{op["path"]}" not implemented'
)
return self

def to_json(self) -> Dict[str, Any]:
return {
"cloudwatchRoleArn": self.cloudwatch_role_arn,
"throttleSettings": self.throttle_settings,
"features": self.features,
"apiKeyVersion": self.api_key_version,
}


class APIGatewayBackend(BaseBackend):
"""
API Gateway mock.
Expand Down Expand Up @@ -1558,6 +1610,7 @@ class APIGatewayBackend(BaseBackend):

def __init__(self, region_name: str, account_id: str):
super().__init__(region_name, account_id)
self.account: Account = Account()
self.apis: Dict[str, RestAPI] = {}
self.keys: Dict[str, ApiKey] = {}
self.usage_plans: Dict[str, UsagePlan] = {}
Expand Down Expand Up @@ -2485,5 +2538,12 @@ def delete_gateway_response(self, rest_api_id: str, response_type: str) -> None:
api = self.get_rest_api(rest_api_id)
api.delete_gateway_response(response_type)

def update_account(self, patch_operations: List[Dict[str, Any]]) -> Account:
account = self.account.apply_patch_operations(patch_operations)
return account

def get_account(self) -> Account:
return self.account


apigateway_backends = BackendDict(APIGatewayBackend, "apigateway")
9 changes: 9 additions & 0 deletions moto/apigateway/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,12 @@ def delete_gateway_response(self) -> TYPE_RESPONSE:
rest_api_id=rest_api_id, response_type=response_type
)
return 202, {}, json.dumps(dict())

def update_account(self) -> str:
patch_operations = self._get_param("patchOperations")
account = self.backend.update_account(patch_operations)
return json.dumps(account.to_json())

def get_account(self) -> str:
account = self.backend.get_account()
return json.dumps(account.to_json())
1 change: 1 addition & 0 deletions moto/apigateway/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"{0}/restapis/(?P<api_id>[^/]+)/gatewayresponses/(?P<response_type>[^/]+)/?$": APIGatewayResponse.dispatch,
"{0}/vpclinks$": APIGatewayResponse.dispatch,
"{0}/vpclinks/(?P<vpclink_id>[^/]+)": APIGatewayResponse.dispatch,
"{0}/account$": APIGatewayResponse.dispatch,
}

# Also manages the APIGatewayV2
Expand Down
81 changes: 81 additions & 0 deletions tests/test_apigateway/test_apigateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_create_and_get_rest_api():
name="my_api", description="this is my api", disableExecuteApiEndpoint=True
)
api_id = response["id"]
root_resource_id = response["rootResourceId"]

response = client.get_rest_api(restApiId=api_id)

Expand All @@ -34,6 +35,7 @@ def test_create_and_get_rest_api():
"endpointConfiguration": {"types": ["EDGE"]},
"tags": {},
"disableExecuteApiEndpoint": True,
"rootResourceId": root_resource_id,
}


Expand All @@ -42,6 +44,7 @@ def test_update_rest_api():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
api_id = response["id"]
root_resource_id = response["rootResourceId"]
patchOperations = [
{"op": "replace", "path": "/name", "value": "new-name"},
{"op": "replace", "path": "/description", "value": "new-description"},
Expand Down Expand Up @@ -71,6 +74,7 @@ def test_update_rest_api():
"endpointConfiguration": {"types": ["EDGE"]},
"tags": {},
"disableExecuteApiEndpoint": True,
"rootResourceId": root_resource_id,
}
# should fail with wrong apikeysoruce
patchOperations = [
Expand Down Expand Up @@ -2486,3 +2490,80 @@ def test_update_path_mapping_with_unknown_stage():
assert ex.value.response["Error"]["Message"] == "Invalid stage identifier specified"
assert ex.value.response["Error"]["Code"] == "BadRequestException"
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400


@mock_aws
def test_update_account():
client = boto3.client("apigateway", region_name="eu-west-1")

patch_operations = [
{
"op": "replace",
"path": "/cloudwatchRoleArn",
"value": "arn:aws:iam:123456789012:role/moto-test-apigw-role-1",
},
{"op": "add", "path": "/features", "value": "UsagePlans"},
{"op": "add", "path": "/features", "value": "TestFeature"},
]

account = client.update_account(patchOperations=patch_operations)

assert (
account["cloudwatchRoleArn"]
== "arn:aws:iam:123456789012:role/moto-test-apigw-role-1"
)
assert account["features"] == ["UsagePlans", "TestFeature"]

patch_operations = [
{
"op": "replace",
"path": "/cloudwatchRoleArn",
"value": "arn:aws:iam:123456789012:role/moto-test-apigw-role-2",
},
{"op": "remove", "path": "/features", "value": "TestFeature"},
]

account = client.update_account(patchOperations=patch_operations)

assert (
account["cloudwatchRoleArn"]
== "arn:aws:iam:123456789012:role/moto-test-apigw-role-2"
)
assert account["throttleSettings"]["burstLimit"] == 5000
assert account["throttleSettings"]["rateLimit"] == 10000.0
assert account["apiKeyVersion"] == "1"
assert account["features"] == ["UsagePlans"]


@mock_aws
def test_update_account_error():
client = boto3.client("apigateway", region_name="eu-west-1")
patch_operations = [
{
"op": "remove",
"path": "/features",
"value": "UsagePlans",
},
]

with pytest.raises(ClientError) as ex:
client.update_account(patchOperations=patch_operations)

assert (
ex.value.response["Error"]["Message"]
== "Usage Plans cannot be disabled once enabled"
)
assert ex.value.response["Error"]["Code"] == "BadRequestException"
assert ex.value.response["ResponseMetadata"]["HTTPStatusCode"] == 400


@mock_aws
def test_get_account():
client = boto3.client("apigateway", region_name="eu-west-1")
account = client.get_account()

assert account["throttleSettings"]["burstLimit"] == 5000
assert account["throttleSettings"]["rateLimit"] == 10000.0
assert account["apiKeyVersion"] == "1"
assert "features" not in account
assert "cloudwatchRoleArn" not in account

0 comments on commit c636d01

Please sign in to comment.