Skip to content

Commit

Permalink
add gpu and cpu metric (#33)
Browse files Browse the repository at this point in the history
* refactor: create prime metric module

refactor: create prime metric module

* add cpu and gpu metric

* apply jannich suggestion

Co-authored-by: JannikSt <[email protected]>

---------

Co-authored-by: JannikSt <[email protected]>
  • Loading branch information
samsja and JannikSt authored Feb 3, 2025
1 parent a2c5e55 commit 42383e4
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 43 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"google-cloud-storage",
"tomli",
"docker>=7.1.0",
"pynvml>=12.0.0"
]

[project.optional-dependencies]
Expand Down
13 changes: 8 additions & 5 deletions src/genesys/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import rich.progress
from transformers import AutoTokenizer
import random
from genesys.utils import log_prime
from genesys.prime_metrics import PrimeMetric


class DataConfig(BaseConfig):
Expand All @@ -18,6 +18,8 @@ class DataConfig(BaseConfig):

prime_log: bool = False

prime_log_freq: int = 5


def repeat_elements(lst, n):
return [item for item in lst for _ in range(n)]
Expand Down Expand Up @@ -85,6 +87,8 @@ def _add_column(dataset, path):
for i, length in enumerate(self.dataset_lengths)
]

self.prime_metric = PrimeMetric(disable=not (config.prime_log), period=config.prime_log_freq)

def _prepare_batch(self, batch: dict, dataset: str) -> tuple:
batch = repeat_elements(
[b for b in batch], self.config.num_responses_per_question
Expand Down Expand Up @@ -129,7 +133,6 @@ def __iter__(self) -> Generator[tuple, None, None]:
break

def log_progress_prime(self, paths: list[str], dataset_counters: list[int]):
if self.config.prime_log:
metric = {path: counter for path, counter in zip(paths, dataset_counters)}
metric.update({"total": sum(dataset_counters)})
log_prime(metric)
metric = {path: counter for path, counter in zip(paths, dataset_counters)}
metric.update({"total": sum(dataset_counters)})
self.prime_metric.log_prime(metric)
131 changes: 131 additions & 0 deletions src/genesys/prime_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import socket
import platform
import json
import os
from typing import Any
import psutil
import threading
import time
import pynvml


class PrimeMetric:
"""
A class to log metrics to Prime Miner via Unix socket.
Periodically collects and logs system metrics including CPU, memory and GPU usage.
Args:
disable (bool): If True, disables metric logging. Defaults to False.
period (int): Collection interval in seconds. Defaults to 5.
Usage:
metrics = PrimeMetric()
metrics.log_prime({"custom_metric": value})
"""

def __init__(self, disable: bool = False, period: int = 5):
self.disable = disable
self.period = period
self._thread = None
self._stop_event = threading.Event()
self._start_metrics_thread()

self.has_gpu = False
try:
pynvml.nvmlInit()
pynvml.nvmlDeviceGetHandleByIndex(0) # Check if at least one GPU exists
self.has_gpu = True
except pynvml.NVMLError:
pass

## public

def log_prime(self, metric: dict[str, Any]):
if self.disable:
return
if not (self._send_message_prime(metric)):
print(f"Prime logging failed: {metric}")

## private

@classmethod
def _get_default_socket_path(cls) -> str:
"""Returns the default socket path based on the operating system."""
default = (
"/tmp/com.prime.miner/metrics.sock"
if platform.system() == "Darwin"
else "/var/run/com.prime.miner/metrics.sock"
)
return os.getenv("PRIME_TASK_BRIDGE_SOCKET", default=default)

def _send_message_prime(self, metric: dict, socket_path: str = None) -> bool:
"""Sends a message to the specified socket path or uses the default if none is provided."""
socket_path = socket_path or os.getenv("PRIME_TASK_BRIDGE_SOCKET", self._get_default_socket_path())
# print("Sending message to socket: ", socket_path)

task_id = os.getenv("PRIME_TASK_ID", None)
if task_id is None:
print("No task ID found, skipping logging to Prime")
return False
try:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.connect(socket_path)

for key, value in metric.items():
message = {"label": key, "value": value, "task_id": task_id}
sock.sendall(json.dumps(message).encode())
return True
except Exception:
return False

### background system metrics

def _start_metrics_thread(self):
"""Starts the metrics collection thread"""
if self._thread is not None:
return
self._stop_event.clear()
self._thread = threading.Thread(target=self._collect_metrics)
self._thread.daemon = True
self._thread.start()

def _stop_metrics_thread(self):
"""Stops the metrics collection thread"""
if self._thread is None:
return
self._stop_event.set()
self._thread.join()
self._thread = None

def _collect_metrics(self):
while not self._stop_event.is_set():
metrics = {
"cpu_percent": psutil.cpu_percent(),
"memory_percent": psutil.virtual_memory().percent,
"memory_usage": psutil.virtual_memory().used,
"memory_total": psutil.virtual_memory().total,
}

if self.has_gpu:
gpu_count = pynvml.nvmlDeviceGetCount()
for i in range(gpu_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu_util = pynvml.nvmlDeviceGetUtilizationRates(handle)

metrics.update(
{
f"gpu_{i}_memory_used": info.used,
f"gpu_{i}_memory_total": info.total,
f"gpu_{i}_utilization": gpu_util.gpu,
}
)

self.log_prime(metrics)
time.sleep(self.period)

def __del__(self):
if hasattr(self, "_thread") and self._thread is not None:
# need to check hasattr because __del__ sometine delete attributes befores
self._stop_metrics_thread()
39 changes: 1 addition & 38 deletions src/genesys/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import random
import base64
import threading
import socket
import platform

from google.cloud import storage
from google.oauth2 import service_account
from queue import Queue
Expand Down Expand Up @@ -134,39 +133,3 @@ def extract_json(text):
return json.loads(json_str)
except json.JSONDecodeError:
raise ValueError("Failed to parse JSON from the extracted content")


def get_default_socket_path() -> str:
"""Returns the default socket path based on the operating system."""
default = (
"/tmp/com.prime.miner/metrics.sock"
if platform.system() == "Darwin"
else "/var/run/com.prime.miner/metrics.sock"
)
return os.getenv("PRIME_TASK_BRIDGE_SOCKET", default=default)


def send_message_prime(metric: dict, socket_path: str = None) -> bool:
"""Sends a message to the specified socket path or uses the default if none is provided."""
socket_path = socket_path or os.getenv("PRIME_TASK_BRIDGE_SOCKET", get_default_socket_path())
# print("Sending message to socket: ", socket_path)

task_id = os.getenv("PRIME_TASK_ID", None)
if task_id is None:
print("No task ID found, skipping logging to Prime")
return False
try:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.connect(socket_path)

for key, value in metric.items():
message = {"label": key, "value": value, "task_id": task_id}
sock.sendall(json.dumps(message).encode())
return True
except Exception:
return False


def log_prime(metric: dict):
if not (send_message_prime(metric)):
print(f"Prime logging failed: {metric}")
24 changes: 24 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 42383e4

Please sign in to comment.