Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ray Autoscaler to the Flyte-Ray plugin #1937

Merged
merged 10 commits into from
Mar 15, 2024
40 changes: 36 additions & 4 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ def __init__(
self,
group_name: str,
replicas: int,
min_replicas: typing.Optional[int] = 0,
min_replicas: typing.Optional[int] = None,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
):
self._group_name = group_name
self._replicas = replicas
self._min_replicas = min_replicas
self._max_replicas = max_replicas if max_replicas else replicas
self._max_replicas = max(replicas, max_replicas) if max_replicas is not None else replicas
self._min_replicas = min(replicas, min_replicas) if min_replicas is not None else replicas
self._ray_start_params = ray_start_params

@property
Expand Down Expand Up @@ -127,10 +127,14 @@ class RayCluster(_common.FlyteIdlEntity):
"""

def __init__(
self, worker_group_spec: typing.List[WorkerGroupSpec], head_group_spec: typing.Optional[HeadGroupSpec] = None
self,
worker_group_spec: typing.List[WorkerGroupSpec],
head_group_spec: typing.Optional[HeadGroupSpec] = None,
enable_autoscaling: bool = False,
):
self._head_group_spec = head_group_spec
self._worker_group_spec = worker_group_spec
self._enable_autoscaling = enable_autoscaling

@property
def head_group_spec(self) -> HeadGroupSpec:
Expand All @@ -148,13 +152,22 @@ def worker_group_spec(self) -> typing.List[WorkerGroupSpec]:
"""
return self._worker_group_spec

@property
def enable_autoscaling(self) -> bool:
"""
Whether to enable autoscaling.
:rtype: bool
"""
return self._enable_autoscaling

def to_flyte_idl(self) -> _ray_pb2.RayCluster:
"""
:rtype: flyteidl.plugins._ray_pb2.RayCluster
"""
return _ray_pb2.RayCluster(
head_group_spec=self.head_group_spec.to_flyte_idl() if self.head_group_spec else None,
worker_group_spec=[wg.to_flyte_idl() for wg in self.worker_group_spec],
enable_autoscaling=self.enable_autoscaling,
)

@classmethod
Expand All @@ -166,6 +179,7 @@ def from_flyte_idl(cls, proto):
return cls(
head_group_spec=HeadGroupSpec.from_flyte_idl(proto.head_group_spec) if proto.head_group_spec else None,
worker_group_spec=[WorkerGroupSpec.from_flyte_idl(wg) for wg in proto.worker_group_spec],
enable_autoscaling=proto.enable_autoscaling,
)


Expand All @@ -178,9 +192,13 @@ def __init__(
self,
ray_cluster: RayCluster,
runtime_env: typing.Optional[str],
ttl_seconds_after_finished: typing.Optional[int] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a no-op if shutdown_after_job_finishes is set to False, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

shutdown_after_job_finishes: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean that by default the rayjob will not be reclaimed by kuberay once the job finishes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

):
self._ray_cluster = ray_cluster
self._runtime_env = runtime_env
self._ttl_seconds_after_finished = ttl_seconds_after_finished
self._shutdown_after_job_finishes = shutdown_after_job_finishes

@property
def ray_cluster(self) -> RayCluster:
Expand All @@ -190,15 +208,29 @@ def ray_cluster(self) -> RayCluster:
def runtime_env(self) -> typing.Optional[str]:
return self._runtime_env

@property
def ttl_seconds_after_finished(self) -> typing.Optional[int]:
# ttl_seconds_after_finished specifies the number of seconds after which the RayCluster will be deleted after the RayJob finishes.
return self._ttl_seconds_after_finished

@property
def shutdown_after_job_finishes(self) -> bool:
# shutdown_after_job_finishes specifies whether the RayCluster should be deleted after the RayJob finishes.
return self._shutdown_after_job_finishes

def to_flyte_idl(self) -> _ray_pb2.RayJob:
return _ray_pb2.RayJob(
ray_cluster=self.ray_cluster.to_flyte_idl(),
runtime_env=self.runtime_env,
ttl_seconds_after_finished=self.ttl_seconds_after_finished,
shutdown_after_job_finishes=self.shutdown_after_job_finishes,
)

@classmethod
def from_flyte_idl(cls, proto: _ray_pb2.RayJob):
return cls(
ray_cluster=RayCluster.from_flyte_idl(proto.ray_cluster) if proto.ray_cluster else None,
runtime_env=proto.runtime_env,
ttl_seconds_after_finished=proto.ttl_seconds_after_finished,
shutdown_after_job_finishes=proto.shutdown_after_job_finishes,
)
6 changes: 6 additions & 0 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ class WorkerNodeConfig:
class RayJobConfig:
worker_node_config: typing.List[WorkerNodeConfig]
head_node_config: typing.Optional[HeadNodeConfig] = None
enable_autoscaling: bool = False
runtime_env: typing.Optional[dict] = None
address: typing.Optional[str] = None
shutdown_after_job_finishes: bool = False
ttl_seconds_after_finished: typing.Optional[int] = None


class RayFunctionTask(PythonFunctionTask):
Expand Down Expand Up @@ -67,9 +70,12 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
WorkerGroupSpec(c.group_name, c.replicas, c.min_replicas, c.max_replicas, c.ray_start_params)
for c in cfg.worker_node_config
],
enable_autoscaling=cfg.enable_autoscaling if cfg.enable_autoscaling else False,
),
# Use base64 to encode runtime_env dict and convert it to byte string
runtime_env=base64.b64encode(json.dumps(cfg.runtime_env).encode()).decode(),
ttl_seconds_after_finished=cfg.ttl_seconds_after_finished,
shutdown_after_job_finishes=cfg.shutdown_after_job_finishes,
)
return MessageToDict(ray_job.to_flyte_idl())

Expand Down
10 changes: 8 additions & 2 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from flytekit.configuration import Image, ImageConfig, SerializationSettings

config = RayJobConfig(
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3)],
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10)],
runtime_env={"pip": ["numpy"]},
enable_autoscaling=True,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=20,
Comment on lines +15 to +17
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a new release of flyteidl.

)


Expand All @@ -37,8 +40,11 @@ def t1(a: int) -> str:
)

ray_job_pb = RayJob(
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3)]),
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec("test_group", 3, 0, 10)]),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
enable_autoscaling=True,
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=20,
).to_flyte_idl()

assert t1.get_custom(settings) == MessageToDict(ray_job_pb)
Expand Down
Loading