Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable GCP Workload Monitoring #932

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .axlearn/axlearn.default.config
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,11 @@ default_dockerfile = "Dockerfile"
# (Optional) Enable VertexAI Tensorboard support during training.
vertexai_tensorboard = "1231231231231231231"
vertexai_region = "us-central1"

# (Optional) Enable GCP Workload Monitoring
# Change to true to send workload metrics to Google Cloud Monitoring
enable_gcp_workload_monitoring = false
# Used to identify the current ML workload
workload_id = "my_workload_id" # Optional (defaults to environment variable 'jobset_name' or 'job_name' else it will be "unknown")
# Used to identify the replica id of the workload, default to "0"
replica_id = "0" # Optional (default: "0")
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
| [Concepts](docs/02-concepts.md) | Core concepts and design principles. |
| [CLI User Guide](docs/03-cli.md) | How to use the CLI. |
| [Infrastructure](docs/04-infrastructure.md) | Core infrastructure components. |
| [GCP Monitoring](docs/05-monitoring.md) | Enable GCP Monitoring. |

## Introduction

Expand Down
164 changes: 164 additions & 0 deletions axlearn/cloud/gcp/monitoring/monitor_workload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Module for GCP Workload Monitoring, enabling performance and heartbeat metrics
to be sent to Google Cloud Monitoring.
"""

import os
import time

from absl import app, logging
from google.api import metric_pb2, monitored_resource_pb2
from google.api_core.exceptions import GoogleAPIError
from google.cloud import monitoring_v3

from axlearn.cloud.gcp.config import gcp_settings
from axlearn.cloud.gcp.utils import get_gcp_metadata


class GCPWorkloadMonitoring:
"""A class to send metrics to Google Cloud Monitoring."""

def __init__(self, project_id, zone, workload_id, replica_id):
self.project_id = project_id
self.zone = zone
self.workload_id = workload_id
self.replica_id = replica_id
self.client = monitoring_v3.MetricServiceClient()
self.project_name = f"projects/{self.project_id}"

def check_connectivity(self):
"""Checks connectivity to the specified GCP project and zone."""
try:
# Check if the project is accessible
resource_descriptors = self.client.list_monitored_resource_descriptors(
name=self.project_name
)
if resource_descriptors:
logging.info("Successfully connected to GCP project %s.", self.project_id)
except GoogleAPIError as e:
raise ValueError(f"Unable to connect to GCP project {self.project_id}: {e}") from e

def send_performance_metric(self, perf_metric: float):
"""Send performance metric to Google Cloud Monitoring."""

metric_type = "compute.googleapis.com/workload/performance"
resource_type = "compute.googleapis.com/Workload"

try:
now = time.time()
seconds = int(now)
nanos = int((now - seconds) * 10**9)

# Create a TimeSeries object for the step time metric
series = monitoring_v3.TimeSeries(
metric=metric_pb2.Metric(
type=metric_type,
),
resource=monitored_resource_pb2.MonitoredResource(
type=resource_type,
labels={
"location": self.zone,
"workload_id": self.workload_id,
"replica_id": self.replica_id,
},
),
points=[
monitoring_v3.Point(
interval=monitoring_v3.TimeInterval(
end_time={"seconds": seconds, "nanos": nanos}
),
value=monitoring_v3.TypedValue(double_value=perf_metric),
),
],
)

# Send data to Google Cloud Monitoring
self.client.create_time_series(
request={"name": self.project_name, "time_series": [series]}
)
logging.info(
"Perf metric (%.3f) successfully sent to GCP resource %s.",
perf_metric,
resource_type,
)
except GoogleAPIError as e:
logging.error("Failed to send metric to GCP. Metric: %s, Error: %s", metric_type, e)

def send_heartbeat_metric(self, local_rank: str, global_rank: str):
"""Send heartbeat metric to Google Cloud Monitoring."""

is_alive = True
metric_type = "compute.googleapis.com/workload_process/heartbeat"
resource_type = "compute.googleapis.com/WorkloadProcess"

try:
now = time.time()
seconds = int(now)
nanos = int((now - seconds) * 10**9)

# Create a TimeSeries object for the heartbeat metric
series = monitoring_v3.TimeSeries(
metric=metric_pb2.Metric(
type=metric_type,
labels={
"local_rank": local_rank,
"instance_id": get_gcp_metadata(category="instance", attribute="id"),
},
),
resource=monitored_resource_pb2.MonitoredResource(
type=resource_type,
labels={
"project_id": self.project_id,
"location": self.zone,
"workload_id": self.workload_id,
"replica_id": self.replica_id,
"process_id": global_rank,
},
),
points=[
monitoring_v3.Point(
interval=monitoring_v3.TimeInterval(
end_time={"seconds": seconds, "nanos": nanos}
),
value=monitoring_v3.TypedValue(bool_value=is_alive),
),
],
)

# Send data to Google Cloud Monitoring
self.client.create_time_series(
request={"name": self.project_name, "time_series": [series]}
)
logging.info(
"Heartbeat metric successfully sent to GCP with value %s for resource %s.",
is_alive,
resource_type,
)
except GoogleAPIError as e:
logging.error("Failed to send metric to GCP. Metric: %s, Error: %s", metric_type, e)


def main(argv):
del argv # Unused argv

# Initialize the monitoring class
enable_gcp_workload_monitoring = gcp_settings("enable_gcp_workload_monitoring", default=False)

if enable_gcp_workload_monitoring:
workload_id = os.environ.get("JOBSET_NAME") or os.environ.get("JOB_NAME") or "unknown"
workload_id = gcp_settings("workload_id", default=workload_id)
monitor = GCPWorkloadMonitoring(
project_id=gcp_settings("project", required=True),
zone=gcp_settings("zone", required=True),
workload_id=gcp_settings("workload_id", default=workload_id),
replica_id=gcp_settings("replica_id", default="0"),
)
# Check Connectivity
monitor.check_connectivity()
# Example: Send metrics
monitor.send_performance_metric(perf_metric=0.123)
monitor.send_heartbeat_metric(local_rank="0", global_rank="0")


if __name__ == "__main__":
app.run(main)
163 changes: 163 additions & 0 deletions axlearn/cloud/gcp/monitoring/monitor_workload_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""
Unit tests for GCPWorkloadMonitoring in axlearn.cloud.gcp.monitoring.monitor_workload.

This module provides tests for verifying the behavior of GCPWorkloadMonitoring.
The tests ensure connectivity, performance metrics, and heartbeat metrics are handled correctly.
"""

import unittest
from unittest.mock import MagicMock, patch

from google.api_core.exceptions import GoogleAPIError

from axlearn.cloud.gcp.monitoring.monitor_workload import GCPWorkloadMonitoring


class TestGCPWorkloadMonitoring(unittest.TestCase):
"""
Test suite for the GCPWorkloadMonitoring class.

It tests:
- Connectivity validation.
- Performance metric submission.
- Heartbeat metric submission.
"""

def setUp(self):
self.project_id = "test-project"
self.zone = "test-zone"
self.workload_id = "test-workload"
self.replica_id = "test-replica"

@patch("axlearn.cloud.gcp.monitoring.monitor_workload.monitoring_v3.MetricServiceClient")
def test_check_connectivity_success(self, mock_metric_service_client):
# Mock the client and its response
mock_client_instance = MagicMock()
mock_metric_service_client.return_value = mock_client_instance
mock_client_instance.list_monitored_resource_descriptors.return_value = [MagicMock()]

# Reinitialize monitor after patching
self.monitor = GCPWorkloadMonitoring(
project_id=self.project_id,
zone=self.zone,
workload_id=self.workload_id,
replica_id=self.replica_id,
)

self.monitor.check_connectivity()
mock_metric_service_client.assert_called_once()
mock_client_instance.list_monitored_resource_descriptors.assert_called_once_with(
name=f"projects/{self.project_id}"
)

@patch("axlearn.cloud.gcp.monitoring.monitor_workload.monitoring_v3.MetricServiceClient")
def test_check_connectivity_failure(self, mock_metric_service_client):
# Mock the client and simulate an error
mock_client_instance = MagicMock()
mock_metric_service_client.return_value = mock_client_instance
mock_client_instance.list_monitored_resource_descriptors.side_effect = GoogleAPIError(
"API Error"
)

# Reinitialize monitor after patching
self.monitor = GCPWorkloadMonitoring(
project_id=self.project_id,
zone=self.zone,
workload_id=self.workload_id,
replica_id=self.replica_id,
)

with self.assertRaises(ValueError):
self.monitor.check_connectivity()

@patch("axlearn.cloud.gcp.monitoring.monitor_workload.monitoring_v3.MetricServiceClient")
def test_send_performance_metric_success(self, mock_metric_service_client):
# Mock the client and its response
mock_client_instance = MagicMock()
mock_metric_service_client.return_value = mock_client_instance
mock_client_instance.create_time_series.return_value = None

# Reinitialize monitor after patching
self.monitor = GCPWorkloadMonitoring(
project_id=self.project_id,
zone=self.zone,
workload_id=self.workload_id,
replica_id=self.replica_id,
)

self.monitor.send_performance_metric(perf_metric=0.123)
mock_metric_service_client.assert_called_once()
mock_client_instance.create_time_series.assert_called_once()

@patch("axlearn.cloud.gcp.monitoring.monitor_workload.monitoring_v3.MetricServiceClient")
def test_send_performance_metric_failure(self, mock_metric_service_client):
# Mock the client and simulate an error
mock_client_instance = MagicMock()
mock_metric_service_client.return_value = mock_client_instance
mock_client_instance.create_time_series.side_effect = GoogleAPIError("API Error")

# Reinitialize monitor after patching
self.monitor = GCPWorkloadMonitoring(
project_id=self.project_id,
zone=self.zone,
workload_id=self.workload_id,
replica_id=self.replica_id,
)

with patch(
"axlearn.cloud.gcp.monitoring.monitor_workload.logging.error"
) as mock_logging_error:
self.monitor.send_performance_metric(perf_metric=0.123)
mock_logging_error.assert_called()

@patch("axlearn.cloud.gcp.monitoring.monitor_workload.get_gcp_metadata")
@patch("axlearn.cloud.gcp.monitoring.monitor_workload.monitoring_v3.MetricServiceClient")
def test_send_heartbeat_metric_success(self, mock_metric_service_client, mock_get_gcp_metadata):
# Mock the client and its response
mock_client_instance = MagicMock()
mock_metric_service_client.return_value = mock_client_instance
mock_client_instance.create_time_series.return_value = None

# Mock metadata fetching
mock_get_gcp_metadata.return_value = "test-instance-id"

# Reinitialize monitor after patching
self.monitor = GCPWorkloadMonitoring(
project_id=self.project_id,
zone=self.zone,
workload_id=self.workload_id,
replica_id=self.replica_id,
)

self.monitor.send_heartbeat_metric(local_rank="0", global_rank="1")
mock_metric_service_client.assert_called_once()
mock_client_instance.create_time_series.assert_called_once()

@patch("axlearn.cloud.gcp.monitoring.monitor_workload.get_gcp_metadata")
@patch("axlearn.cloud.gcp.monitoring.monitor_workload.monitoring_v3.MetricServiceClient")
def test_send_heartbeat_metric_failure(self, mock_metric_service_client, mock_get_gcp_metadata):
# Mock the client and simulate an error
mock_client_instance = MagicMock()
mock_metric_service_client.return_value = mock_client_instance
mock_client_instance.create_time_series.side_effect = GoogleAPIError("API Error")

# Mock metadata fetching
mock_get_gcp_metadata.return_value = "test-instance-id"

# Reinitialize monitor after patching
self.monitor = GCPWorkloadMonitoring(
project_id=self.project_id,
zone=self.zone,
workload_id=self.workload_id,
replica_id=self.replica_id,
)

with patch(
"axlearn.cloud.gcp.monitoring.monitor_workload.logging.error"
) as mock_logging_error:
self.monitor.send_heartbeat_metric(local_rank="0", global_rank="1")
mock_logging_error.assert_called()


if __name__ == "__main__":
unittest.main()
Loading