Skip to content

Commit

Permalink
assorted adjustments to batch workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
nmerket committed Aug 22, 2024
1 parent 4db8c26 commit 1bb1ed5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
24 changes: 13 additions & 11 deletions buildstockbatch/aws/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def create_compute_environment(self, maxCPUs=10000):
"type": "SPOT",
"bidPercentage": 100,
"spotIamFleetRole": self.spot_service_role_arn,
"allocationStrategy": "SPOT_PRICE_CAPACITY_OPTIMIZED",
}
)
else:
Expand Down Expand Up @@ -692,7 +693,7 @@ def create_job_definition(self, docker_image, vcpus, memory, command, env_vars):
"jobRoleArn": self.task_role_arn,
"environment": self.generate_name_value_inputs(env_vars),
},
retryStrategy={"attempts": 2},
retryStrategy={"attempts": 5},
tags=self.get_tags(),
)

Expand Down Expand Up @@ -1185,7 +1186,7 @@ def run_job(cls, job_id, bucket, prefix, job_name, region):
jobs_file_path = sim_dir.parent / "jobs.tar.gz"
s3.download_file(bucket, f"{prefix}/jobs.tar.gz", str(jobs_file_path))
with tarfile.open(jobs_file_path, "r") as tar_f:
jobs_d = json.load(tar_f.extractfile(f"jobs/job{job_id:05d}.json"), encoding="utf-8")
jobs_d = json.load(tar_f.extractfile(f"jobs/job{job_id:05d}.json"))
logger.debug("Number of simulations = {}".format(len(jobs_d["batch"])))

logger.debug("Getting weather files")
Expand Down Expand Up @@ -1215,9 +1216,10 @@ def get_dask_client(self):
m = 1024
self.dask_cluster = FargateCluster(
region_name=self.region,
fargate_spot=True,
fargate_spot=dask_cfg.get("fargate_spot", True),
image=self.image_url,
cluster_name_template=f"dask-{self.job_identifier}",
scheduler_timeout="3600",
scheduler_cpu=dask_cfg.get("scheduler_cpu", 2 * m),
scheduler_mem=dask_cfg.get("scheduler_memory", 8 * m),
worker_cpu=dask_cfg.get("worker_cpu", 2 * m),
Expand All @@ -1227,6 +1229,7 @@ def get_dask_client(self):
tags=batch_env.get_tags(),
)
self.dask_client = Client(self.dask_cluster)
logger.info(f"Dask Dashboard: {self.dask_client.dashboard_link}")
return self.dask_client

def cleanup_dask(self):
Expand Down Expand Up @@ -1262,14 +1265,13 @@ def process_results(self, *args, **kwargs):
with open(tmppath / "args.json", "w") as f:
json.dump([args, kwargs], f)

credentials = boto3.Session().get_credentials().get_frozen_credentials()
env = {
"AWS_ACCESS_KEY_ID": credentials.access_key,
"AWS_SECRET_ACCESS_KEY": credentials.secret_key,
}
if credentials.token:
env["AWS_SESSION_TOKEN"] = credentials.token
env["POSTPROCESSING_INSIDE_DOCKER_CONTAINER"] = "true"
env = {"POSTPROCESSING_INSIDE_DOCKER_CONTAINER": "true"}
if self.cfg["aws"]["dask"].get("pass_frozen_credentials", False):
credentials = boto3.Session().get_credentials().get_frozen_credentials()
env["AWS_ACCESS_KEY_ID"] = credentials.access_key
env["AWS_SECRET_ACCESS_KEY"] = credentials.secret_key
if credentials.token:
env["AWS_SESSION_TOKEN"] = credentials.token

volumes = {
tmpdir: {"bind": str(container_workpath), "mode": "rw"},
Expand Down
6 changes: 3 additions & 3 deletions buildstockbatch/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,9 +608,9 @@ def combine_results(fs, results_dir, cfg, do_timeseries=True):

logger.info(f"Finished combining and saving timeseries for upgrade{upgrade_id}.")
logger.info("All aggregation completed. ")
if do_timeseries:
logger.info("Writing timeseries metadata files")
write_metadata_files(fs, ts_dir, partition_columns)
# if do_timeseries:
# logger.info("Writing timeseries metadata files")
# write_metadata_files(fs, ts_dir, partition_columns)


def remove_intermediate_files(fs, results_dir, keep_individual_timeseries=False):
Expand Down
2 changes: 2 additions & 0 deletions buildstockbatch/schemas/v0.3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ aws-dask-spec:
worker_cpu: enum(1024, 2048, 4096, 8192, 16384, required=False)
worker_memory: int(min=1024, required=False)
n_workers: int(min=1, required=True)
pass_frozen_credentials: bool(required=False)
fargate_spot: bool(required=False)

hpc-spec:
account: str(required=True)
Expand Down

0 comments on commit 1bb1ed5

Please sign in to comment.