diff --git a/spec/04-services/05-sts_spec.lua b/spec/04-services/05-sts_spec.lua new file mode 100644 index 0000000..cb7660e --- /dev/null +++ b/spec/04-services/05-sts_spec.lua @@ -0,0 +1,188 @@ +setmetatable(_G, nil) + +-- -- hock request sending +-- package.loaded["resty.aws.request.execute"] = function(...) +-- return ... +-- end + +local AWS = require("resty.aws") +local AWS_global_config = require("resty.aws.config").global + +local config = AWS_global_config +local aws = AWS(config) + +aws.config.credentials = aws:Credentials { + accessKeyId = "test_id", + secretAccessKey = "test_key", +} + +-- aws.config.region = "test_region" + +local test_assume_role_arn = "arn:aws:iam::123456789012:role/test-role" +local test_role_session_name = "lua-resty-aws-test-assumeRole" + +describe("STS service", function() + local origin_time + setup(function() + origin_time = ngx.time + ngx.time = function () --luacheck: ignore + return 1667543171 + end + end) + + teardown(function () + ngx.time = origin_time --luacheck: ignore + end) + + -- before_each(function() + -- sts = aws:STS() + -- end) + + -- after_each(function() + + -- end) + + for _, region in ipairs({"us-east-1", "us-east-2", "ap-south-1", "ca-west-1", "eu-west-2", "sa-east-1"}) do + describe("In Region #" .. region, function () + -- before_each(function() + -- aws.config.region = region + -- end) + + it("AWS_STS_REGIONAL_ENDPOINT==regional with default endpoint", function () + local config = { + region = region, + stsRegionalEndpoints = "regional", + dry_run = true, + } + + local sts = aws:STS(config) + local request = sts:assumeRole({ + RoleArn = test_assume_role_arn, + RoleSessionName = test_role_session_name, + }) + + assert.same(sts.config.stsRegionalEndpoints, "regional") + -- Check the signing region has been injected + assert.same(region, sts.config.signingRegion) + assert.truthy(sts.config._regionalEndpointInjected) + -- Check the endpoint has been injected + assert.same(sts.config.endpoint, "https://sts." .. region .. ".amazonaws.com") + assert.not_nil(request.headers.Authorization:find(region, 1, true)) + end) + + describe("AWS_STS_REGIONAL_ENDPOINT==regional with non-default endpoint", function() + it("and endpoint is regional domain", function () + local config = { + region = region, + stsRegionalEndpoints = "regional", + endpoint = "https://sts." .. region .. ".amazonaws.com", + dry_run = true, + } + + local sts = aws:STS(config) + local request = sts:assumeRole({ + RoleArn = test_assume_role_arn, + RoleSessionName = test_role_session_name, + }) + + assert.same(sts.config.stsRegionalEndpoints, "regional") + -- Check the signing region has been injected + assert.same(region, sts.config.signingRegion) + assert.truthy(sts.config._regionalEndpointInjected) + -- Check thes endpoint has not been injected twice + assert.same(sts.config.endpoint, config.endpoint) + assert.not_nil(request.headers.Authorization:find(region, 1, true)) + end) + + it("and endpoint is global domain", function () + local config = { + region = region, + stsRegionalEndpoints = "regional", + endpoint = "https://sts.amazonaws.com", + dry_run = true, + } + + local sts = aws:STS(config) + local request = sts:assumeRole({ + RoleArn = test_assume_role_arn, + RoleSessionName = test_role_session_name, + }) + + assert.same(sts.config.stsRegionalEndpoints, "regional") + -- Check the signing region has been injected + assert.same(region, sts.config.signingRegion) + assert.truthy(sts.config._regionalEndpointInjected) + -- Check the endpoint has been injected + assert.same(sts.config.endpoint, "https://sts." .. region .. ".amazonaws.com") + assert.not_nil(request.headers.Authorization:find(region, 1, true)) + end) + + it("and endpoint is region VPC endpoint", function () + local config = { + region = region, + stsRegionalEndpoints = "regional", + endpoint = "https://vpce-1234567-abcdefg.sts." .. region .. ".vpce.amazonaws.com", + dry_run = true, + } + + local sts = aws:STS(config) + local request = sts:assumeRole({ + RoleArn = test_assume_role_arn, + RoleSessionName = test_role_session_name, + }) + + assert.same(sts.config.stsRegionalEndpoints, "regional") + -- Check the signing region has been injected + assert.same(region, sts.config.signingRegion) + assert.truthy(sts.config._regionalEndpointInjected) + -- Check the endpoint has not been injected when endpoint is a vpc endpoint + assert.same(sts.config.endpoint, config.endpoint) + assert.not_nil(request.headers.Authorization:find(region, 1, true)) + end) + + it("and endpoint is AZ VPC endpoint", function () + local config = { + region = region, + stsRegionalEndpoints = "regional", + endpoint = "https://vpce-1234567-abcdefg-" .. region .. "c" .. ".sts." .. region .. ".vpce.amazonaws.com", + dry_run = true, + } + + local sts = aws:STS(config) + local request = sts:assumeRole({ + RoleArn = test_assume_role_arn, + RoleSessionName = test_role_session_name, + }) + + assert.same(sts.config.stsRegionalEndpoints, "regional") + -- Check the signing region has been injected + assert.same(region, sts.config.signingRegion) + assert.truthy(sts.config._regionalEndpointInjected) + -- Check the endpoint has not been injected when endpoint is a vpc endpoint + assert.same(sts.config.endpoint, config.endpoint) + assert.not_nil(request.headers.Authorization:find(region, 1, true)) + end) + end) + + it("AWS_STS_REGIONAL_ENDPOINT==legacy with default endpoint", function () + local config = { + region = region, + stsRegionalEndpoints = "legacy", + dry_run = true, + } + + local sts = aws:STS(config) + local request = sts:assumeRole({ + RoleArn = test_assume_role_arn, + RoleSessionName = test_role_session_name, + }) + + assert.same(sts.config.stsRegionalEndpoints, "legacy") + assert.same("us-east-1", sts.config.signingRegion) + assert.is_nil(sts.config._regionalEndpointInjected) + assert.same(sts.config.endpoint, "https://sts.amazonaws.com") + assert.not_nil(request.headers.Authorization:find("us-east-1", 1, true)) + end) + end) + end +end) diff --git a/src/resty/aws/init.lua b/src/resty/aws/init.lua index f5b7b2a..9d3b9a1 100644 --- a/src/resty/aws/init.lua +++ b/src/resty/aws/init.lua @@ -260,6 +260,24 @@ do end +local is_regional_sts_domain do + -- from the list described in https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html + -- TODO: not sure if gov cloud also has their own endpoints so leave it for now + local stsRegionRegexes = { + [[sts\.(us|eu|ap|sa|ca|me)\-\w+\-\d+\.amazonaws\.com$]], + [[sts\.cn\-\w+\-\d+\.amazonaws\.com\.cn$]], + } + + function is_regional_sts_domain(domain) + for _, entry in ipairs(stsRegionRegexes) do + if ngx.re.match(domain, entry, "jo") then + return true + end + end + + return false + end +end -- written from scratch @@ -325,14 +343,19 @@ local function generate_service_methods(service) -- https://github.com/aws/aws-sdk-js/blob/307e82673b48577fce4389e4ce03f95064e8fe0d/lib/services/sts.js#L78-L82 assert(service.config.region, "region is required when using STS regional endpoints") - -- If the endpoint is a VPC endpoint DNS hostname then we don't need to inject the region - -- VPC endpoint DNS hostnames always contain region, see - -- https://docs.aws.amazon.com/vpc/latest/privatelink/privatelink-access-aws-services.html#interface-endpoint-dns-hostnames - if not service.config._regionalEndpointInjected and not service.config.endpoint:match(AWS_VPC_ENDPOINT_DOMAIN_PATTERN) then - local pre, post = service.config.endpoint:match(AWS_PUBLIC_DOMAIN_PATTERN) - service.config.endpoint = pre .. "." .. service.config.region .. post - service.config.signingRegion = service.config.region + if not service.config._regionalEndpointInjected then service.config._regionalEndpointInjected = true + -- stsRegionalEndpoints is set to 'regional', so inject region into the + -- signingRegion to override global region_config_data + service.config.signingRegion = service.config.region + + -- If the endpoint is a VPC endpoint DNS hostname, or a regional STS domain, then we don't need to inject the region + -- VPC endpoint DNS hostnames always contain region, see + -- https://docs.aws.amazon.com/vpc/latest/privatelink/privatelink-access-aws-services.html#interface-endpoint-dns-hostnames + if not service.config.endpoint:match(AWS_VPC_ENDPOINT_DOMAIN_PATTERN) and not is_regional_sts_domain(service.config.endpoint) then + local pre, post = service.config.endpoint:match(AWS_PUBLIC_DOMAIN_PATTERN) + service.config.endpoint = pre .. "." .. service.config.region .. post + end end end