Skip to content

Commit

Permalink
feat(embed): add unittest for embed weights endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
larisa17 committed Jan 31, 2025
1 parent a47cc10 commit 33b4f46
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions api/embed/test/test_api_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from unittest.mock import patch

from django.test import TestCase
from ninja.testing import TestClient

from embed.api import api_router


class TestGetEmbedWeights(TestCase):
def setUp(self):
self.client = TestClient(api_router)

@patch("embed.api.handle_get_scorer_weights")
def test_get_embed_weights_no_community(self, mock_handle_get_scorer_weights):
mock_handle_get_scorer_weights.return_value = {"weight1": 0.5, "weight2": 1.0}

response = self.client.get("/weights")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"weight1": 0.5, "weight2": 1.0})
mock_handle_get_scorer_weights.assert_called_once_with(None)

@patch("embed.api.handle_get_scorer_weights")
def test_get_embed_weights_with_community(self, mock_handle_get_scorer_weights):
mock_handle_get_scorer_weights.return_value = {"weightA": 0.7, "weightB": 0.3}

response = self.client.get("/weights?community_id=community123")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json(), {"weightA": 0.7, "weightB": 0.3})
mock_handle_get_scorer_weights.assert_called_once_with("community123")

0 comments on commit 33b4f46

Please sign in to comment.