Skip to content

Commit

Permalink
PENG-2456 move the SQL query logic from the *job_submissions_metrics*…
Browse files Browse the repository at this point in the history
… endpoint to a dedicated function in the helpers module
  • Loading branch information
matheushent committed Dec 13, 2024
1 parent 9aee2d6 commit 0c589c6
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 58 deletions.
73 changes: 72 additions & 1 deletion jobbergate-api/jobbergate_api/apps/job_submissions/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@

from collections.abc import Iterable
from math import ceil
from typing import Any, Type
from textwrap import dedent
from typing import Any, assert_never, Type

from loguru import logger

from jobbergate_api.apps.job_submissions.constants import (
JobSubmissionMetricSampleRate,
JobSubmissionMetricAggregateNames,
)


def _force_cast(object: Any, expected_type: Type[Any]) -> Any:
"""Forcefully cast a value to the expected type.
Expand Down Expand Up @@ -60,3 +66,68 @@ def validate_job_metric_upload_input(
logger.error(f"Failed to cast data to expected types: {e}")
raise ValueError("Failed to cast data to expected types.")
return data


def build_job_metric_aggregation_query(node: str | None, sample_rate: JobSubmissionMetricSampleRate) -> str:
"""
Build a SQL query string to aggregate job metrics based on the provided node and sample rate.
Args:
node (str | None): The node host identifier. If None, the query will aggregate metrics for all nodes.
sample_rate (JobSubmissionMetricSampleRate): The sample rate for the metrics aggregation. Determines the view name to use.
Returns:
str: The SQL query string for aggregating job metrics.
"""
if node is not None:
where_statement = "WHERE job_submission_id = :job_submission_id AND node_host = :node_host"
match sample_rate:
case JobSubmissionMetricSampleRate.ten_seconds:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_seconds_by_node
case JobSubmissionMetricSampleRate.one_minute:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_minute_by_node
case JobSubmissionMetricSampleRate.ten_minutes:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_minutes_by_node
case JobSubmissionMetricSampleRate.one_hour:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_hour_by_node
case JobSubmissionMetricSampleRate.one_week:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_week_by_node
case _ as unreachable:
assert_never(unreachable)
else:
where_statement = "WHERE job_submission_id = :job_submission_id"
match sample_rate:
case JobSubmissionMetricSampleRate.ten_seconds:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_seconds_all_nodes
case JobSubmissionMetricSampleRate.one_minute:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_minute_all_nodes
case JobSubmissionMetricSampleRate.ten_minutes:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_minutes_all_nodes
case JobSubmissionMetricSampleRate.one_hour:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_hour_all_nodes
case JobSubmissionMetricSampleRate.one_week:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_week_all_nodes
case _ as unreachable:
assert_never(unreachable)

return dedent(
f"""
SELECT bucket,
node_host,
cpu_frequency,
cpu_time,
cpu_utilization,
gpu_memory,
gpu_utilization,
page_faults,
memory_rss,
memory_virtual,
disk_read,
disk_write
FROM {view_name}
{where_statement}
AND bucket >= :start_time
AND bucket <= :end_time
ORDER BY bucket
"""
)
62 changes: 6 additions & 56 deletions jobbergate-api/jobbergate_api/apps/job_submissions/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
"""

from datetime import datetime, timedelta, timezone
from textwrap import dedent
from typing import Any, assert_never
from typing import Any

from fastapi import APIRouter, Depends, HTTPException, Path, Query, Body
from fastapi import Response as FastAPIResponse
Expand All @@ -23,7 +22,6 @@
JobSubmissionStatus,
slurm_job_state_details,
JobSubmissionMetricSampleRate,
JobSubmissionMetricAggregateNames,
)
from jobbergate_api.apps.job_submissions.schemas import (
ActiveJobSubmission,
Expand All @@ -39,7 +37,10 @@
JobSubmissionAgentMaxTimes,
JobSubmissionMetricSchema,
)
from jobbergate_api.apps.job_submissions.helpers import validate_job_metric_upload_input
from jobbergate_api.apps.job_submissions.helpers import (
validate_job_metric_upload_input,
build_job_metric_aggregation_query,
)
from jobbergate_api.apps.permissions import Permissions, can_bypass_ownership_check
from jobbergate_api.apps.schemas import ListParams
from jobbergate_api.email_notification import notify_submission_rejected
Expand Down Expand Up @@ -615,58 +616,7 @@ async def job_submissions_metrics(
)
end_time = end_time or datetime.now(tz=timezone.utc)

if node is not None:
where_statement = "WHERE job_submission_id = :job_submission_id AND node_host = :node_host"
match sample_rate:
case JobSubmissionMetricSampleRate.ten_seconds:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_seconds_by_node
case JobSubmissionMetricSampleRate.one_minute:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_minute_by_node
case JobSubmissionMetricSampleRate.ten_minutes:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_minutes_by_node
case JobSubmissionMetricSampleRate.one_hour:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_hour_by_node
case JobSubmissionMetricSampleRate.one_week:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_week_by_node
case _ as unreachable:
assert_never(unreachable)
else:
where_statement = "WHERE job_submission_id = :job_submission_id"
match sample_rate:
case JobSubmissionMetricSampleRate.ten_seconds:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_seconds_all_nodes
case JobSubmissionMetricSampleRate.one_minute:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_minute_all_nodes
case JobSubmissionMetricSampleRate.ten_minutes:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_minutes_all_nodes
case JobSubmissionMetricSampleRate.one_hour:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_hour_all_nodes
case JobSubmissionMetricSampleRate.one_week:
view_name = JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_week_all_nodes
case _ as unreachable:
assert_never(unreachable)

query = dedent(
f"""
SELECT bucket,
node_host,
cpu_frequency,
cpu_time,
cpu_utilization,
gpu_memory,
gpu_utilization,
page_faults,
memory_rss,
memory_virtual,
disk_read,
disk_write
FROM {view_name}
{where_statement}
AND bucket >= :start_time
AND bucket <= :end_time
ORDER BY bucket
"""
)
query = build_job_metric_aggregation_query(node, sample_rate)
query_params = {
"job_submission_id": job_submission_id,
"start_time": start_time,
Expand Down
102 changes: 101 additions & 1 deletion jobbergate-api/tests/apps/job_submissions/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
"""Core module for testing the helper functions of the job submissions app."""

from math import ceil
from textwrap import dedent

import pytest
from jobbergate_api.apps.job_submissions.helpers import validate_job_metric_upload_input, _force_cast

from jobbergate_api.apps.job_submissions.constants import (
JobSubmissionMetricAggregateNames,
JobSubmissionMetricSampleRate,
)
from jobbergate_api.apps.job_submissions.helpers import (
build_job_metric_aggregation_query,
_force_cast,
validate_job_metric_upload_input,
)


class TestValidateJobMetricUploadInput:
Expand Down Expand Up @@ -174,3 +185,92 @@ def test_force_cast_failure(self):
_force_cast("not a float", float)
with pytest.raises(TypeError):
_force_cast(123, list)

class TestBuildJobMetricAggregationQuery:
"""
Test suite for the `build_job_metric_aggregation_query` function.
This test suite contains various test cases to validate the behavior of the
`build_job_metric_aggregation_query` function under different scenarios. The
function is expected to build a SQL query string based on the provided node and sample rate.
Test Cases:
- `test_build_query_with_node`: Tests the function with a specific node.
- `test_build_query_without_node`: Tests the function without a specific node.
- `test_build_query_invalid_sample_rate`: Tests that the function raises an error for an invalid sample rate.
"""

def test_build_query_with_node(self):
"""
Test the `build_job_metric_aggregation_query` function with a specific node.
This test checks if the function correctly builds a SQL query string when provided with a node.
"""
node = "node1"
sample_rate = JobSubmissionMetricSampleRate.ten_seconds
expected_query = dedent(
f"""
SELECT bucket,
node_host,
cpu_frequency,
cpu_time,
cpu_utilization,
gpu_memory,
gpu_utilization,
page_faults,
memory_rss,
memory_virtual,
disk_read,
disk_write
FROM {JobSubmissionMetricAggregateNames.metrics_nodes_mv_10_seconds_by_node}
WHERE job_submission_id = :job_submission_id AND node_host = :node_host
AND bucket >= :start_time
AND bucket <= :end_time
ORDER BY bucket
"""
)
result = build_job_metric_aggregation_query(node, sample_rate)
assert result == expected_query

def test_build_query_without_node(self):
"""
Test the `build_job_metric_aggregation_query` function without a specific node.
This test checks if the function correctly builds a SQL query string when no node is provided.
"""
node = None
sample_rate = JobSubmissionMetricSampleRate.one_minute
expected_query = dedent(
f"""
SELECT bucket,
node_host,
cpu_frequency,
cpu_time,
cpu_utilization,
gpu_memory,
gpu_utilization,
page_faults,
memory_rss,
memory_virtual,
disk_read,
disk_write
FROM {JobSubmissionMetricAggregateNames.metrics_nodes_mv_1_minute_all_nodes}
WHERE job_submission_id = :job_submission_id
AND bucket >= :start_time
AND bucket <= :end_time
ORDER BY bucket
"""
)
result = build_job_metric_aggregation_query(node, sample_rate)
assert result == expected_query

def test_build_query_invalid_sample_rate(self):
"""
Test that `build_job_metric_aggregation_query` raises an error for an invalid sample rate.
This test checks that the function correctly raises an error when an invalid sample rate is provided.
"""
node = "node1"
sample_rate = "invalid_sample_rate" # Invalid sample rate
with pytest.raises(TypeError):
build_job_metric_aggregation_query(node, sample_rate)

0 comments on commit 0c589c6

Please sign in to comment.