diff --git a/python/ray/dashboard/datacenter.py b/python/ray/dashboard/datacenter.py index 2a2c660ecd440..b99c46bb70b40 100644 --- a/python/ray/dashboard/datacenter.py +++ b/python/ray/dashboard/datacenter.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List, Optional +from typing import List, Optional import ray.dashboard.consts as dashboard_consts from ray._private.utils import ( @@ -188,41 +188,6 @@ async def get_all_node_summary(cls): for node_id in DataSource.nodes.keys() ] - @classmethod - async def get_agent_infos( - cls, target_node_ids: Optional[List[str]] = None - ) -> Dict[str, Dict[str, Any]]: - """Fetches running Agent (like HTTP/gRPC ports, IP, etc) running on every node - - :param target_node_ids: Target node ids to fetch agent info for. If omitted will - fetch the info for all agents - """ - - # Return all available agent infos in case no target node-ids were provided - target_node_ids = target_node_ids or DataSource.agents.keys() - - missing_node_ids = [ - node_id for node_id in target_node_ids if node_id not in DataSource.agents - ] - if missing_node_ids: - logger.warning( - f"Agent info was not found for {missing_node_ids}" - f" (having agent infos for {list(DataSource.agents.keys())})" - ) - return {} - - def _create_agent_info(node_id: str): - (node_ip, http_port, grpc_port) = DataSource.agents[node_id] - - return dict( - ipAddress=node_ip, - httpPort=int(http_port or -1), - grpcPort=int(grpc_port or -1), - httpAddress=f"{node_ip}:{http_port}", - ) - - return {node_id: _create_agent_info(node_id) for node_id in target_node_ids} - @classmethod async def get_actor_infos(cls, actor_ids: Optional[List[str]] = None): target_actor_table_entries: dict[str, Optional[dict]] diff --git a/python/ray/dashboard/modules/job/job_head.py b/python/ray/dashboard/modules/job/job_head.py index 185c8fc94983a..9572c4eec2871 100644 --- a/python/ray/dashboard/modules/job/job_head.py +++ b/python/ray/dashboard/modules/job/job_head.py @@ -3,25 +3,31 @@ import json import logging import traceback -from random import sample -from typing import AsyncIterator, List, Optional +from random import choice +from typing import AsyncIterator, Dict, List, Optional, Tuple import aiohttp.web from aiohttp.client import ClientResponse from aiohttp.web import Request, Response import ray +from ray import NodeID import ray.dashboard.consts as dashboard_consts +from ray.dashboard.consts import ( + GCS_RPC_TIMEOUT_SECONDS, + DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS, + WAIT_AVAILABLE_AGENT_TIMEOUT, +) import ray.dashboard.optional_utils as optional_utils import ray.dashboard.utils as dashboard_utils -from ray._private.ray_constants import env_bool +from ray._private.ray_constants import env_bool, KV_NAMESPACE_DASHBOARD from ray._private.runtime_env.packaging import ( package_exists, pin_runtime_env_uri, upload_package_to_gcs, ) from ray._private.utils import get_or_create_event_loop -from ray.dashboard.datacenter import DataOrganizer from ray.dashboard.modules.job.common import ( JobDeleteResponse, JobInfoStorageClient, @@ -166,7 +172,8 @@ def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig): # `JobHead` has ever used, and will not be deleted # from it unless `JobAgentSubmissionClient` is no # longer available (the corresponding agent process is dead) - self._agents = dict() + # {node_id: JobAgentSubmissionClient} + self._agents: Dict[NodeID, JobAgentSubmissionClient] = dict() async def get_target_agent(self) -> Optional[JobAgentSubmissionClient]: if RAY_JOB_AGENT_USE_HEAD_NODE_ONLY: @@ -191,7 +198,7 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: """ # NOTE: Following call will block until there's at least 1 agent info # being populated from GCS - agent_infos = await self._fetch_agent_infos() + agent_infos = await self._fetch_all_agent_infos() # delete dead agents. for dead_node in set(self._agents) - set(agent_infos): @@ -199,18 +206,17 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: await client.close() if len(self._agents) >= dashboard_consts.CANDIDATE_AGENT_NUMBER: - node_id = sample(list(set(self._agents)), 1)[0] + node_id = choice(list(self._agents)) return self._agents[node_id] else: # Randomly select one from among all agents, it is possible that # the selected one already exists in `self._agents` - node_id = sample(sorted(agent_infos), 1)[0] + node_id = choice(list(agent_infos)) agent_info = agent_infos[node_id] if node_id not in self._agents: - node_ip = agent_info["ipAddress"] - http_port = agent_info["httpPort"] - agent_http_address = f"http://{node_ip}:{http_port}" + ip, http_port, grpc_port = agent_info + agent_http_address = f"http://{ip}:{http_port}" self._agents[node_id] = JobAgentSubmissionClient(agent_http_address) return self._agents[node_id] @@ -218,49 +224,109 @@ async def _pick_random_agent(self) -> Optional[JobAgentSubmissionClient]: async def _get_head_node_agent(self) -> Optional[JobAgentSubmissionClient]: """Retrieves HTTP client for `JobAgent` running on the Head node""" - head_node_id = await get_head_node_id(self.gcs_aio_client) + head_node_id_hex = await get_head_node_id(self.gcs_aio_client) - if not head_node_id: + if not head_node_id_hex: logger.warning("Head node id has not yet been persisted in GCS") return None + head_node_id = NodeID.from_hex(head_node_id_hex) + if head_node_id not in self._agents: - agent_infos = await self._fetch_agent_infos(target_node_ids=[head_node_id]) + agent_infos = await self._fetch_agent_infos([head_node_id]) if head_node_id not in agent_infos: - logger.error("Head node agent's information was not found") + logger.error( + f"Head node agent's information was not found: {head_node_id} not in {agent_infos}" + ) return None - agent_info = agent_infos[head_node_id] - - node_ip = agent_info["ipAddress"] - http_port = agent_info["httpPort"] - agent_http_address = f"http://{node_ip}:{http_port}" + ip, http_port, grpc_port = agent_infos[head_node_id] + agent_http_address = f"http://{ip}:{http_port}" self._agents[head_node_id] = JobAgentSubmissionClient(agent_http_address) return self._agents[head_node_id] - @staticmethod - async def _fetch_agent_infos(target_node_ids: Optional[List[str]] = None): - """Fetches agent infos for nodes identified by provided node-ids (for all - nodes if not provided) - - NOTE: This call will block until there's at least 1 valid agent info populated + async def _fetch_all_agent_infos(self) -> Dict[NodeID, Tuple[str, int, int]]: """ + Fetches all agent infos for all nodes in the cluster. + + If there's no agent available at all, or there's exception, it will retry every + `TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS` seconds indefinitely. + Returns: {node_id_hex: (ip, http_port, grpc_port)} + """ while True: - raw_agent_infos = await DataOrganizer.get_agent_infos(target_node_ids) - # Filter out invalid agent infos with unset HTTP port - agent_infos = { - key: value - for key, value in raw_agent_infos.items() - if value.get("httpPort", -1) > 0 - } + try: + keys = await self.gcs_aio_client.internal_kv_keys( + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}".encode(), + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + if not keys: + # No agent keys found, retry + raise Exception() + values: Dict[ + bytes, bytes + ] = await self.gcs_aio_client.internal_kv_multi_get( + keys, + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + prefix_len = len(DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX) + return { + NodeID.from_hex(key[prefix_len:].decode()): json.loads( + value.decode() + ) + for key, value in values.items() + } + + except Exception: + logger.info( + f"Failed to fetch all agent infos, retrying in {TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS} seconds..." + ) + await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) + + async def _fetch_agent_infos( + self, target_node_ids: List[NodeID] + ) -> Dict[NodeID, Tuple[str, int, int]]: + """ + Fetches agent infos for nodes identified by provided node-ids. + + If any of the node-ids is not found, it will retry every + `TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS` seconds indefinitely. - if len(agent_infos) > 0: - return agent_infos + Returns: {node_id_hex: (ip, http_port, grpc_port)} + """ - await asyncio.sleep(dashboard_consts.TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) + while True: + try: + keys = [ + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}" + for node_id in target_node_ids + ] + values: Dict[ + bytes, bytes + ] = await self.gcs_aio_client.internal_kv_multi_get( + keys, + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + if not values or len(values) != len(target_node_ids): + # Not all agent infos found, retry + raise Exception() + prefix_len = len(DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX) + return { + NodeID.from_hex(key[prefix_len:].decode()): json.loads( + value.decode() + ) + for key, value in values.items() + } + except Exception: + logger.info( + f"Failed to fetch agent infos for nodes {target_node_ids}, retrying in {TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS} seconds..." + ) + await asyncio.sleep(TRY_TO_GET_AGENT_INFO_INTERVAL_SECONDS) @routes.get("/api/version") async def get_version(self, req: Request) -> Response: @@ -337,7 +403,7 @@ async def submit_job(self, req: Request) -> Response: try: job_agent_client = await asyncio.wait_for( self.get_target_agent(), - timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + timeout=WAIT_AVAILABLE_AGENT_TIMEOUT, ) resp = await job_agent_client.submit_job_internal(submit_request) except asyncio.TimeoutError: @@ -384,7 +450,7 @@ async def stop_job(self, req: Request) -> Response: try: job_agent_client = await asyncio.wait_for( self.get_target_agent(), - timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + timeout=WAIT_AVAILABLE_AGENT_TIMEOUT, ) resp = await job_agent_client.stop_job_internal(job.submission_id) except Exception: @@ -419,7 +485,7 @@ async def delete_job(self, req: Request) -> Response: try: job_agent_client = await asyncio.wait_for( self.get_target_agent(), - timeout=dashboard_consts.WAIT_AVAILABLE_AGENT_TIMEOUT, + timeout=WAIT_AVAILABLE_AGENT_TIMEOUT, ) resp = await job_agent_client.delete_job_internal(job.submission_id) except Exception: diff --git a/python/ray/dashboard/modules/job/tests/test_http_job_server.py b/python/ray/dashboard/modules/job/tests/test_http_job_server.py index 1441d89bae1b1..bac9c379a870d 100644 --- a/python/ray/dashboard/modules/job/tests/test_http_job_server.py +++ b/python/ray/dashboard/modules/job/tests/test_http_job_server.py @@ -8,13 +8,14 @@ import tempfile import time from pathlib import Path -from typing import Optional +from typing import Optional, List from unittest.mock import patch import pytest import yaml import ray +from ray import NodeID from ray._private.test_utils import ( chdir, format_web_url, @@ -22,6 +23,7 @@ wait_for_condition, wait_until_server_available, ) +from ray.dashboard.consts import DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX from ray.dashboard.modules.dashboard_sdk import ClusterInfo, parse_cluster_info from ray.dashboard.modules.job.job_head import JobHead from ray.dashboard.modules.job.pydantic_models import JobDetails @@ -736,30 +738,58 @@ async def test_job_head_pick_random_job_agent(monkeypatch): importlib.reload(ray.dashboard.consts) - from ray.dashboard.datacenter import DataSource + # Fake GCS client + class _FakeGcsClient: + def __init__(self): + self._kv = {} + + async def internal_kv_put(self, key: bytes, value: bytes, **kwargs): + self._kv[key] = value + + async def internal_kv_get(self, key: bytes, **kwargs): + return self._kv.get(key, None) + + async def internal_kv_multi_get(self, keys: List[bytes], **kwargs): + return {key: self._kv.get(key, None) for key in keys} + + async def internal_kv_del(self, key: bytes, **kwargs): + self._kv.pop(key) + + async def internal_kv_keys(self, prefix: bytes, **kwargs): + return [key for key in self._kv.keys() if key.startswith(prefix)] class MockJobHead(JobHead): def __init__(self): self._agents = dict() - DataSource.agents = {} - DataSource.nodes = {} job_head = MockJobHead() + job_head._gcs_aio_client = _FakeGcsClient() - def add_agent(agent): + async def add_agent(agent): node_id = agent[0] node_ip = agent[1]["ipAddress"] http_port = agent[1]["httpPort"] grpc_port = agent[1]["grpcPort"] - DataSource.nodes[node_id] = {"nodeManagerAddress": node_ip} - DataSource.agents[node_id] = (node_ip, http_port, grpc_port) - def del_agent(agent): - node_id = agent[0] - DataSource.nodes.pop(node_id) - DataSource.agents.pop(node_id) + await job_head._gcs_aio_client.internal_kv_put( + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}".encode(), + json.dumps([node_ip, http_port, grpc_port]).encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) - head_node_id = "node1" + async def del_agent(agent): + node_id = agent[0] + await job_head._gcs_aio_client.internal_kv_del( + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}".encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + ) + + head_node_id = NodeID.from_random() + await job_head._gcs_aio_client.internal_kv_put( + ray_constants.KV_HEAD_NODE_ID_KEY, + head_node_id.hex().encode(), + namespace=ray_constants.KV_NAMESPACE_JOB, + ) agent_1 = ( head_node_id, @@ -771,7 +801,7 @@ def del_agent(agent): ), ) agent_2 = ( - "node2", + NodeID.from_random(), dict( ipAddress="2.2.2.2", httpPort=2, @@ -780,7 +810,7 @@ def del_agent(agent): ), ) agent_3 = ( - "node3", + NodeID.from_random(), dict( ipAddress="3.3.3.3", httpPort=3, @@ -796,12 +826,12 @@ def del_agent(agent): ) # Check only 1 agent present, only agent being returned - add_agent(agent_1) + await add_agent(agent_1) job_agent_client = await job_head.get_target_agent() assert job_agent_client._agent_address == "http://1.1.1.1:1" # Remove only agent, no agents present, should time out - del_agent(agent_1) + await del_agent(agent_1) with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(job_head.get_target_agent(), timeout=3) @@ -812,19 +842,9 @@ def del_agent(agent): ) # Add 3 agents - add_agent(agent_1) - add_agent(agent_2) - add_agent(agent_3) - - # Mock GCS client - class _MockedGCSClient: - async def internal_kv_get(self, key: bytes, **kwargs): - if key == ray_constants.KV_HEAD_NODE_ID_KEY: - return head_node_id.encode() - - return None - - job_head._gcs_aio_client = _MockedGCSClient() + await add_agent(agent_1) + await add_agent(agent_2) + await add_agent(agent_3) # Make sure returned agent is a head-node # NOTE: We run 3 tims to make sure we're not hitting branch probabilistically @@ -853,7 +873,7 @@ async def internal_kv_get(self, key: bytes, **kwargs): for agent in [agent_1, agent_2, agent_3]: if f"http://{agent[1]['httpAddress']}" in addresses_2: break - del_agent(agent) + await del_agent(agent) # Theoretically, the probability of failure is 1/2^100 addresses_3 = set() @@ -871,7 +891,7 @@ async def internal_kv_get(self, key: bytes, **kwargs): for agent in [agent_1, agent_2, agent_3]: if f"http://{agent[1]['httpAddress']}" in addresses_4: break - del_agent(agent) + await del_agent(agent) address = None for _ in range(3): job_agent_client = await job_head.get_target_agent() diff --git a/python/ray/dashboard/modules/job/tests/test_sdk.py b/python/ray/dashboard/modules/job/tests/test_sdk.py index e440cc2efb917..0e1500bc9ce58 100644 --- a/python/ray/dashboard/modules/job/tests/test_sdk.py +++ b/python/ray/dashboard/modules/job/tests/test_sdk.py @@ -7,17 +7,23 @@ from unittest.mock import Mock, patch import pytest -import requests import ray import ray.experimental.internal_kv as kv -from ray._private.ray_constants import DEFAULT_DASHBOARD_AGENT_LISTEN_PORT +from ray._private.ray_constants import ( + DEFAULT_DASHBOARD_AGENT_LISTEN_PORT, + KV_NAMESPACE_DASHBOARD, +) from ray._private.test_utils import ( format_web_url, wait_for_condition, wait_until_server_available, ) -from ray.dashboard.consts import RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR +from ray.dashboard.consts import ( + RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES_ENV_VAR, + DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + GCS_RPC_TIMEOUT_SECONDS, +) from ray.dashboard.modules.dashboard_sdk import ( DEFAULT_DASHBOARD_ADDRESS, ClusterInfo, @@ -28,7 +34,7 @@ from ray.dashboard.tests.conftest import * # noqa from ray.tests.conftest import _ray_start from ray.util.state import list_nodes - +from ray._raylet import GcsClient import psutil @@ -165,12 +171,13 @@ def mock_candidate_number(): os.environ.pop("CANDIDATE_AGENT_NUMBER", None) -def get_register_agents_number(webui_url): - response = requests.get(webui_url + "/internal/node_module") - response.raise_for_status() - result = response.json() - data = result["data"] - return data["registeredAgents"] +def get_register_agents_number(gcs_client): + keys = gcs_client.internal_kv_keys( + prefix=DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + namespace=KV_NAMESPACE_DASHBOARD, + timeout=GCS_RPC_TIMEOUT_SECONDS, + ) + return len(keys) @pytest.mark.parametrize( @@ -195,6 +202,7 @@ def test_job_head_choose_job_agent_E2E(ray_start_cluster_head_with_env_vars): webui_url = cluster.webui_url webui_url = format_web_url(webui_url) client = JobSubmissionClient(webui_url) + gcs_client = GcsClient(address=cluster.gcs_address) def submit_job_and_wait_finish(): submission_id = client.submit_job(entrypoint="echo hello") @@ -206,7 +214,7 @@ def submit_job_and_wait_finish(): head_http_port = DEFAULT_DASHBOARD_AGENT_LISTEN_PORT worker_1_http_port = 52366 cluster.add_node(dashboard_agent_listen_port=worker_1_http_port) - wait_for_condition(lambda: get_register_agents_number(webui_url) == 2, timeout=20) + wait_for_condition(lambda: get_register_agents_number(gcs_client) == 2, timeout=20) assert len(cluster.worker_nodes) == 1 node_try_to_kill = list(cluster.worker_nodes)[0] @@ -250,7 +258,7 @@ def _kill_all_driver(): worker_2_http_port = 52367 cluster.add_node(dashboard_agent_listen_port=worker_2_http_port) - wait_for_condition(lambda: get_register_agents_number(webui_url) == 3, timeout=20) + wait_for_condition(lambda: get_register_agents_number(gcs_client) == 3, timeout=20) # The third `JobAgent` will not be called here. submit_job_and_wait_finish() @@ -281,7 +289,7 @@ def get_all_new_supervisor_actor_info(old_supervisor_actor_ids): node_try_to_kill.kill_raylet() # make sure the head updates the info of the dead node. - wait_for_condition(lambda: get_register_agents_number(webui_url) == 2, timeout=20) + wait_for_condition(lambda: get_register_agents_number(gcs_client) == 2, timeout=20) # Make sure the third JobAgent will be called here. wait_for_condition( @@ -324,6 +332,7 @@ def test_jobs_run_on_head_by_default_E2E(ray_start_cluster_head_with_env_vars): webui_url = cluster.webui_url webui_url = format_web_url(webui_url) client = JobSubmissionClient(webui_url) + gcs_client = GcsClient(address=cluster.gcs_address) def _check_nodes(num_nodes): try: @@ -334,7 +343,7 @@ def _check_nodes(num_nodes): return False wait_for_condition(lambda: _check_nodes(num_nodes=3), timeout=15) - wait_for_condition(lambda: get_register_agents_number(webui_url) == 3, timeout=20) + wait_for_condition(lambda: get_register_agents_number(gcs_client) == 3, timeout=20) # Submit 20 simple jobs. for i in range(20): diff --git a/python/ray/dashboard/modules/job/utils.py b/python/ray/dashboard/modules/job/utils.py index 8c00a7014cecd..4f2a2ca5c5c12 100644 --- a/python/ray/dashboard/modules/job/utils.py +++ b/python/ray/dashboard/modules/job/utils.py @@ -36,13 +36,14 @@ async def get_head_node_id(gcs_aio_client: GcsAioClient) -> Optional[str]: """Fetches Head node id persisted in GCS""" - head_node_id_bytes = await gcs_aio_client.internal_kv_get( + head_node_id_hex_bytes = await gcs_aio_client.internal_kv_get( ray_constants.KV_HEAD_NODE_ID_KEY, namespace=ray_constants.KV_NAMESPACE_JOB, timeout=30, ) - - return head_node_id_bytes.decode() if head_node_id_bytes is not None else None + if head_node_id_hex_bytes is None: + return None + return head_node_id_hex_bytes.decode() def strip_keys_with_value_none(d: Dict[str, Any]) -> Dict[str, Any]: diff --git a/python/ray/dashboard/modules/log/log_manager.py b/python/ray/dashboard/modules/log/log_manager.py index bb21446f15f62..a05b09c8f9d4b 100644 --- a/python/ray/dashboard/modules/log/log_manager.py +++ b/python/ray/dashboard/modules/log/log_manager.py @@ -12,7 +12,6 @@ GetLogOptions, protobuf_to_task_state_dict, ) -from ray.util.state.exception import DataSourceUnavailable from ray.util.state.state_manager import StateDataSourceClient if BaseModel is None: @@ -74,9 +73,8 @@ async def list_logs( Dictionary of {component_name -> list of log files} Raises: - DataSourceUnavailable: If a source is unresponsive. + ValueError: If a source is unresponsive. """ - self._verify_node_registered(node_id) reply = await self.client.list_logs(node_id, glob_filter, timeout=timeout) return self._categorize_log_files(reply.log_files) @@ -126,18 +124,6 @@ async def stream_logs( async for streamed_log in stream: yield streamed_log.data - def _verify_node_registered(self, node_id: str): - if node_id not in self.client.get_all_registered_log_agent_ids(): - raise DataSourceUnavailable( - f"Given node id {node_id} is not available. " - "It's either the node is dead, or it is not registered. " - "Use `ray list nodes` " - "to see the node status. If the node is registered, " - "it is highly likely " - "a transient issue. Try again." - ) - assert node_id is not None - async def _resolve_job_filename(self, sub_job_id: str) -> Tuple[str, str]: """Return the log file name and node id for a given job submission id. @@ -249,7 +235,6 @@ async def _resolve_actor_filename( "Actor is not scheduled yet." ) node_id = NodeID(node_id_binary) - self._verify_node_registered(node_id.hex()) log_filename = await self._resolve_worker_file( node_id_hex=node_id.hex(), worker_id_hex=worker_id.hex(), @@ -415,7 +400,6 @@ async def resolve_filename( "Node id needs to be specified for resolving" f" filenames of pid {pid}" ) - self._verify_node_registered(node_id) log_filename = await self._resolve_worker_file( node_id_hex=node_id, worker_id_hex=None, diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index 8707c6abae196..2b36007cbeb36 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -11,7 +11,6 @@ import grpc import ray._private.utils -import ray.dashboard.consts as dashboard_consts import ray.dashboard.optional_utils as dashboard_optional_utils import ray.dashboard.utils as dashboard_utils from ray._private import ray_constants @@ -30,7 +29,11 @@ parse_usage, ) from ray.core.generated import gcs_pb2, node_manager_pb2, node_manager_pb2_grpc -from ray.dashboard.consts import GCS_RPC_TIMEOUT_SECONDS +from ray.dashboard.consts import ( + GCS_RPC_TIMEOUT_SECONDS, + DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX, + DASHBOARD_AGENT_ADDR_IP_PREFIX, +) from ray.dashboard.datacenter import DataOrganizer, DataSource from ray.dashboard.modules.node import node_consts from ray.dashboard.modules.node.node_consts import ( @@ -125,7 +128,6 @@ def get_internal_states(self): return { "head_node_registration_time_s": self._head_node_registration_time_s, "registered_nodes": len(DataSource.nodes), - "registered_agents": len(DataSource.agents), "module_lifetime_s": time.time() - self._module_start_time, } @@ -195,48 +197,27 @@ async def _update_node(self, node: dict): ) assert node["state"] in ["ALIVE", "DEAD"] is_alive = node["state"] == "ALIVE" - # Prepare agents for alive node, and pop agents for dead node. - if is_alive: - if node_id not in DataSource.agents: - # Agent port is read from internal KV, which is only populated - # upon Agent startup. In case this update received before agent - # fully started up, we schedule a task to asynchronously update - # DataSource with appropriate agent port. - asyncio.create_task(self._update_agent(node_id)) - else: - DataSource.agents.pop(node_id, None) - self._dead_node_queue.append(node_id) - if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE: - DataSource.nodes.pop(self._dead_node_queue.popleft(), None) - DataSource.nodes[node_id] = node - - async def _update_agent(self, node_id): - """ - Given a node, update the agent_port in DataSource.agents. Problem is it's not - present until agent.py starts, so we need to loop waiting for agent.py writes - its port to internal kv. - """ - key = ( - f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}".encode() - ) - while True: - try: - agent_addr = await self.gcs_aio_client.internal_kv_get( + if not is_alive: + # Remove the agent address from the internal KV. + keys = [ + f"{DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id}", + f"{DASHBOARD_AGENT_ADDR_IP_PREFIX}{node['nodeManagerAddress']}", + ] + tasks = [ + self.gcs_aio_client.internal_kv_del( key, + del_by_prefix=False, namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - timeout=None, + timeout=GCS_RPC_TIMEOUT_SECONDS, ) - # The node may be dead already. Only update DataSource.agents if the - # node is still alive. - if DataSource.nodes.get(node_id, {}).get("state") != "ALIVE": - return - if agent_addr: - DataSource.agents[node_id] = json.loads(agent_addr) - return - except Exception: - logger.exception(f"Error getting agent port for node {node_id}.") + for key in keys + ] + await asyncio.gather(*tasks) - await asyncio.sleep(node_consts.RAY_DASHBOARD_AGENT_POLL_INTERVAL_S) + self._dead_node_queue.append(node_id) + if len(self._dead_node_queue) > node_consts.MAX_DEAD_NODES_TO_CACHE: + DataSource.nodes.pop(self._dead_node_queue.popleft(), None) + DataSource.nodes[node_id] = node async def _update_nodes(self): """ @@ -263,14 +244,6 @@ async def _update_nodes(self): ) warning_shown = True - @routes.get("/internal/node_module") - async def get_node_module_internal_state(self, req) -> aiohttp.web.Response: - return dashboard_optional_utils.rest_response( - success=True, - message="", - **self.get_internal_states(), - ) - async def get_nodes_logical_resources(self) -> dict: from ray.autoscaler.v2.utils import is_autoscaler_v2 diff --git a/python/ray/dashboard/modules/node/tests/test_node.py b/python/ray/dashboard/modules/node/tests/test_node.py index 93d6cbd600991..8884b93619e1e 100644 --- a/python/ray/dashboard/modules/node/tests/test_node.py +++ b/python/ray/dashboard/modules/node/tests/test_node.py @@ -43,22 +43,10 @@ def test_nodes_update(enable_test_module, ray_start_with_dashboard): assert dump_info["result"] is True dump_data = dump_info["data"] assert len(dump_data["nodes"]) == 1 - assert len(dump_data["agents"]) == 1 - - response = requests.get(webui_url + "/test/notified_agents") - response.raise_for_status() - try: - notified_agents = response.json() - except Exception as ex: - logger.info("failed response: %s", response.text) - raise ex - assert notified_agents["result"] is True - notified_agents = notified_agents["data"] - assert len(notified_agents) == 1 - assert notified_agents == dump_data["agents"] break + except (AssertionError, requests.exceptions.ConnectionError) as e: - logger.info("Retry because of %s", e) + logger.exception("Retry") finally: if time.time() > start_time + timeout_seconds: raise Exception("Timed out while testing.") @@ -190,10 +178,6 @@ def _check_nodes(): else: assert detail["raylet"]["state"] == "DEAD" assert detail["raylet"].get("objectStoreAvailableMemory", 0) == 0 - response = requests.get(webui_url + "/test/dump?key=agents") - response.raise_for_status() - agents = response.json() - assert len(agents["data"]["agents"]) == 3 return True except Exception as ex: logger.info(ex) diff --git a/python/ray/dashboard/modules/state/state_head.py b/python/ray/dashboard/modules/state/state_head.py index 824fe30265251..1f32eda3574ad 100644 --- a/python/ray/dashboard/modules/state/state_head.py +++ b/python/ray/dashboard/modules/state/state_head.py @@ -75,7 +75,6 @@ def __init__( ) DataSource.nodes.signal.append(self._update_raylet_stubs) - DataSource.agents.signal.append(self._update_agent_stubs) async def limit_handler_(self): return do_reply( @@ -119,20 +118,6 @@ async def _update_raylet_stubs(self, change: Change): int(node_info["runtimeEnvAgentPort"]), ) - async def _update_agent_stubs(self, change: Change): - """Callback that's called when a new agent is added to Datasource.""" - if change.old: - node_id, _ = change.old - self._state_api_data_source_client.unregister_agent_client(node_id) - if change.new: - # When a new node information is written to DataSource. - node_id, (node_ip, http_port, grpc_port) = change.new - self._state_api_data_source_client.register_agent_client( - node_id, - node_ip, - grpc_port, - ) - @routes.get("/api/v0/actors") @RateLimitedModule.enforce_max_concurrent_calls async def list_actors(self, req: aiohttp.web.Request) -> aiohttp.web.Response: diff --git a/python/ray/dashboard/modules/tests/test_head.py b/python/ray/dashboard/modules/tests/test_head.py index 98e46f2fa9828..4258052facb87 100644 --- a/python/ray/dashboard/modules/tests/test_head.py +++ b/python/ray/dashboard/modules/tests/test_head.py @@ -20,16 +20,6 @@ class TestHead(dashboard_utils.DashboardHeadModule): def __init__(self, config: dashboard_utils.DashboardHeadModuleConfig): super().__init__(config) - self._notified_agents = {} - DataSource.agents.signal.append(self._update_notified_agents) - - async def _update_notified_agents(self, change): - if change.old: - node_id, _ = change.old - self._notified_agents.pop(node_id) - if change.new: - node_id, (node_ip, http_port, grpc_port) = change.new - self._notified_agents[node_id] = (node_ip, http_port, grpc_port) @staticmethod def is_minimal_module(): @@ -73,14 +63,6 @@ async def dump(self, req) -> aiohttp.web.Response: **{key: data}, ) - @routes.get("/test/notified_agents") - async def get_notified_agents(self, req) -> aiohttp.web.Response: - return dashboard_optional_utils.rest_response( - success=True, - message="Fetch notified agents success.", - **self._notified_agents, - ) - @routes.get("/test/http_get") async def get_url(self, req) -> aiohttp.web.Response: url = req.query.get("url") diff --git a/python/ray/dashboard/tests/test_dashboard.py b/python/ray/dashboard/tests/test_dashboard.py index fd652926ea96d..82b6be6ae4393 100644 --- a/python/ray/dashboard/tests/test_dashboard.py +++ b/python/ray/dashboard/tests/test_dashboard.py @@ -381,11 +381,15 @@ def test_http_get(enable_test_module, ray_start_with_dashboard): logger.info("failed response: %s", response.text) raise ex assert dump_info["result"] is True - dump_data = dump_info["data"] - assert len(dump_data["agents"]) == 1 - node_id, (node_ip, http_port, grpc_port) = next( - iter(dump_data["agents"].items()) + + # Get agent ip and http port + node_id_hex = ray_start_with_dashboard["node_id"] + agent_addr = ray.experimental.internal_kv._internal_kv_get( + f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id_hex}", + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, ) + assert agent_addr is not None + node_ip, http_port, _ = json.loads(agent_addr) response = requests.get( f"http://{node_ip}:{http_port}" diff --git a/python/ray/tests/test_memory_pressure.py b/python/ray/tests/test_memory_pressure.py index 1a6f82a02896a..9dbc7a72cccfb 100644 --- a/python/ray/tests/test_memory_pressure.py +++ b/python/ray/tests/test_memory_pressure.py @@ -42,7 +42,6 @@ def get_local_state_client(): port = int(node["NodeManagerPort"]) runtime_env_agent_port = int(node["RuntimeEnvAgentPort"]) client.register_raylet_client(node_id, ip, port, runtime_env_agent_port) - client.register_agent_client(node_id, ip, port) return client diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 11f9f620cabad..156af8ac3d199 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -1665,8 +1665,6 @@ def get_addr(): ) wait_for_condition(lambda: get_addr() is not None) - ip, http_port, grpc_port = json.loads(get_addr()) - client.register_agent_client(node_id, ip, grpc_port) result = await client.get_runtime_envs_info(node_id) assert isinstance(result, GetRuntimeEnvsInfoReply) @@ -1835,8 +1833,6 @@ def get_addr(): ) wait_for_condition(lambda: get_addr() is not None) - ip, http_port, grpc_port = json.loads(get_addr()) - client.register_agent_client(node_id, ip, grpc_port) @ray.remote class Actor: diff --git a/python/ray/tests/test_state_api_log.py b/python/ray/tests/test_state_api_log.py index 2aefa941b7137..d9031ced47638 100644 --- a/python/ray/tests/test_state_api_log.py +++ b/python/ray/tests/test_state_api_log.py @@ -4,7 +4,6 @@ import asyncio from typing import List import urllib -import re from unittest.mock import MagicMock, AsyncMock import pytest @@ -45,7 +44,7 @@ from ray.dashboard.tests.conftest import * # noqa from ray.util.state import get_log, list_logs, list_nodes, list_workers from ray.util.state.common import GetLogOptions -from ray.util.state.exception import DataSourceUnavailable, RayStateApiException +from ray.util.state.exception import RayStateApiException from ray.util.state.state_manager import StateDataSourceClient @@ -441,16 +440,15 @@ async def generate_logs_stream(num_chunks: int): async def test_logs_manager_list_logs(logs_manager): logs_client = logs_manager.data_source_client - logs_client.get_all_registered_log_agent_ids = MagicMock() - logs_client.get_all_registered_log_agent_ids.return_value = ["1", "2"] + async def my_list_logs(node_id, glob_filter, timeout): + if node_id != "2": + raise ValueError("Agent for node id: 3 doesn't exist.") + return generate_list_logs(["gcs_server.out"]) - logs_client.list_logs.side_effect = [ - generate_list_logs(["gcs_server.out"]), - DataSourceUnavailable(), - ] + logs_client.list_logs = AsyncMock() + logs_client.list_logs.side_effect = my_list_logs - # Unregistered node id should raise a DataSourceUnavailable. - with pytest.raises(DataSourceUnavailable): + with pytest.raises(ValueError): result = await logs_manager.list_logs( node_id="3", timeout=30, glob_filter="*gcs*" ) @@ -459,12 +457,8 @@ async def test_logs_manager_list_logs(logs_manager): assert len(result) == 1 assert result["gcs_server"] == ["gcs_server.out"] assert result["raylet"] == [] - logs_client.get_all_registered_log_agent_ids.assert_called() - logs_client.list_logs.assert_awaited_with("2", "*gcs*", timeout=30) - # The second call raises DataSourceUnavailable, which will - # return DataSourceUnavailable to the caller. - with pytest.raises(DataSourceUnavailable): + with pytest.raises(ValueError): result = await logs_manager.list_logs( node_id="1", timeout=30, glob_filter="*gcs*" ) @@ -477,8 +471,6 @@ async def test_logs_manager_resolve_file(logs_manager): Test filename is given. """ logs_client = logs_manager.data_source_client - logs_client.get_all_registered_log_agent_ids = MagicMock() - logs_client.get_all_registered_log_agent_ids.return_value = [node_id.hex()] expected_filename = "filename" res = await logs_manager.resolve_filename( node_id=node_id.hex(), @@ -699,8 +691,6 @@ async def test_logs_manager_stream_log(logs_manager): NUM_LOG_CHUNKS = 10 logs_client = logs_manager.data_source_client - logs_client.get_all_registered_log_agent_ids = MagicMock() - logs_client.get_all_registered_log_agent_ids.return_value = ["1", "2"] logs_client.ip_to_node_id = MagicMock() logs_client.stream_log.return_value = generate_logs_stream(NUM_LOG_CHUNKS) @@ -771,8 +761,6 @@ async def test_logs_manager_keepalive_no_timeout(logs_manager): NUM_LOG_CHUNKS = 10 logs_client = logs_manager.data_source_client - logs_client.get_all_registered_log_agent_ids = MagicMock() - logs_client.get_all_registered_log_agent_ids.return_value = ["1", "2"] logs_client.ip_to_node_id = MagicMock() logs_client.stream_log.return_value = generate_logs_stream(NUM_LOG_CHUNKS) @@ -1011,8 +999,8 @@ def verify(): with pytest.raises(requests.HTTPError) as e: list_logs(node_id=node_id) - assert re.match( - f"Given node id {node_id} is not available", e.value.response.json()["msg"] + assert ( + f"Agent for node id: {node_id} doesn't exist." in e.value.response.json()["msg"] ) diff --git a/python/ray/util/state/state_manager.py b/python/ray/util/state/state_manager.py index a4afa825f4668..bde90ae6a8e15 100644 --- a/python/ray/util/state/state_manager.py +++ b/python/ray/util/state/state_manager.py @@ -4,6 +4,7 @@ from collections import defaultdict from functools import wraps from typing import List, Optional, Tuple +import json import aiohttp import grpc @@ -11,6 +12,7 @@ import ray import ray.dashboard.modules.log.log_consts as log_consts +import ray.dashboard.consts as dashboard_consts from ray._private import ray_constants from ray._private.gcs_utils import GcsAioClient from ray._private.utils import hex_to_binary @@ -154,7 +156,6 @@ def __init__(self, gcs_channel: grpc.aio.Channel, gcs_aio_client: GcsAioClient): self.register_gcs_client(gcs_channel) self._raylet_stubs = {} self._runtime_env_agent_addresses = {} # {node_id -> url} - self._log_agent_stub = {} self._job_client = JobInfoStorageClient(gcs_aio_client) self._id_ip_map = IdToIpMap() self._gcs_aio_client = gcs_aio_client @@ -204,18 +205,6 @@ def unregister_raylet_client(self, node_id: str): self._runtime_env_agent_addresses.pop(node_id) self._id_ip_map.pop(node_id) - def register_agent_client(self, node_id, address: str, port: int): - options = _STATE_MANAGER_GRPC_OPTIONS - channel = ray._private.utils.init_grpc_channel( - f"{address}:{port}", options=options, asynchronous=True - ) - self._log_agent_stub[node_id] = LogServiceStub(channel) - self._id_ip_map.put(node_id, address) - - def unregister_agent_client(self, node_id: str): - self._log_agent_stub.pop(node_id) - self._id_ip_map.pop(node_id) - def get_all_registered_raylet_ids(self) -> List[str]: return self._raylet_stubs.keys() @@ -223,9 +212,21 @@ def get_all_registered_raylet_ids(self) -> List[str]: def get_all_registered_runtime_env_agent_ids(self) -> List[str]: return self._runtime_env_agent_addresses.keys() - # Returns all nod_ids which registered their log_agent_stub. - def get_all_registered_log_agent_ids(self) -> List[str]: - return self._log_agent_stub.keys() + async def get_log_service_stub(self, node_id: NodeID) -> LogServiceStub: + """Returns None if the agent on the node is not registered in Internal KV.""" + agent_addr = await self._gcs_aio_client.internal_kv_get( + f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{node_id.hex()}".encode(), + namespace=ray_constants.KV_NAMESPACE_DASHBOARD, + timeout=dashboard_consts.GCS_RPC_TIMEOUT_SECONDS, + ) + if not agent_addr: + return None + ip, http_port, grpc_port = json.loads(agent_addr) + options = ray_constants.GLOBAL_GRPC_OPTIONS + channel = ray._private.utils.init_grpc_channel( + f"{ip}:{grpc_port}", options=options, asynchronous=True + ) + return LogServiceStub(channel) def ip_to_node_id(self, ip: Optional[str]) -> Optional[str]: """Return the node id that corresponds to the given ip. @@ -495,7 +496,7 @@ async def get_runtime_envs_info( async def list_logs( self, node_id: str, glob_filter: str, timeout: int = None ) -> ListLogsReply: - stub = self._log_agent_stub.get(node_id) + stub = await self.get_log_service_stub(NodeID.from_hex(node_id)) if not stub: raise ValueError(f"Agent for node id: {node_id} doesn't exist.") return await stub.ListLogs( @@ -514,7 +515,7 @@ async def stream_log( start_offset: Optional[int] = None, end_offset: Optional[int] = None, ) -> UnaryStreamCall: - stub = self._log_agent_stub.get(node_id) + stub = await self.get_log_service_stub(NodeID.from_hex(node_id)) if not stub: raise ValueError(f"Agent for node id: {node_id} doesn't exist.")