Skip to content

Commit

Permalink
fix: update tests and requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
future-xy committed Feb 14, 2025
1 parent 7c24a00 commit 597011a
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
matplotlib>=3.10
numpy>=1.26.4
pandas>=2.2.3
requests>=2.31.0
seaborn>=0.13.2
124 changes: 124 additions & 0 deletions tests/test_trace_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import unittest
from datetime import datetime
from unittest.mock import MagicMock, mock_open, patch

import pandas as pd

from tracestorm.constants import AZURE_DATASET_PATHS
from tracestorm.trace_generator import (
AzureTraceGenerator,
SyntheticTraceGenerator,
)


class TestSyntheticTraceGenerator(unittest.TestCase):
def test_uniform_distribution(self):
"""Test uniform distribution pattern."""
generator = SyntheticTraceGenerator(
rps=2, pattern="uniform", duration=3
)
result = generator.generate()
expected = [0, 500, 1000, 1500, 2000, 2500]
self.assertEqual(result, expected)

def test_invalid_rps(self):
"""Test invalid RPS value."""
with self.assertRaises(ValueError) as context:
SyntheticTraceGenerator(rps=-1, pattern="uniform", duration=10)
self.assertEqual(
str(context.exception), "rps must be a non-negative integer"
)

def test_invalid_duration(self):
"""Test invalid duration value."""
with self.assertRaises(ValueError) as context:
SyntheticTraceGenerator(rps=1, pattern="uniform", duration=-5)
self.assertEqual(
str(context.exception), "duration must be a non-negative integer"
)

def test_zero_requests(self):
"""Test zero requests case."""
generator = SyntheticTraceGenerator(
rps=0, pattern="uniform", duration=10
)
result = generator.generate()
self.assertEqual(result, [])

def test_invalid_pattern(self):
"""Test invalid pattern."""
generator = SyntheticTraceGenerator(
rps=1, pattern="invalid", duration=10
)
with self.assertRaises(ValueError) as context:
generator.generate()
self.assertEqual(str(context.exception), "Unknown pattern: invalid")


class TestAzureTraceGenerator(unittest.TestCase):
def setUp(self):
self.generator = AzureTraceGenerator(dataset_type="code")

def test_invalid_dataset_type(self):
"""Test initialization with invalid dataset type."""
with self.assertRaises(ValueError) as context:
AzureTraceGenerator(dataset_type="invalid")
self.assertEqual(
str(context.exception),
"Invalid dataset type. Please choose 'code' or 'conv'.",
)

@patch("tracestorm.trace_generator.requests.get")
def test_download_file_success(self, mock_requests):
"""Test successful file download."""
mock_requests.return_value.status_code = 200
mock_requests.return_value.content = b"Test content"

with patch("builtins.open", mock_open()) as mock_file:
self.generator._download_file("test/path.csv", "local.csv")

mock_requests.assert_called_once()
mock_file.assert_called_once_with("local.csv", "wb")
mock_file().write.assert_called_once_with(b"Test content")

@patch("tracestorm.trace_generator.requests.get")
def test_download_file_failure(self, mock_requests):
"""Test file download failure."""
mock_requests.return_value.status_code = 404
mock_requests.return_value.text = "Not Found"

with self.assertRaises(Exception) as context:
self.generator._download_file("test/path.csv", "local.csv")
self.assertIn("Failed to download file: 404", str(context.exception))

def test_process_dataset(self):
"""Test dataset processing."""
# Create test DataFrame
dates = [
datetime(2023, 1, 1, 12, 0, 0),
datetime(2023, 1, 1, 12, 0, 1),
datetime(2023, 1, 1, 12, 0, 2),
]
test_df = pd.DataFrame({"TIMESTAMP": dates})

with patch("pandas.read_csv", return_value=test_df):
result = self.generator._process_dataset("dummy_path.csv")

expected = [0, 1000, 2000] # timestamps in milliseconds
self.assertEqual(result, expected)

@patch("tracestorm.trace_generator.AzureTraceGenerator._download_file")
@patch("tracestorm.trace_generator.AzureTraceGenerator._process_dataset")
def test_generate(self, mock_process, mock_download):
"""Test the generate method."""
mock_process.return_value = [0, 1000, 2000]

result = self.generator.generate()

self.assertEqual(result, [0, 1000, 2000])
mock_download.assert_called_once()
mock_process.assert_called_once()


if __name__ == "__main__":
unittest.main()

0 comments on commit 597011a

Please sign in to comment.