Skip to content

Commit

Permalink
test streaming edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
mpnowacki-reef committed Jan 30, 2025
1 parent b00cb89 commit 8634f0e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import uuid
from collections.abc import Sequence
from datetime import timedelta
Expand All @@ -14,11 +15,13 @@
)


async def check_synthetic_job(job_uuid: uuid.UUID, miner_id: int, status: str, score: float):
async def check_synthetic_job(job_uuid: uuid.UUID, miner_id: int, status: str, score: float, comment: re.Pattern | None = None):
job = await SyntheticJob.objects.aget(job_uuid=job_uuid)
assert job.miner_id == miner_id, f"{job.miner_id} != {miner_id}"
assert job.status == status, f"{job.status} != {status}"
assert job.score == score, f"{job.score} != {score}"
if comment:
assert comment.match(job.comment), f"{job.comment} does not match {comment}"


async def check_miner_job_system_events(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ async def create(self, executor_class: ExecutorClass, **kwargs) -> BaseSynthetic

class LlmPromptsSyntheticJobGeneratorFactory:
def __init__(
self, uuids: list[uuid.UUID], prompt_samples: list[PromptSample], prompts: list[Prompt]
self, uuids: list[uuid.UUID], prompt_samples: list[PromptSample], prompts: list[Prompt],
streaming:bool=False,
):
self._uuids = uuids
self._prompt_samples = prompt_samples
self._prompts = prompts
self._streaming = streaming

async def create(
self, executor_class: ExecutorClass, *args, **kwargs
Expand All @@ -96,7 +98,7 @@ async def create(
expected_prompts=self._prompts,
s3_url="mock",
seed=0,
streaming=False,
streaming=self._streaming,
)
generator._uuid = self._uuids.pop(0)
return generator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from unittest.mock import patch

import bittensor
import httpx
import pytest
import pytest_asyncio
from compute_horde.certificate import generate_certificate_at
from compute_horde.executor_class import DEFAULT_LLM_EXECUTOR_CLASS
from compute_horde.miner_client.base import AbstractTransport
from compute_horde.mv_protocol import miner_requests
Expand All @@ -28,6 +30,7 @@

from .helpers import check_miner_job_system_events, check_synthetic_job, generate_prompts
from .mock_generator import NOT_SCORED, LlmPromptsSyntheticJobGeneratorFactory
from ...synthetic_jobs.generator import llm_prompts

pytestmark = [
pytest.mark.asyncio,
Expand Down Expand Up @@ -94,6 +97,11 @@ def _create(ctx: BatchContext, miner_hotkey: str):
return _create


@pytest.fixture
def ssl_public_key():
return generate_certificate_at()[1]


async def test_all_succeed(
axon_dict: dict[str, bittensor.AxonInfo],
transports: list[MinerSimulationTransport],
Expand Down Expand Up @@ -132,17 +140,8 @@ async def test_all_succeed(
await check_synthetic_job(job_uuid, miner.pk, SyntheticJob.Status.COMPLETED, 1)


async def test_all_streaming_succeed(
axon_dict: dict[str, bittensor.AxonInfo],
transports: list[MinerSimulationTransport],
miners: list[Miner],
create_simulation_miner_client: Callable,
job_uuids: list[uuid.UUID],
streaming_manifest_message: str,
httpx_mock: HTTPXMock,
mocker: MockerFixture,
settings,
):

async def prep_mocks_for_streaming(mocker: MockerFixture, httpx_mock:HTTPXMock, job_uuids: list[uuid.UUID], settings):
prompts, prompt_samples = await generate_prompts(num_miners=len(job_uuids))
mocker.patch(
"compute_horde_validator.validator.synthetic_jobs.batch_run.get_streaming_job_executor_classes",
Expand All @@ -151,19 +150,46 @@ async def test_all_streaming_succeed(
mocker.patch(
"compute_horde_validator.validator.synthetic_jobs.generator.current.synthetic_job_generator_factory",
LlmPromptsSyntheticJobGeneratorFactory(
uuids=job_uuids.copy(), prompt_samples=prompt_samples, prompts=prompts
uuids=job_uuids.copy(), prompt_samples=prompt_samples, prompts=prompts, streaming=True,
),
)
mocker.patch.object(llm_prompts, 'STREAMING_PROCESSING_TIMEOUT', 1)
mocker.patch.object(llm_prompts, 'STREAMING_PROCESSING_TIMEOUT_LEEWAY', 0.5)

httpx_mock.add_response(
url=re.compile(
get_public_url(key=".*", bucket_name=settings.S3_BUCKET_NAME_ANSWERS, prefix="solved/")
),
json={p.content: p.answer for p in prompts},
)

async def sleepy_request(*_):
await asyncio.sleep(2)
return httpx.Response(201)

httpx_mock.add_callback(sleepy_request, url=re.compile('https://127.0.0.1:8004.*'))


@pytest.mark.override_config(
DYNAMIC_SYNTHETIC_STREAMING_JOB_READY_TIMEOUT=0.5,
)
async def test_some_streaming_succeed(
axon_dict: dict[str, bittensor.AxonInfo],
transports: list[MinerSimulationTransport],
miners: list[Miner],
create_simulation_miner_client: Callable,
job_uuids: list[uuid.UUID],
streaming_manifest_message: str,
httpx_mock: HTTPXMock,
mocker: MockerFixture,
ssl_public_key: str,
settings,
):
await prep_mocks_for_streaming(mocker, httpx_mock, job_uuids, settings)
# generator will solve to the right answer
MOCK_SCORE = 1.0

port = 8000
for job_uuid, transport in zip(job_uuids, transports):
await transport.add_message(streaming_manifest_message, send_before=1)

Expand All @@ -175,28 +201,34 @@ async def test_all_streaming_succeed(
).model_dump_json()
await transport.add_message(executor_ready_message, send_before=0)

streaming_ready_message = miner_requests.V0StreamingJobReadyRequest(
job_uuid=str(job_uuid), public_key="123", ip="127.0.0.1", port=8000
).model_dump_json()
await transport.add_message(streaming_ready_message, send_before=0)
if job_uuid != job_uuids[-1]:
streaming_ready_message = miner_requests.V0StreamingJobReadyRequest(
job_uuid=str(job_uuid), public_key=ssl_public_key, ip="127.0.0.1", port=(port := port + 1)
).model_dump_json()
await transport.add_message(streaming_ready_message, send_before=0)

job_finish_message = miner_requests.V0JobFinishedRequest(
job_uuid=str(job_uuid), docker_process_stdout="", docker_process_stderr=""
).model_dump_json()
job_finish_message = miner_requests.V0JobFinishedRequest(
job_uuid=str(job_uuid), docker_process_stdout="", docker_process_stderr=""
).model_dump_json()

await transport.add_message(job_finish_message, send_before=2)
await transport.add_message(job_finish_message, send_before=2)

await asyncio.wait_for(
execute_synthetic_batch_run(
axon_dict,
miners,
create_miner_client=create_simulation_miner_client,
),
timeout=1,
timeout=10,
)

for job_uuid, miner in zip(job_uuids, miners):
await check_synthetic_job(job_uuid, miner.pk, SyntheticJob.Status.COMPLETED, MOCK_SCORE)
if job_uuid == job_uuids[-1]:
await check_synthetic_job(job_uuid, miner.pk, SyntheticJob.Status.FAILED, 0)
elif job_uuid == job_uuids[-2]:
await check_synthetic_job(job_uuid, miner.pk, SyntheticJob.Status.FAILED, 0, re.compile("took too long: time_took_sec=.*"))
else:
await check_synthetic_job(job_uuid, miner.pk, SyntheticJob.Status.COMPLETED, MOCK_SCORE)


@pytest_asyncio.fixture
Expand Down

0 comments on commit 8634f0e

Please sign in to comment.