Skip to content

Commit

Permalink
Merge pull request #84 from evanuk/migrate-to-imdsv2
Browse files Browse the repository at this point in the history
Migrate from IMDSv1 to IMDSv2
  • Loading branch information
evanuk authored Sep 10, 2021
2 parents 8ac0e9f + fba1075 commit c895e9f
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import java.net.URI;
import java.util.Collections;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.javatuples.Pair;
import software.amazon.cloudwatchlogs.emf.Constants;
import software.amazon.cloudwatchlogs.emf.config.Configuration;
import software.amazon.cloudwatchlogs.emf.exception.EMFClientException;
Expand All @@ -33,7 +35,13 @@ public class EC2Environment extends AgentBasedEnvironment {

private static final String INSTANCE_IDENTITY_URL =
"http://169.254.169.254/latest/dynamic/instance-identity/document";

private static final String INSTANCE_TOKEN_URL = "http://169.254.169.254/latest/api/token";
private static final String CFN_EC2_TYPE = "AWS::EC2::Instance";
private static final String TOKEN_REQUEST_HEADER_KEY = "X-aws-ec2-metadata-token-ttl-seconds";
private static final String TOKEN_REQUEST_HEADER_VALUE = "21600";

private static final String METADATA_REQUEST_TOKEN_HEADER_KEY = "X-aws-ec2-metadata-token";

EC2Environment(Configuration config, ResourceFetcher fetcher) {
super(config);
Expand All @@ -43,6 +51,28 @@ public class EC2Environment extends AgentBasedEnvironment {

@Override
public boolean probe() {
String token;
Pair<String, String> tokenRequestHeader =
new Pair<>(TOKEN_REQUEST_HEADER_KEY, TOKEN_REQUEST_HEADER_VALUE);

URI tokenEndpoint = null;
try {
tokenEndpoint = new URI(INSTANCE_TOKEN_URL);
} catch (Exception ex) {
log.debug("Failed to construct url: " + INSTANCE_IDENTITY_URL);
return false;
}
try {
token =
fetcher.fetch(
tokenEndpoint, "PUT", Collections.singletonList(tokenRequestHeader));
} catch (EMFClientException ex) {
log.debug("Failed to get response from: " + tokenEndpoint, ex);
return false;
}

Pair<String, String> metadataRequestTokenHeader =
new Pair<>(METADATA_REQUEST_TOKEN_HEADER_KEY, token);
URI endpoint = null;
try {
endpoint = new URI(INSTANCE_IDENTITY_URL);
Expand All @@ -51,7 +81,12 @@ public boolean probe() {
return false;
}
try {
metadata = fetcher.fetch(endpoint, EC2Metadata.class);
metadata =
fetcher.fetch(
endpoint,
"GET",
EC2Metadata.class,
Collections.singletonList(metadataRequestTokenHeader));
return true;
} catch (EMFClientException ex) {
log.debug("Failed to get response from: " + endpoint, ex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import java.net.HttpURLConnection;
import java.net.Proxy;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.javatuples.Pair;
import software.amazon.cloudwatchlogs.emf.exception.EMFClientException;
import software.amazon.cloudwatchlogs.emf.util.IOUtils;
import software.amazon.cloudwatchlogs.emf.util.Jackson;
Expand All @@ -33,24 +36,37 @@ public class ResourceFetcher {

/** Fetch a json object from a given uri and deserialize it to the specified class: clazz. */
<T> T fetch(URI endpoint, Class<T> clazz) {
String response = doReadResource(endpoint, "GET");
String response = doReadResource(endpoint, "GET", Collections.emptyList());
return Jackson.fromJsonString(response, clazz);
}

/**
* Request a json object from a given uri with the provided headers and deserialize it to the
* specified class: clazz.
*/
<T> T fetch(URI endpoint, String method, Class<T> clazz, List<Pair<String, String>> headers) {
String response = doReadResource(endpoint, method, headers);
return Jackson.fromJsonString(response, clazz);
}

/** Request a string from a given uri with the provided headers */
String fetch(URI endpoint, String method, List<Pair<String, String>> headers) {
return doReadResource(endpoint, method, headers);
}

/**
* Fetch a json object from a given uri and deserialize it to the specified class with a given
* Jackson ObjectMapper.
*/
<T> T fetch(URI endpoint, ObjectMapper objectMapper, Class<T> clazz) {
String response = doReadResource(endpoint, "GET");
String response = doReadResource(endpoint, "GET", Collections.emptyList());
return Jackson.fromJsonString(response, objectMapper, clazz);
}

private String doReadResource(URI endpoint, String method) {
private String doReadResource(URI endpoint, String method, List<Pair<String, String>> headers) {
InputStream inputStream = null;
try {

HttpURLConnection connection = connectToEndpoint(endpoint, method);
HttpURLConnection connection = connectToEndpoint(endpoint, method, headers);

int statusCode = connection.getResponseCode();

Expand Down Expand Up @@ -105,13 +121,16 @@ private void handleErrorResponse(InputStream errorStream, String responseMessage
}
}

private HttpURLConnection connectToEndpoint(URI endpoint, String method) throws IOException {
private HttpURLConnection connectToEndpoint(
URI endpoint, String method, List<Pair<String, String>> headers) throws IOException {
HttpURLConnection connection =
(HttpURLConnection) endpoint.toURL().openConnection(Proxy.NO_PROXY);
connection.setConnectTimeout(1000);
connection.setReadTimeout(1000);
connection.setRequestMethod(method);
connection.setDoOutput(true);
headers.forEach(
header -> connection.setRequestProperty(header.getValue0(), header.getValue1()));

connection.connect();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ public void setUp() {
environment = new EC2Environment(config, fetcher);
}

@SuppressWarnings("unchecked")
@Test
public void testProbeReturnFalse() {
when(fetcher.fetch(any(), any())).thenThrow(new EMFClientException("Invalid URL"));
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any()))
.thenThrow(new EMFClientException("Invalid URL"));

assertFalse(environment.probe());
}
Expand All @@ -71,8 +73,10 @@ public void testGetTypeWhenNoMetadata() {
}

@Test
@SuppressWarnings("unchecked")
public void testGetTypeReturnDefined() {
when(fetcher.fetch(any(), any())).thenReturn(new EC2Environment.EC2Metadata());
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any()))
.thenReturn(new EC2Environment.EC2Metadata());
environment.probe();
assertEquals(environment.getType(), "AWS::EC2::Instance");
}
Expand All @@ -87,10 +91,11 @@ public void testGetTypeFromConfiguration() {
}

@Test
@SuppressWarnings("unchecked")
public void testConfigureContext() {
EC2Environment.EC2Metadata metadata = new EC2Environment.EC2Metadata();
getRandomMetadata(metadata);
when(fetcher.fetch(any(), any())).thenReturn(metadata);
when(fetcher.fetch(any(), any(), (Class<Object>) any(), any())).thenReturn(metadata);
environment.probe();

MetricsContext context = new MetricsContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.javafaker.Faker;
import java.net.InetAddress;
import java.net.UnknownHostException;
Expand Down Expand Up @@ -68,7 +69,7 @@ public void testReturnTrueWithCorrectURL() {
String uri = "http://ecs-metata.com";
PowerMockito.when(SystemWrapper.getenv("ECS_CONTAINER_METADATA_URI")).thenReturn(uri);
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);

assertTrue(environment.probe());
}
Expand All @@ -81,7 +82,7 @@ public void testFormatImageName() {
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
metadata.image = "testAccount.dkr.ecr.us-west-2.amazonaws.com/testImage:latest";
metadata.labels = new HashMap<>();
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);

assertTrue(environment.probe());
assertEquals(environment.getName(), "testImage:latest");
Expand Down Expand Up @@ -122,7 +123,8 @@ public void testSetFluentBit() {
PowerMockito.when(SystemWrapper.getenv("FLUENT_HOST")).thenReturn(fluentHost);

environment.probe();
when(fetcher.fetch(any(), any(), any())).thenReturn(new ECSEnvironment.ECSMetadata());
when(fetcher.fetch(any(), (ObjectMapper) any(), any()))
.thenReturn(new ECSEnvironment.ECSMetadata());
ArgumentCaptor<String> argument = ArgumentCaptor.forClass(String.class);
Mockito.verify(config, times(1)).setAgentEndpoint(argument.capture());
assertEquals(
Expand Down Expand Up @@ -155,7 +157,7 @@ public void testConfigureContext() throws UnknownHostException {
PowerMockito.when(SystemWrapper.getenv("ECS_CONTAINER_METADATA_URI")).thenReturn(uri);
ECSEnvironment.ECSMetadata metadata = new ECSEnvironment.ECSMetadata();
getRandomMetadata(metadata);
when(fetcher.fetch(any(), any(), any())).thenReturn(metadata);
when(fetcher.fetch(any(), (ObjectMapper) any(), any())).thenReturn(metadata);

environment.probe();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
import java.net.ServerSocket;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import lombok.Data;
import org.javatuples.Pair;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
Expand Down Expand Up @@ -94,6 +96,33 @@ public void testReadDataWith200Response() {
assertEquals(data.size, 10);
}

@Test
public void testReadDataWithHeaders200Response() {
Pair<String, String> mockHeader = new Pair<>("X-mock-header-key", "headerValue");
generateStub(200, "{\"name\":\"test\",\"size\":10}");
TestData data =
fetcher.fetch(
endpoint, "GET", TestData.class, Collections.singletonList(mockHeader));

verify(
getRequestedFor(urlEqualTo(endpoint_path))
.withHeader("X-mock-header-key", equalTo("headerValue")));
assertEquals(data.name, "test");
assertEquals(data.size, 10);
}

@Test
public void testWithProvidedMethodAndHeadersWith200Response() {
generatePutStub(200, "putResponseData");
Pair<String, String> mockHeader = new Pair<>("X-mock-header-key", "headerValue");
String data = fetcher.fetch(endpoint, "PUT", Collections.singletonList(mockHeader));

verify(
putRequestedFor(urlEqualTo(endpoint_path))
.withHeader("X-mock-header-key", equalTo("headerValue")));
assertEquals(data, "putResponseData");
}

@Test
public void testReadCaseInsensitiveDataWith200Response() {
generateStub(200, "{\"Name\":\"test\",\"Size\":10}");
Expand Down Expand Up @@ -136,6 +165,17 @@ private void generateStub(int statusCode, String message) {
.withBody(message)));
}

private void generatePutStub(int statusCode, String message) {
stubFor(
put(urlPathEqualTo(endpoint_path))
.willReturn(
aResponse()
.withStatus(statusCode)
.withHeader("Content-Type", "application/json")
.withHeader("charset", "utf-8")
.withBody(message)));
}

@Data
private static class TestData {
private String name;
Expand Down

0 comments on commit c895e9f

Please sign in to comment.