Skip to content

Commit

Permalink
Merge pull request #119 from aws-beam/add-additional-special-handling…
Browse files Browse the repository at this point in the history
…-for-apigatewaymanagementapi-due-to-api-id

Add additional handling for apigatewaymanagementapi since it requires an api-id to be passed into the host
  • Loading branch information
onno-vos-dev authored Jan 25, 2025
2 parents feec7b3 + b2f60f1 commit acb7a87
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions priv/rest.erl.eex
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<%= if context.docstring != "%% @doc" do %><%= context.docstring %><% end %>
-module(<%= context.module_name %>).
<%= if context.module_name == "aws_apigatewaymanagementapi" do %>
-export([<%= Enum.map(context.actions, fn(action) -> if action.method == "GET" do ["#{action.function_name}/#{action.arity - 2}"] else [] end ++ ["#{action.function_name}/#{action.arity}", "#{action.function_name}/#{action.arity + 1}"] end) |> List.flatten |> Enum.join(",\n ") %>]).
-export([<%= Enum.map(context.actions, fn(action) -> if action.method == "GET" do ["#{action.function_name}/#{action.arity - 1}"] else [] end ++ ["#{action.function_name}/#{action.arity + 1}", "#{action.function_name}/#{action.arity + 2}"] end) |> List.flatten |> Enum.join(",\n ") %>]).
<% else %>
-export([<%= Enum.map(context.actions, fn(action) -> if action.method == "GET" do ["#{action.function_name}/#{action.arity - 3}"] else [] end ++ ["#{action.function_name}/#{action.arity - 1}", "#{action.function_name}/#{action.arity}"] end) |> List.flatten |> Enum.join(",\n ") %>]).
<% end %>
Expand Down Expand Up @@ -50,21 +50,21 @@ end) %>
%%====================================================================
<%= for action <- context.actions do %>
<%= action.docstring %><%= if action.method == "GET" do %>
-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>) ->
-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary(), list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>) ->
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>)
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>)
when is_map(Client) ->
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, #{}, #{}).
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, #{}, #{}).

-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>, map(), map()) ->
-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary(), list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>, map(), map()) ->
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, QueryMap, HeadersMap)
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, QueryMap, HeadersMap)
when is_map(Client), is_map(QueryMap), is_map(HeadersMap) ->
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, QueryMap, HeadersMap, []).
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, QueryMap, HeadersMap, []).

-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>, map(), map(), proplists:proplist()) ->
-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary(), list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>, map(), map(), proplists:proplist()) ->
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, QueryMap, HeadersMap, Options0)
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, QueryMap, HeadersMap, Options0)
when is_map(Client), is_map(QueryMap), is_map(HeadersMap), is_list(Options0) ->
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
<%= if !String.contains?("Bucket", AWS.CodeGen.RestService.required_function_parameters(action)) do %><% else %> Bucket = undefined,<% end %><% end %>
Expand Down Expand Up @@ -93,7 +93,7 @@ end) %>
<% else %>
Query_ = [],
<% end %><%= if length(action.response_header_parameters) > 0 do %>
case request(Client, get, Path, Query_, Headers, undefined, Options, SuccessStatusCode<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>, Bucket<% end %>) of
case request(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>#{api_id => ApiId}<% else %><% end %>, get, Path, Query_, Headers, undefined, Options, SuccessStatusCode<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>, Bucket<% end %>) of
{ok, Body0, {_, ResponseHeaders, _} = Response} ->
ResponseHeadersParams =
[<%= for parameter <- Enum.drop action.response_header_parameters, -1 do %>
Expand All @@ -111,16 +111,16 @@ end) %>
Result ->
Result
end.<% else %>
request(Client, get, Path, Query_, Headers, undefined, Options, SuccessStatusCode<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>, Bucket<% end %>).<% end %>
request(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>#{api_id => ApiId}<% else %><% end %>, get, Path, Query_, Headers, undefined, Options, SuccessStatusCode<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>, Bucket<% end %>).<% end %>
<% else %>
-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>, <%= AWS.CodeGen.Types.function_argument_type(context.language, action)%>) ->
-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary(), list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>, <%= AWS.CodeGen.Types.function_argument_type(context.language, action)%>) ->
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, Stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, Input) ->
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, Stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, Input, []).
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, Input) ->
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, Input, []).

-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>, <%= AWS.CodeGen.Types.function_argument_type(context.language, action)%>, proplists:proplist()) ->
-spec <%= action.function_name %>(aws_client:aws_client()<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, list() | binary(), list() | binary()<% end %><%= AWS.CodeGen.Types.required_function_parameter_types(action) %>, <%= AWS.CodeGen.Types.function_argument_type(context.language, action)%>, proplists:proplist()) ->
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, Stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, Input0, Options0) ->
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, Input0, Options0) ->
Method = <%= AWS.CodeGen.RestService.Action.method(action) %>,
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
<%= if !String.contains?("Bucket", AWS.CodeGen.RestService.required_function_parameters(action)) do %><% else %> Bucket = undefined,<% end %><% end %>
Expand Down Expand Up @@ -177,7 +177,7 @@ end) %>
Result ->
Result
end.<% else %>
request(Client, Method, Path, Query_, CustomHeaders ++ Headers, Input, Options, SuccessStatusCode<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>, Bucket<% end %>).<% end %>
request(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>#{api_id => ApiId}<% else %><% end %>, Method, Path, Query_, CustomHeaders ++ Headers, Input, Options, SuccessStatusCode<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>, Bucket<% end %>).<% end %>
<% end %><% end %>
%%====================================================================
%% Internal functions
Expand Down Expand Up @@ -231,7 +231,7 @@ do_request(Client, Method, Path, Query, Headers0, Input, Options, SuccessStatusC
Headers1 = aws_request:add_headers(AdditionalHeaders, Headers0),

MethodBin = aws_request:method_to_binary(Method),
SignedHeaders = aws_request:sign_request(Client1, MethodBin, URL, Headers1, Payload),
SignedHeaders = aws_request:sign_request(Client1, MethodBin, URL, Headers1, Payload<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, [{uri_encode_path, true}]<% else %><% end %>),
Response = hackney:request(Method, URL, SignedHeaders, Payload, Options),
DecodeBody = not proplists:get_value(receive_body_as_binary, Options),
handle_response(Response, SuccessStatusCode, DecodeBody).
Expand Down Expand Up @@ -288,7 +288,13 @@ handle_response({ok, StatusCode, ResponseHeaders, Client}, _, _DecodeBody) ->
end;
handle_response({error, Reason}, _, _DecodeBody) ->
{error, Reason}.
<%= if context.endpoint_prefix == "s3-control" do %>
<%= if context.module_name == "aws_apigatewaymanagementapi" do %>
build_host(_EndpointPrefix, #{region := <<"local">>, endpoint := Endpoint}) ->
Endpoint;
build_host(_EndpointPrefix, #{region := <<"local">>}) ->
<<"localhost">>;
build_host(EndpointPrefix, #{api_id := ApiId, region := Region, endpoint := Endpoint}) ->
aws_util:binary_join([ApiId, EndpointPrefix, Region, Endpoint], <<".">>).<% else %><%= if context.endpoint_prefix == "s3-control" do %>
build_host(_AccountId, _EndpointPrefix, #{region := <<"local">>, endpoint := Endpoint}) ->
Endpoint;
build_host(_AccountId, _EndpointPrefix, #{region := <<"local">>}) ->
Expand All @@ -314,8 +320,7 @@ build_host(_EndpointPrefix, #{region := <<"local">>}, _Bucket) ->
build_host(EndpointPrefix, #{endpoint := Endpoint}, undefined) ->
aws_util:binary_join([EndpointPrefix, Endpoint], <<".">>);
build_host(EndpointPrefix, #{endpoint := Endpoint}, Bucket) ->
aws_util:binary_join([Bucket, EndpointPrefix, Endpoint], <<".">>).
<% else %>
aws_util:binary_join([Bucket, EndpointPrefix, Endpoint], <<".">>).<% else %>
build_host(EndpointPrefix, #{region := Region, endpoint := Endpoint}, undefined) ->
aws_util:binary_join([EndpointPrefix, Region, Endpoint], <<".">>);
build_host(EndpointPrefix, #{region := Region, endpoint := Endpoint}, Bucket) ->
Expand All @@ -326,7 +331,7 @@ build_host(_EndpointPrefix, #{region := <<"local">>}) ->
build_host(EndpointPrefix, #{endpoint := Endpoint}) ->
aws_util:binary_join([EndpointPrefix, Endpoint], <<".">>).<% else %>
build_host(EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
aws_util:binary_join([EndpointPrefix, Region, Endpoint], <<".">>).<% end %><% end %><% end %>
aws_util:binary_join([EndpointPrefix, Region, Endpoint], <<".">>).<% end %><% end %><% end %><% end %>

<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>build_url(Host0, Path0, Client, Bucket) ->
Proto = aws_client:proto(Client),
Expand Down

0 comments on commit acb7a87

Please sign in to comment.