Skip to content

Commit

Permalink
fix: patch N1081B, adding timeout to connect()
Browse files Browse the repository at this point in the history
- also update tests
  • Loading branch information
furkan-bilgin committed Oct 16, 2024
1 parent c84735d commit e9700c1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
15 changes: 13 additions & 2 deletions src/daq/jobs/caen/n1081b.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass

from N1081B import N1081B
from websocket import WebSocket
from websocket import WebSocket, create_connection

from daq.base import DAQJob
from daq.models import DAQJobMessage
Expand All @@ -21,14 +21,25 @@ class DAQJobN1081BConfig(StorableDAQJobConfig):
sections_to_store: list[str]


class N1081BPatched(N1081B):
def __init__(self, ip):
super().__init__(ip)

def connect(self):
self.ws = create_connection(
self.API_ENDPOINT, timeout=N1081B_WEBSOCKET_TIMEOUT_SECONDS
)
return self.ws.connected


class DAQJobN1081B(DAQJob):
config_type = DAQJobN1081BConfig
device: N1081B
config: DAQJobN1081BConfig

def __init__(self, config: DAQJobN1081BConfig):
super().__init__(config)
self.device = N1081B(f"{config.host}:{config.port}?")
self.device = N1081BPatched(f"{config.host}:{config.port}?")

for section in config.sections_to_store:
if section not in N1081B.Section.__members__:
Expand Down
33 changes: 17 additions & 16 deletions src/tests/test_n1081b.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import unittest
from unittest.mock import MagicMock, patch

from N1081B import N1081B
from websocket import WebSocket

from daq.jobs.caen.n1081b import DAQJobN1081B, DAQJobN1081BConfig
from daq.jobs.caen.n1081b import DAQJobN1081B, DAQJobN1081BConfig, N1081BPatched
from daq.store.models import DAQJobMessageStore


Expand All @@ -20,30 +19,30 @@ def setUp(self):
)
self.daq_job = DAQJobN1081B(self.config)

@patch.object(N1081B, "connect", return_value=True)
@patch.object(N1081B, "login", return_value=True)
@patch.object(N1081BPatched, "connect", return_value=True)
@patch.object(N1081BPatched, "login", return_value=True)
def test_connect_to_device_success(self, mock_login, mock_connect):
self.daq_job.device.ws = MagicMock(spec=WebSocket)
self.daq_job._connect_to_device()
mock_connect.assert_called_once()
mock_login.assert_called_once_with("password")
self.assertTrue(isinstance(self.daq_job.device.ws, WebSocket))

@patch.object(N1081B, "connect", return_value=False)
@patch.object(N1081BPatched, "connect", return_value=False)
def test_connect_to_device_failure(self, mock_connect):
with self.assertRaises(Exception) as context:
self.daq_job._connect_to_device()
self.assertTrue("Connection failed" in str(context.exception))

@patch.object(N1081B, "login", return_value=False)
@patch.object(N1081B, "connect", return_value=True)
@patch.object(N1081BPatched, "login", return_value=False)
@patch.object(N1081BPatched, "connect", return_value=True)
def test_login_failure(self, mock_connect, mock_login):
with self.assertRaises(Exception) as context:
self.daq_job._connect_to_device()
self.assertTrue("Login failed" in str(context.exception))

@patch.object(
N1081B,
N1081BPatched,
"get_function_results",
return_value={"data": {"counters": [{"lemo": 1, "value": 100}]}},
)
Expand All @@ -53,7 +52,7 @@ def test_poll_sections(self, mock_get_function_results):
self.daq_job._send_store_message.assert_called()
self.assertEqual(self.daq_job._send_store_message.call_count, 2)

@patch.object(N1081B, "get_function_results", return_value={"data": {}})
@patch.object(N1081BPatched, "get_function_results", return_value={"data": {}})
def test_poll_sections_no_counters(self, mock_get_function_results):
self.daq_job._send_store_message = MagicMock()
with self.assertRaises(Exception) as context:
Expand All @@ -73,7 +72,7 @@ def test_start(
mock_connect_to_device.assert_called_once()
mock_poll_sections.assert_called_once()

@patch.object(N1081B, "get_function_results", return_value=None)
@patch.object(N1081BPatched, "get_function_results", return_value=None)
def test_poll_sections_no_results(self, mock_get_function_results):
self.daq_job._send_store_message = MagicMock()
with self.assertRaises(Exception) as context:
Expand All @@ -82,7 +81,7 @@ def test_poll_sections_no_results(self, mock_get_function_results):
self.daq_job._send_store_message.assert_not_called()

@patch.object(
N1081B,
N1081BPatched,
"get_function_results",
return_value={"data": {"counters": [{"lemo": 1, "value": 100}]}},
)
Expand Down Expand Up @@ -112,8 +111,8 @@ def test_invalid_section_in_config(self):
DAQJobN1081B(invalid_config)
self.assertTrue("Invalid section: INVALID_SECTION" in str(context.exception))

@patch.object(N1081B, "connect", return_value=True)
@patch.object(N1081B, "login", return_value=True)
@patch.object(N1081BPatched, "connect", return_value=True)
@patch.object(N1081BPatched, "login", return_value=True)
def test_connect_to_device_timeout(self, mock_login, mock_connect):
self.daq_job.device.ws = MagicMock(spec=WebSocket)
self.daq_job.device.ws.settimeout = MagicMock(side_effect=Exception("Timeout"))
Expand All @@ -123,16 +122,18 @@ def test_connect_to_device_timeout(self, mock_login, mock_connect):
mock_connect.assert_called_once()
mock_login.assert_called_once_with("password")

@patch.object(N1081B, "get_function_results", side_effect=Exception("Timeout"))
@patch.object(
N1081BPatched, "get_function_results", side_effect=Exception("Timeout")
)
def test_poll_sections_timeout(self, mock_get_function_results):
self.daq_job._send_store_message = MagicMock()
with self.assertRaises(Exception) as context:
self.daq_job._poll_sections()
self.assertTrue("Timeout" in str(context.exception))
self.daq_job._send_store_message.assert_not_called()

@patch.object(N1081B, "connect", return_value=True)
@patch.object(N1081B, "login", return_value=True)
@patch.object(N1081BPatched, "connect", return_value=True)
@patch.object(N1081BPatched, "login", return_value=True)
def test_connect_to_device_no_websocket(self, mock_login, mock_connect):
self.daq_job.device.ws = None # type: ignore
with self.assertRaises(Exception) as context:
Expand Down

0 comments on commit e9700c1

Please sign in to comment.