Skip to content

Commit

Permalink
Apply override only to local generation, do not propagate overridden …
Browse files Browse the repository at this point in the history
…Runner to child components
  • Loading branch information
pbourke committed Aug 8, 2024
1 parent fe47fbb commit 42e0025
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
11 changes: 7 additions & 4 deletions sammo/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ async def _call(
dynamic_context: frozendict | None,
priority: int = 0,
) -> LLMResult:
if self._override_runner is not None:
runner = self._override_runner

y = await self._child(runner, context, dynamic_context)
parents = [y]
if self._history:
Expand All @@ -98,8 +95,14 @@ async def _call(
history = previous_turn.history
else:
history = None

try:
result = await runner.generate_text(
if self._override_runner is not None:
runner_for_generation = self._override_runner
else:
runner_for_generation = runner

result = await runner_for_generation.generate_text(
y.value,
priority=priority,
system_prompt=self._system_prompt,
Expand Down
10 changes: 10 additions & 0 deletions sammo/components_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,13 @@ async def test_override_runner():
res = await res2(runner2, dict())
assert runner2.prompt_log[0] == "I got test1"
assert res.value == "test2"


@pytest.mark.asyncio
async def test_child_runner_not_overridden():
runner1 = MockedRunner("test1")
runner2 = MockedRunner("test2")
res2 = GenerateText(Template("I got {{res1}}", res1=GenerateText("Get test1")), runner=runner2)
res = await res2(runner1, dict())
assert runner2.prompt_log[0] == "I got test1"
assert res.value == "test2"

0 comments on commit 42e0025

Please sign in to comment.