diff --git a/src/gpuhunt/providers/aws.py b/src/gpuhunt/providers/aws.py index 89570bf..3ec14bb 100644 --- a/src/gpuhunt/providers/aws.py +++ b/src/gpuhunt/providers/aws.py @@ -138,7 +138,7 @@ def fill_gpu_details(self, offers: List[RawCatalogItem]): gpu = i["GpuInfo"]["Gpus"][0] gpus[i["InstanceType"]] = ( gpu["Name"], - gpu["MemoryInfo"]["SizeInMiB"] / 1024, + _get_gpu_memory_gib(gpu["Name"], gpu["MemoryInfo"]["SizeInMiB"]), ) regions = { @@ -230,6 +230,22 @@ def filter(cls, offers: List[RawCatalogItem]) -> List[RawCatalogItem]: ] +def _get_gpu_memory_gib(gpu_name: str, reported_memory_mib: int) -> float: + """ + Fixes L4 memory size misreported by AWS API + """ + + if gpu_name != "L4": + return reported_memory_mib / 1024 + + if reported_memory_mib not in (22888, 91553, 183105): + logger.warning( + "The L4 memory size reported by AWS changed. " + "Please check that it is now correct and remove the hardcoded size if it is." + ) + return 24 + + def parse_memory(s: str) -> float: r = re.match(r"^([0-9.]+) GiB$", s) return float(r.group(1))