Skip to content

Commit

Permalink
feat: allow trajectory run
Browse files Browse the repository at this point in the history
  • Loading branch information
adeprez committed Sep 27, 2024
1 parent f1fcddc commit 16a543a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
3 changes: 3 additions & 0 deletions lavague-core/lavague/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ def run(self, url: str, objective: str, async_run=False) -> Trajectory:
if not async_run:
trajectory.run_to_completion()
return trajectory

def load(self, run_id: str) -> Trajectory:
return self.client.load_run(run_id)
23 changes: 22 additions & 1 deletion lavague-core/lavague/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from lavague.trajectory import Trajectory
from lavague.trajectory.controller import TrajectoryController
from typing import Any, Optional
from PIL import Image, ImageFile
from io import BytesIO
import requests


Expand Down Expand Up @@ -48,7 +50,11 @@ def create_run(self, url: str, objective: str, step_by_step=False) -> Trajectory
"POST",
{"url": url, "objective": objective, "step_by_step": step_by_step},
)
return Trajectory.from_data(content, self.parser)
return Trajectory.from_data(content, self.parser, self)

def load_run(self, run_id: str) -> Trajectory:
content = self.request_api(f"/runs/{run_id}", "GET")
return Trajectory.from_data(content, self.parser, self)

def next_step(self, run_id: str) -> StepCompletion:
content = self.request_api(
Expand All @@ -63,6 +69,21 @@ def stop_run(self, run_id: str) -> None:
"POST",
)

def get_preaction_screenshot(self, step_id: str) -> ImageFile.ImageFile:
content = self.request_api(f"/steps/{step_id}/screenshot/preaction", "GET")
return Image.open(BytesIO(content))

def get_postaction_screenshot(self, step_id: str) -> ImageFile.ImageFile:
content = self.request_api(f"/steps/{step_id}/screenshot/preaction", "GET")
return Image.open(BytesIO(content))

def get_run_screenshot(self, run_id: str) -> ImageFile.ImageFile:
content = self.request_api(f"/runs/{run_id}/screenshot", "GET")
return Image.open(BytesIO(content))

def get_run_view_url(self, run_id: str) -> str:
return f"{self.api_base_url}/runs/{run_id}/view"


class ApiException(Exception):
pass
9 changes: 3 additions & 6 deletions lavague-core/lavague/trajectory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from lavague.action import ActionParser, DEFAULT_PARSER
from lavague.trajectory.controller import TrajectoryController
from lavague.trajectory.model import TrajectoryData, RunStatus
from lavague.action import Action


class Trajectory(TrajectoryData):
Expand Down Expand Up @@ -33,15 +34,11 @@ def stop_run(self):
self._controller.stop_run(self.run_id)
self.status = RunStatus.CANCELLED

def iter(self) -> Iterator:
def iter_actions(self) -> Iterator[Action]:
yield from self.actions
while self.is_running:
yield self.next_action()

def __next__(self):
if not self.is_running:
raise StopIteration
return self.next_action()

@classmethod
def from_data(
cls,
Expand Down

0 comments on commit 16a543a

Please sign in to comment.