Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AWS client creation #332

Merged
merged 16 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package com.conveyal.datatools.common.utils;

import java.util.HashMap;

import static com.conveyal.datatools.common.utils.AWSUtils.DEFAULT_EXPIRING_AWS_ASSET_VALID_DURATION_MILLIS;

/**
* This abstract class provides a framework for managing the creation of AWS Clients. Three types of clients are stored
* in this class:
* 1. A default client to use when not requesting a client using a specific role and/or region
* 2. A client to use when using a specific region, but not with a role
* 3. A client to use with a specific role and region combination (including null regions)
*
* The {@link AWSClientManager#getClient(String, String)} handles the creation and caching of clients based on the given
* role and region inputs.
*/
public abstract class AWSClientManager<T> {
evansiroky marked this conversation as resolved.
Show resolved Hide resolved
protected final T defaultClient;
private HashMap<String, T> nonRoleClientsByRegion = new HashMap<>();
private HashMap<String, ExpiringAsset<T>> clientsByRoleAndRegion = new HashMap<>();

public AWSClientManager (T defaultClient) {
this.defaultClient = defaultClient;
}

/**
* An abstract method where the implementation will create a client with the specified region, but not with a role.
*/
public abstract T buildDefaultClientWithRegion(String region);

/**
* An abstract method where the implementation will create a client with the specified role and region.
*/
protected abstract T buildCredentialedClientForRoleAndRegion(String role, String region)
throws CheckedAWSException;

/**
* Obtain a potentially cached AWS client for the provided role ARN and region. If the role and region are null, the
* default AWS client will be used. If just the role is null a cached client configured for the specified
* region will be returned. For clients that require using a role, a client will be obtained (either via a cache or
* by creation and then insertion into the cache) that has obtained the proper credentials.
*/
public T getClient(String role, String region) throws CheckedAWSException {
// return default client for null region and role
if (role == null && region == null) {
return defaultClient;
}

// if the role is null, return a potentially cached EC2 client with the region configured
T client;
if (role == null) {
client = nonRoleClientsByRegion.get(region);
if (client == null) {
client = buildDefaultClientWithRegion(region);
nonRoleClientsByRegion.put(region, client);
}
return client;
}

// check for the availability of a client already associated with the given role and region
String roleRegionKey = makeRoleRegionKey(role, region);
ExpiringAsset<T> clientWithRole = clientsByRoleAndRegion.get(roleRegionKey);
if (clientWithRole != null && clientWithRole.isActive()) return clientWithRole.asset;

// Either a new client hasn't been created or it has expired. Create a new client and cache it.
T credentialedClientForRoleAndRegion = buildCredentialedClientForRoleAndRegion(role, region);
clientsByRoleAndRegion.put(
roleRegionKey,
new ExpiringAsset<T>(credentialedClientForRoleAndRegion, DEFAULT_EXPIRING_AWS_ASSET_VALID_DURATION_MILLIS)
);
return credentialedClientForRoleAndRegion;
}

private static String makeRoleRegionKey(String role, String region) {
return String.format("role=%s,region=%s", role, region);
}
}
247 changes: 200 additions & 47 deletions src/main/java/com/conveyal/datatools/common/utils/AWSUtils.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.conveyal.datatools.common.utils;

import com.amazonaws.AmazonServiceException;

/**
* A helper exception class that does not extend the RunTimeException class in order to make the compiler properly
* detect possible places where an exception could occur.
*/
public class CheckedAWSException extends Exception {
public final Exception originalException;

public CheckedAWSException(String message) {
super(message);
originalException = null;
}

public CheckedAWSException(AmazonServiceException e) {
super(e.getMessage());
originalException = e;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.conveyal.datatools.common.utils;

/**
* A class that holds another variable and keeps track of whether the variable is still considered to be active (ie not
* expired)
*/
public class ExpiringAsset<T> {
evansiroky marked this conversation as resolved.
Show resolved Hide resolved
public final T asset;
private final long expirationTimeMillis;

public ExpiringAsset(T asset, long validDurationMillis) {
this.asset = asset;
this.expirationTimeMillis = System.currentTimeMillis() + validDurationMillis;
}

/**
* @return true if the asset hasn't yet expired
*/
public boolean isActive() {
return expirationTimeMillis > System.currentTimeMillis();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package com.conveyal.datatools.editor.controllers.api;


import com.amazonaws.AmazonServiceException;
import com.amazonaws.services.s3.AmazonS3;
import com.conveyal.datatools.common.utils.CheckedAWSException;
import com.conveyal.datatools.common.utils.SparkUtils;
import com.conveyal.datatools.editor.jobs.CreateSnapshotJob;
import com.conveyal.datatools.editor.jobs.ExportSnapshotToGTFSJob;
Expand All @@ -14,7 +17,6 @@
import com.conveyal.datatools.manager.models.FeedVersion;
import com.conveyal.datatools.manager.models.JsonViews;
import com.conveyal.datatools.manager.models.Snapshot;
import com.conveyal.datatools.manager.persistence.FeedStore;
import com.conveyal.datatools.manager.persistence.Persistence;
import com.conveyal.datatools.manager.utils.json.JsonManager;
import org.slf4j.Logger;
Expand All @@ -26,6 +28,7 @@
import java.util.Collection;

import static com.conveyal.datatools.common.utils.AWSUtils.downloadFromS3;
import static com.conveyal.datatools.common.utils.AWSUtils.getDefaultS3Client;
import static com.conveyal.datatools.common.utils.SparkUtils.downloadFile;
import static com.conveyal.datatools.common.utils.SparkUtils.formatJobMessage;
import static com.conveyal.datatools.common.utils.SparkUtils.logMessageAndHalt;
Expand Down Expand Up @@ -193,16 +196,22 @@ private static Object getSnapshotToken(Request req, Response res) {
// an actual object to download.
// FIXME: use new FeedStore.
if (DataManager.useS3) {
if (!FeedStore.s3Client.doesObjectExist(DataManager.feedBucket, key)) {
logMessageAndHalt(
req,
500,
String.format("Error downloading snapshot from S3. Object %s does not exist.", key),
new Exception("s3 object does not exist")
);
try {
AmazonS3 S3Client = getDefaultS3Client();
if (!S3Client.doesObjectExist(DataManager.feedBucket, key)) {
logMessageAndHalt(
req,
500,
String.format("Error downloading snapshot from S3. Object %s does not exist.", key),
new Exception("s3 object does not exist")
);
}
// Return presigned download link if using S3.
return downloadFromS3(S3Client, DataManager.feedBucket, key, false, res);
} catch (AmazonServiceException | CheckedAWSException e) {
logMessageAndHalt(req, 500, "Failed to download snapshot from S3.", e);
return null;
}
// Return presigned download link if using S3.
return downloadFromS3(FeedStore.s3Client, DataManager.feedBucket, key, false, res);
} else {
// If not storing on s3, just use the token download method.
token = new FeedDownloadToken(snapshot);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package com.conveyal.datatools.editor.jobs;

import com.amazonaws.AmazonServiceException;
import com.conveyal.datatools.common.status.MonitorableJob;
import com.conveyal.datatools.common.utils.CheckedAWSException;
import com.conveyal.datatools.manager.DataManager;
import com.conveyal.datatools.manager.auth.Auth0UserProfile;
import com.conveyal.datatools.manager.models.FeedVersion;
import com.conveyal.datatools.manager.models.Snapshot;
import com.conveyal.datatools.manager.persistence.FeedStore;
import com.conveyal.gtfs.loader.FeedLoadResult;
import com.conveyal.gtfs.loader.JdbcGtfsExporter;
import com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -16,6 +17,8 @@
import java.io.FileInputStream;
import java.io.IOException;

import static com.conveyal.datatools.common.utils.AWSUtils.getDefaultS3Client;

/**
* This job will export a database snapshot (i.e., namespace) to a GTFS file. If a feed version is supplied in the
* constructor, it will assume that the GTFS file is intended for ingestion into Data Tools as a new feed version.
Expand Down Expand Up @@ -70,7 +73,12 @@ public void jobLogic() {
status.update("Writing snapshot to GTFS file", 90);
if (DataManager.useS3) {
String s3Key = String.format("%s/%s", bucketPrefix, filename);
FeedStore.s3Client.putObject(DataManager.feedBucket, s3Key, tempFile);
try {
getDefaultS3Client().putObject(DataManager.feedBucket, s3Key, tempFile);
} catch (AmazonServiceException | CheckedAWSException e) {
status.fail("Failed to upload file to S3", e);
return;
}
LOG.info("Storing snapshot GTFS at s3://{}/{}", DataManager.feedBucket, s3Key);
} else {
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package com.conveyal.datatools.manager.controllers.api;

import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.services.ec2.AmazonEC2;
import com.amazonaws.services.ec2.AmazonEC2Client;
import com.amazonaws.services.ec2.model.DescribeInstancesRequest;
import com.amazonaws.services.ec2.model.DescribeInstancesResult;
import com.amazonaws.services.ec2.model.Filter;
import com.amazonaws.services.ec2.model.Instance;
import com.amazonaws.services.ec2.model.Reservation;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3URI;
import com.conveyal.datatools.common.status.MonitorableJob;
import com.conveyal.datatools.common.utils.AWSUtils;
import com.conveyal.datatools.common.utils.CheckedAWSException;
import com.conveyal.datatools.common.utils.SparkUtils;
import com.conveyal.datatools.manager.DataManager;
import com.conveyal.datatools.manager.auth.Auth0UserProfile;
Expand All @@ -23,7 +20,6 @@
import com.conveyal.datatools.manager.models.JsonViews;
import com.conveyal.datatools.manager.models.OtpServer;
import com.conveyal.datatools.manager.models.Project;
import com.conveyal.datatools.manager.persistence.FeedStore;
import com.conveyal.datatools.manager.persistence.Persistence;
import com.conveyal.datatools.manager.utils.json.JsonManager;
import org.bson.Document;
Expand All @@ -46,8 +42,8 @@
import java.util.stream.Collectors;

import static com.conveyal.datatools.common.utils.AWSUtils.downloadFromS3;
import static com.conveyal.datatools.common.utils.AWSUtils.getS3Client;
import static com.conveyal.datatools.common.utils.SparkUtils.logMessageAndHalt;
import static com.conveyal.datatools.manager.persistence.FeedStore.getAWSCreds;
import static spark.Spark.delete;
import static spark.Spark.get;
import static spark.Spark.options;
Expand All @@ -61,7 +57,6 @@
public class DeploymentController {
private static final Logger LOG = LoggerFactory.getLogger(DeploymentController.class);
private static Map<String, DeployJob> deploymentJobsByServer = new HashMap<>();
private static final AmazonEC2 ec2 = AmazonEC2Client.builder().build();

/**
* Gets the deployment specified by the request's id parameter and ensure that user has access to the
Expand Down Expand Up @@ -95,11 +90,9 @@ private static Deployment deleteDeployment (Request req, Response res) {
/**
evansiroky marked this conversation as resolved.
Show resolved Hide resolved
* HTTP endpoint for downloading a build artifact (e.g., otp build log or Graph.obj) from S3.
*/
private static String downloadBuildArtifact (Request req, Response res) {
private static String downloadBuildArtifact (Request req, Response res) throws CheckedAWSException {
Deployment deployment = getDeploymentWithPermissions(req, res);
DeployJob.DeploySummary summaryToDownload = null;
// Default client to use if no role was used during the deployment.
AmazonS3 s3Client = FeedStore.s3Client;
String role = null;
String region = null;
String uriString;
Expand Down Expand Up @@ -144,12 +137,13 @@ private static String downloadBuildArtifact (Request req, Response res) {
}
AmazonS3URI uri = new AmazonS3URI(uriString);
// Assume the alternative role if needed to download the deploy artifact.
if (role != null) {
s3Client = AWSUtils.getS3ClientForRole(role, region);
} else if (region != null) {
s3Client = AWSUtils.getS3ClientForCredentials(getAWSCreds(), region);
}
return downloadFromS3(s3Client, uri.getBucket(), String.join("/", uri.getKey(), filename), false, res);
return downloadFromS3(
getS3Client(role, region),
uri.getBucket(),
String.join("/", uri.getKey(), filename),
false,
res
);
}

/**
Expand Down Expand Up @@ -340,7 +334,8 @@ private static Deployment updateDeployment (Request req, Response res) {
* perhaps two people somehow kicked off a deploy job for the same deployment simultaneously and one of the EC2
* instances has out-of-date data).
*/
private static boolean terminateEC2InstanceForDeployment(Request req, Response res) {
private static boolean terminateEC2InstanceForDeployment(Request req, Response res)
throws CheckedAWSException {
Deployment deployment = getDeploymentWithPermissions(req, res);
String instanceIds = req.queryParams("instanceIds");
if (instanceIds == null) {
Expand All @@ -355,18 +350,12 @@ private static boolean terminateEC2InstanceForDeployment(Request req, Response r
.collect(Collectors.toList());
// Get the target group ARN from the latest deployment. Surround in a try/catch in case of NPEs.
// TODO: Perhaps provide some other way to provide the target group ARN.
String targetGroupArn;
DeployJob.DeploySummary latest;
AWSStaticCredentialsProvider credentials;
try {
latest = deployment.latest();
targetGroupArn = latest.ec2Info.targetGroupArn;
// Also, get credentials for role (if exists), which are needed to terminate instances in external AWS account.
credentials = AWSUtils.getCredentialsForRole(latest.role, "deregister-instances");
} catch (Exception e) {
DeployJob.DeploySummary latest = deployment.latest();
if (latest == null || latest.ec2Info == null) {
logMessageAndHalt(req, 400, "Latest deploy job does not exist or is missing target group ARN.");
return false;
}
String targetGroupArn = latest.ec2Info.targetGroupArn;
for (String id : idsToTerminate) {
if (!instanceIdsForDeployment.contains(id)) {
logMessageAndHalt(req, HttpStatus.UNAUTHORIZED_401, "It is not permitted to terminate an instance that is not associated with deployment " + deployment.id);
Expand All @@ -381,7 +370,7 @@ private static boolean terminateEC2InstanceForDeployment(Request req, Response r
}
// If checks are ok, terminate instances.
boolean success = ServerController.deRegisterAndTerminateInstances(
credentials,
latest.role,
targetGroupArn,
latest.ec2Info.region,
idsToTerminate
Expand All @@ -396,7 +385,10 @@ private static boolean terminateEC2InstanceForDeployment(Request req, Response r
/**
* HTTP controller to fetch information about provided EC2 machines that power ELBs running a trip planner.
*/
private static List<EC2InstanceSummary> fetchEC2InstanceSummaries(Request req, Response res) {
private static List<EC2InstanceSummary> fetchEC2InstanceSummaries(
Request req,
Response res
) throws CheckedAWSException {
Deployment deployment = getDeploymentWithPermissions(req, res);
return deployment.retrieveEC2Instances();
}
Expand All @@ -412,7 +404,7 @@ public static List<EC2InstanceSummary> fetchEC2InstanceSummaries(AmazonEC2 ec2Cl
* Fetch EC2 instances from AWS that match the provided set of filters (e.g., tags, instance ID, or other properties).
*/
public static List<Instance> fetchEC2Instances(AmazonEC2 ec2Client, Filter... filters) {
if (ec2Client == null) ec2Client = ec2;
if (ec2Client == null) throw new IllegalArgumentException("Must provide EC2Client");
List<Instance> instances = new ArrayList<>();
DescribeInstancesRequest request = new DescribeInstancesRequest().withFilters(filters);
DescribeInstancesResult result = ec2Client.describeInstances(request);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.conveyal.datatools.manager.controllers.api;

import com.amazonaws.AmazonServiceException;
import com.conveyal.datatools.common.utils.CheckedAWSException;
import com.conveyal.datatools.common.utils.SparkUtils;
import com.conveyal.datatools.manager.DataManager;
import com.conveyal.datatools.manager.auth.Auth0UserProfile;
Expand Down Expand Up @@ -36,6 +38,7 @@
import java.util.Set;

import static com.conveyal.datatools.common.utils.AWSUtils.downloadFromS3;
import static com.conveyal.datatools.common.utils.AWSUtils.getDefaultS3Client;
import static com.conveyal.datatools.common.utils.SparkUtils.copyRequestStreamIntoFile;
import static com.conveyal.datatools.common.utils.SparkUtils.downloadFile;
import static com.conveyal.datatools.common.utils.SparkUtils.formatJobMessage;
Expand Down Expand Up @@ -209,7 +212,12 @@ private static Object getDownloadCredentials(Request req, Response res) {

if (DataManager.useS3) {
// Return pre-signed download link if using S3.
return downloadFromS3(FeedStore.s3Client, DataManager.feedBucket, FeedStore.s3Prefix + version.id, false, res);
try {
return downloadFromS3(getDefaultS3Client(), DataManager.feedBucket, FeedStore.s3Prefix + version.id, false, res);
} catch (AmazonServiceException | CheckedAWSException e) {
logMessageAndHalt(req, 500, "Failed to download file", e);
return null;
}
evansiroky marked this conversation as resolved.
Show resolved Hide resolved
} else {
// when feeds are stored locally, single-use download token will still be used
FeedDownloadToken token = new FeedDownloadToken(version);
Expand Down
Loading