Skip to content

Commit

Permalink
tests(agent): cover metrics module
Browse files Browse the repository at this point in the history
  • Loading branch information
rezib committed Oct 23, 2024
1 parent b6523d2 commit c799b08
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 6 deletions.
204 changes: 198 additions & 6 deletions slurmweb/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,23 @@
from unittest import mock
import tempfile
import os
import textwrap
import ipaddress

from flask import Blueprint
from rfl.authentication.user import AuthenticatedUser
from prometheus_client.parser import text_string_to_metric_families

from slurmweb.version import get_version
from slurmweb.apps import SlurmwebConfSeed
from slurmweb.apps.agent import SlurmwebAppAgent
from slurmweb.slurmrestd.errors import (
SlurmrestConnectionError,
SlurmrestdNotFoundError,
SlurmrestdInvalidResponseError,
SlurmrestdInternalError,
)
from slurmweb.errors import SlurmwebCacheError

from .utils import (
all_slurm_versions,
Expand Down Expand Up @@ -48,8 +54,9 @@ def __init__(self, **kwargs):
super().__init__("Fake RacksDB web blueprint", __name__)


class TestAgent(unittest.TestCase):
def setUp(self):
class TestAgentBase(unittest.TestCase):

def setup_client(self, additional_conf=None):
# Generate JWT signing key
key = tempfile.NamedTemporaryFile(mode="w+")
key.write("hey")
Expand All @@ -67,7 +74,12 @@ def setUp(self):

# Generate configuration file
conf = tempfile.NamedTemporaryFile(mode="w+")
conf.write(CONF.format(key=key.name, policy_defs=policy_defs, policy=policy))
conf_content = CONF
if additional_conf is not None:
conf_content += additional_conf
conf.write(
conf_content.format(key=key.name, policy_defs=policy_defs, policy=policy)
)
conf.seek(0)

# Configuration definition path
Expand Down Expand Up @@ -104,13 +116,19 @@ def setUp(self):
self.client = self.app.test_client()
self.client.environ_base["HTTP_AUTHORIZATION"] = "Bearer " + token

def mock_slurmrestd_responses(self, slurm_version, assets):
return mock_slurmrestd_responses(self.app.slurmrestd, slurm_version, assets)


class TestAgent(TestAgentBase):

def setUp(self):
self.setup_client()

#
# Generic routes (without slurmrestd requests)
#

def mock_slurmrestd_responses(self, slurm_version, assets):
return mock_slurmrestd_responses(self.app.slurmrestd, slurm_version, assets)

def test_version(self):
response = self.client.get("/version")
self.assertEqual(response.status_code, 200)
Expand Down Expand Up @@ -566,3 +584,177 @@ def test_request_accounts(self, slurm_version):
self.assertEqual(len(response.json), len(accounts_asset))
for idx in range(len(response.json)):
self.assertEqual(response.json[idx]["name"], accounts_asset[idx]["name"])

def test_request_metrics(self):
# Metrics feature is disabled in this test case, check that the corresponding
# endpoint returns HTTP/404 (not found).
response = self.client.get("/metrics")
self.assertEqual(response.status_code, 404)


class TestAgentMetrics(TestAgentBase):

def setUp(self):
self.setup_client(
additional_conf=textwrap.dedent(
"""
[metrics]
enabled=yes
"""
)
)

def tearDown(self):
self.app.metrics.unregister()

@all_slurm_versions
def test_request_metrics(self, slurm_version):
try:
[nodes_asset, jobs_asset] = self.mock_slurmrestd_responses(
slurm_version,
[("slurm-nodes", "nodes"), ("slurm-jobs", "jobs")],
)
except SlurmwebAssetUnavailable:
return
response = self.client.get("/metrics")
self.assertEqual(response.status_code, 200)
families = list(text_string_to_metric_families(response.text))
# Check expected metrics are present
metrics_names = [family.name for family in families]
self.assertCountEqual(
[
"slurm_nodes",
"slurm_nodes_total",
"slurm_cores",
"slurm_cores_total",
"slurm_jobs",
"slurm_jobs_total",
],
metrics_names,
)
# Check some values against assets
for family in families:
if family.name == "slurm_nodes_total":
self.assertEqual(family.samples[0].value, len(nodes_asset))
if family.name == "slurm_jobs_total":
self.assertEqual(family.samples[0].value, len(jobs_asset))

def test_request_metrics_forbidden(self):
# Change restricted list of network allowed to request metrics
self.app.settings.metrics.restrict = [ipaddress.ip_network("192.168.1.0/24")]
with self.assertLogs("slurmweb", level="WARNING") as cm:
response = self.client.get("/metrics")

# Check HTTP/403 is returned with text message. Also check warning message is
# emitted in logs.
self.assertEqual(response.status_code, 403)
self.assertEqual(
response.text, "IP address 127.0.0.1 not authorized to request metrics\n"
)
self.assertEqual(
cm.output,
[
"WARNING:slurmweb.metrics:IP address 127.0.0.1 not authorized to "
"request metrics"
],
)

def test_request_metrics_slurmrest_connection_error(self):
self.app.slurmrestd._request = mock.Mock(
side_effect=SlurmrestConnectionError("connection error")
)
with self.assertLogs("slurmweb", level="ERROR") as cm:
response = self.client.get("/metrics")
# In case of connection error with slurmrestd, metrics WSGI application returns
# HTTP/200 empty response. Check error message is emitted in logs.
self.assertEqual(response.status_code, 200)
self.assertEqual(response.text, "")
self.assertEqual(
cm.output,
[
"ERROR:slurmweb.metrics:Unable to collect metrics due to slurmrestd "
"connection error: connection error"
],
)

def test_request_metrics_slurmrestd_invalid_type(self):
self.app.slurmrestd._request = mock.Mock(
side_effect=SlurmrestdInvalidResponseError("invalid type")
)
with self.assertLogs("slurmweb", level="ERROR") as cm:
response = self.client.get("/metrics")
# In case of invalid response from slurmrestd, metrics WSGI application returns
# HTTP/200 empty response. Check error message is emitted in logs.
self.assertEqual(response.status_code, 200)
self.assertEqual(response.text, "")
self.assertEqual(
cm.output,
[
"ERROR:slurmweb.metrics:Unable to collect metrics due to slurmrestd "
"invalid response: invalid type"
],
)

def test_request_metrics_slurmrestd_internal_error(self):
self.app.slurmrestd._request = mock.Mock(
side_effect=SlurmrestdInternalError(
"slurmrestd fake error",
-1,
"fake error description",
"fake error source",
)
)
with self.assertLogs("slurmweb", level="ERROR") as cm:
response = self.client.get("/metrics")
# In case of slurmrestd internal error, metrics WSGI application returns
# HTTP/200 empty response. Check error message is emitted in logs.
self.assertEqual(response.status_code, 200)
self.assertEqual(response.text, "")
self.assertEqual(
cm.output,
[
"ERROR:slurmweb.metrics:Unable to collect metrics due to slurmrestd "
"internal error: fake error description (fake error source)"
],
)

@all_slurm_versions
def test_request_metrics_slurmrestd_not_found(self, slurm_version):
self.app.slurmrestd._request = mock.Mock(
side_effect=SlurmrestdNotFoundError("/unfound")
)
with self.assertLogs("slurmweb", level="ERROR") as cm:
response = self.client.get("/metrics")
# In case of slurmrestd not found error, metrics WSGI application returns
# HTTP/200 empty response. Check error message is emitted in logs.
self.assertEqual(response.status_code, 200)
self.assertEqual(response.text, "")
self.assertEqual(
cm.output,
[
"ERROR:slurmweb.metrics:Unable to collect metrics due to URL not found "
"on slurmrestd: /unfound"
],
)

@all_slurm_versions
def test_request_metrics_cache_error(self, slurm_version):
# Collector first calls slurmrestd.nodes() then trigger SlurmwebCacheError on
# this method call.
self.app.slurmrestd.nodes = mock.Mock(
side_effect=SlurmwebCacheError("fake error")
)
with self.assertLogs("slurmweb", level="ERROR") as cm:
response = self.client.get("/metrics")
# In case of cache error, metrics WSGI application returns HTTP/200 empty
# response. Check error message is emitted in logs.
self.assertEqual(response.status_code, 200)
self.assertEqual(response.text, "")
self.assertEqual(
cm.output,
[
"ERROR:slurmweb.metrics:Unable to collect metrics due to cache error: "
"fake error"
],
)
46 changes: 46 additions & 0 deletions slurmweb/tests/test_slurmrestd.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,25 @@ def test_jobs(self, slurm_version):
jobs = self.slurmrestd.jobs()
self.assertCountEqual(jobs, asset)

@all_slurm_versions
def test_jobs_states(self, slurm_version):
try:
[asset] = self.mock_slurmrestd_responses(
slurm_version, [("slurm-jobs", "jobs")]
)
except SlurmwebAssetUnavailable:
return

jobs, total = self.slurmrestd.jobs_states()
# Check total value matches the number of jobs in asset
self.assertEqual(total, len(asset))

# Check sum of jobs states matches the total number of jobs
jobs_sum = 0
for value in jobs.values():
jobs_sum += value
self.assertEqual(total, jobs_sum)

@all_slurm_versions
def test_nodes(self, slurm_version):
try:
Expand All @@ -149,6 +168,33 @@ def test_nodes(self, slurm_version):
nodes = self.slurmrestd.nodes()
self.assertCountEqual(nodes, asset)

@all_slurm_versions
def test_nodes_cores_states(self, slurm_version):
try:
[asset] = self.mock_slurmrestd_responses(
slurm_version, [("slurm-nodes", "nodes")]
)
except SlurmwebAssetUnavailable:
return

nodes_states, cores_states, nodes_total, cores_total = (
self.slurmrestd.nodes_cores_states()
)
# Check total number of nodes matches the number of nodes in asset
self.assertEqual(nodes_total, len(asset))

# Check sum of nodes states matches the total number of nodes
nodes_sum = 0
for value in nodes_states.values():
nodes_sum += value
self.assertEqual(nodes_total, nodes_sum)

# Check sum of cores states matches the total number of cores
cores_sum = 0
for value in cores_states.values():
cores_sum += value
self.assertEqual(cores_total, cores_sum)

@all_slurm_versions
def test_node(self, slurm_version):
# We can use slurm-node-allocated asset for this test.
Expand Down

0 comments on commit c799b08

Please sign in to comment.