From b7a6bd589c74bd454c271b03906fcdc480ef5e33 Mon Sep 17 00:00:00 2001 From: Rui Qiao Date: Wed, 22 Jan 2025 10:40:30 -0800 Subject: [PATCH] Revert "[Core] Fix token expiration for ray autoscaler (#48481)" This reverts commit bd3d90a3e835efb9c448b3d650de905039364018. Signed-off-by: Rui Qiao --- .../_private/kuberay/autoscaling_config.py | 22 +++++++++++---- .../_private/kuberay/node_provider.py | 28 ++++--------------- .../tests/kuberay/test_autoscaling_config.py | 11 ++------ 3 files changed, 25 insertions(+), 36 deletions(-) diff --git a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py index b4ec5aba67016..41122047eb9df 100644 --- a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py +++ b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py @@ -49,10 +49,10 @@ class AutoscalingConfigProducer: """ def __init__(self, ray_cluster_name, ray_cluster_namespace): - self.kubernetes_api_client = node_provider.KubernetesHttpApiClient( - namespace=ray_cluster_namespace + self._headers, self._verify = node_provider.load_k8s_secrets() + self._ray_cr_url = node_provider.url_from_resource( + namespace=ray_cluster_namespace, path=f"rayclusters/{ray_cluster_name}" ) - self._ray_cr_path = f"rayclusters/{ray_cluster_name}" def __call__(self): ray_cr = self._fetch_ray_cr_from_k8s_with_retries() @@ -67,7 +67,7 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]: """ for i in range(1, MAX_RAYCLUSTER_FETCH_TRIES + 1): try: - return self.kubernetes_api_client.get(self._ray_cr_path) + return self._fetch_ray_cr_from_k8s() except requests.HTTPError as e: if i < MAX_RAYCLUSTER_FETCH_TRIES: logger.exception( @@ -80,6 +80,18 @@ def _fetch_ray_cr_from_k8s_with_retries(self) -> Dict[str, Any]: # This branch is inaccessible. Raise to satisfy mypy. raise AssertionError + def _fetch_ray_cr_from_k8s(self) -> Dict[str, Any]: + result = requests.get( + self._ray_cr_url, + headers=self._headers, + timeout=node_provider.KUBERAY_REQUEST_TIMEOUT_S, + verify=self._verify, + ) + if not result.status_code == 200: + result.raise_for_status() + ray_cr = result.json() + return ray_cr + def _derive_autoscaling_config_from_ray_cr(ray_cr: Dict[str, Any]) -> Dict[str, Any]: provider_config = _generate_provider_config(ray_cr["metadata"]["namespace"]) @@ -167,7 +179,7 @@ def _generate_legacy_autoscaling_config_fields() -> Dict[str, Any]: def _generate_available_node_types_from_ray_cr_spec( - ray_cr_spec: Dict[str, Any], + ray_cr_spec: Dict[str, Any] ) -> Dict[str, Any]: """Formats autoscaler "available_node_types" field based on the Ray CR's group specs. diff --git a/python/ray/autoscaler/_private/kuberay/node_provider.py b/python/ray/autoscaler/_private/kuberay/node_provider.py index 0bf01e5504433..6e788564d7a9a 100644 --- a/python/ray/autoscaler/_private/kuberay/node_provider.py +++ b/python/ray/autoscaler/_private/kuberay/node_provider.py @@ -1,4 +1,3 @@ -import datetime import json import logging import os @@ -55,8 +54,6 @@ # Key for GKE label that identifies which multi-host replica a pod belongs to REPLICA_INDEX_KEY = "replicaIndex" -TOKEN_REFRESH_PERIOD = datetime.timedelta(minutes=1) - # Design: # Each modification the autoscaler wants to make is posted to the API server goal state @@ -267,19 +264,7 @@ class KubernetesHttpApiClient(IKubernetesHttpApiClient): def __init__(self, namespace: str, kuberay_crd_version: str = KUBERAY_CRD_VER): self._kuberay_crd_version = kuberay_crd_version self._namespace = namespace - self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD - self._headers, self._verify = None, None - - def _get_refreshed_headers_and_verify(self): - if (datetime.datetime.now() >= self._token_expires_at) or ( - self._headers is None or self._verify is None - ): - logger.info("Refreshing K8s API client token and certs.") - self._headers, self._verify = load_k8s_secrets() - self._token_expires_at = datetime.datetime.now() + TOKEN_REFRESH_PERIOD - return self._headers, self._verify - else: - return self._headers, self._verify + self._headers, self._verify = load_k8s_secrets() def get(self, path: str) -> Dict[str, Any]: """Wrapper for REST GET of resource with proper headers. @@ -298,13 +283,11 @@ def get(self, path: str) -> Dict[str, Any]: path=path, kuberay_crd_version=self._kuberay_crd_version, ) - - headers, verify = self._get_refreshed_headers_and_verify() result = requests.get( url, - headers=headers, + headers=self._headers, timeout=KUBERAY_REQUEST_TIMEOUT_S, - verify=verify, + verify=self._verify, ) if not result.status_code == 200: result.raise_for_status() @@ -328,12 +311,11 @@ def patch(self, path: str, payload: List[Dict[str, Any]]) -> Dict[str, Any]: path=path, kuberay_crd_version=self._kuberay_crd_version, ) - headers, verify = self._get_refreshed_headers_and_verify() result = requests.patch( url, json.dumps(payload), - headers={**headers, "Content-type": "application/json-patch+json"}, - verify=verify, + headers={**self._headers, "Content-type": "application/json-patch+json"}, + verify=self._verify, ) if not result.status_code == 200: result.raise_for_status() diff --git a/python/ray/tests/kuberay/test_autoscaling_config.py b/python/ray/tests/kuberay/test_autoscaling_config.py index 82aec91ff9696..7fe8759fc1c2c 100644 --- a/python/ray/tests/kuberay/test_autoscaling_config.py +++ b/python/ray/tests/kuberay/test_autoscaling_config.py @@ -395,22 +395,17 @@ def test_autoscaling_config_fetch_retries(exception, num_exceptions): AutoscalingConfigProducer._fetch_ray_cr_from_k8s_with_retries. """ - class MockKubernetesHttpApiClient: - def __init__(self): + class MockAutoscalingConfigProducer(AutoscalingConfigProducer): + def __init__(self, *args, **kwargs): self.exception_counter = 0 - def get(self, *args, **kwargs): + def _fetch_ray_cr_from_k8s(self) -> Dict[str, Any]: if self.exception_counter < num_exceptions: self.exception_counter += 1 raise exception else: return {"ok-key": "ok-value"} - class MockAutoscalingConfigProducer(AutoscalingConfigProducer): - def __init__(self, *args, **kwargs): - self.kubernetes_api_client = MockKubernetesHttpApiClient() - self._ray_cr_path = "rayclusters/mock" - config_producer = MockAutoscalingConfigProducer() # Patch retry backoff period. with mock.patch(