Skip to content

Commit

Permalink
feat: trajectory export features
Browse files Browse the repository at this point in the history
  • Loading branch information
adeprez committed Sep 18, 2024
1 parent d39eca5 commit b834544
Show file tree
Hide file tree
Showing 22 changed files with 874 additions and 283 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from lavague.core.utilities.version_checker import check_latest_version
from lavague.utilities.version_checker import check_latest_version

from lavague.core.agent import WebAgent
from lavague.core.trajectory import Trajectory
from lavague.agent import WebAgent
from lavague.trajectory import Trajectory

import os
import warnings
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from lavague.core.action.base import (
from lavague.action.base import (
Action,
ActionStatus,
ActionParser,
DEFAULT_PARSER,
UnhandledTypeException,
ActionTranslator,
)

from lavague.core.action.navigation import NavigationAction
from lavague.action.navigation import NavigationAction

DEFAULT_PARSER.register("navigation", NavigationAction)
60 changes: 60 additions & 0 deletions lavague-core/lavague/action/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Dict, Type, Optional, Callable, TypeVar, Self
from pydantic import BaseModel, validate_call
from enum import Enum


class ActionStatus(Enum):
COMPLETED = "completed"
FAILED = "failed"


class Action(BaseModel):
"""Action performed by the agent."""

engine: str
action: str
status: ActionStatus

@classmethod
def parse(cls, action_dict: Dict) -> "Action":
return cls(**action_dict)

@classmethod
def add_translator(cls, name: str, translator: "ActionTranslator[Self]"):
setattr(cls, name, translator)


class ActionParser(BaseModel):
engine_action_builders: Dict[str, Type[Action]]

def __init__(self):
super().__init__(engine_action_builders={})

@validate_call
def register(self, engine: str, action: Type[Action]):
self.engine_action_builders[engine] = action

def unregister(self, engine: str):
if engine in self.engine_action_builders:
del self.engine_action_builders[engine]

def parse(self, action_dict: Dict) -> Action:
engine = action_dict.get("engine", "")
target_type: Type[Action] = self.engine_action_builders.get(engine, Action)
try:
return target_type.parse(action_dict)
except UnhandledTypeException:
return Action.parse(action_dict)


class UnhandledTypeException(Exception):
pass


T = TypeVar("T", bound=Action)


ActionTranslator = Callable[[T], Optional[str]]


DEFAULT_PARSER = ActionParser()
83 changes: 83 additions & 0 deletions lavague-core/lavague/action/navigation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from lavague.action import Action
from typing import ClassVar, Dict, Type, Optional, TypeVar

T = TypeVar("T", bound="NavigationAction")


class NavigationAction(Action):
"""Navigation action performed by the agent."""

subtypes: ClassVar[Dict[str, Type["NavigationAction"]]] = {}

xpath: str
value: Optional[str] = None

@classmethod
def parse(cls, action_dict: Dict) -> "NavigationAction":
action_name = action_dict.get("action", "")
target_type = cls.subtypes.get(action_name, NavigationAction)
return target_type(**action_dict)

@classmethod
def register_subtype(cls, subtype: str, action: Type[T]):
cls.subtypes[subtype] = action
return cls


def register_navigation(name: str):
def wrapper(cls: Type[T]) -> Type[T]:
NavigationAction.register_subtype(name, cls)
return cls

return wrapper


class NavigationWithValueAction(NavigationAction):
"""Navigation action performed by the agent with a value."""

value: str


@register_navigation("click")
class ClickAction(NavigationAction):
pass


@register_navigation("hover")
class HoverAction(NavigationAction):
pass


@register_navigation("setValue")
class SetValueAction(NavigationWithValueAction):
pass


@register_navigation("setValueAndEnter")
class SetValueAndEnterAction(SetValueAction):
pass


@register_navigation("dropdownSelect")
class DropdownSelectAction(NavigationWithValueAction):
pass


@register_navigation("scroll_down")
class ScrollDownAction(NavigationAction):
pass


@register_navigation("scroll_up")
class ScrollUpAction(NavigationAction):
pass


@register_navigation("back")
class BackAction(NavigationAction):
pass


@register_navigation("switch_tab")
class SwitchTabAction(NavigationAction):
pass
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from pydantic import BaseModel
from typing import Optional
from lavague.core.trajectory import Trajectory
from lavague.core.client import LaVagueClient
from lavague.core.utilities.config import get_config
from lavague.trajectory import Trajectory
from lavague.client import LaVagueClient
from lavague.utilities.config import get_config

logging_print = logging.getLogger(__name__)
logging_print.setLevel(logging.INFO)
Expand Down
27 changes: 27 additions & 0 deletions lavague-core/lavague/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import click
from lavague import WebAgent
import sys
from typing import Optional


def run(url: str, objective: str, file: Optional[str] = None):
agent = WebAgent()
trajectory = agent.run(url, objective)
if file:
trajectory.write_to_file(file)
else:
print(trajectory.model_dump_json(indent=2))


@click.command()
@click.argument("url", required=True)
@click.argument("objective", required=True)
@click.option("--file", "-f", required=False)
def cli_run(url: str, objective: str, file: Optional[str]):
run(url, objective, file)


if __name__ == "__main__":
url = sys.argv[1]
objective = sys.argv[2]
run(url, objective)
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from pydantic import BaseModel
from lavague.core.utilities.config import get_config, LAVAGUE_API_BASE_URL
from lavague.core.action import ActionParser, DEFAULT_PARSER
from lavague.core.trajectory import Trajectory
from pydantic_core import from_json
from lavague.utilities.config import get_config, is_flag_true, LAVAGUE_API_BASE_URL
from lavague.action import ActionParser, DEFAULT_PARSER
from lavague.trajectory import Trajectory
from typing import Any, Optional
import requests

Expand All @@ -14,28 +13,28 @@ class LaVagueClient(BaseModel):

api_base_url: str = get_config("LAVAGUE_API_BASE_URL", LAVAGUE_API_BASE_URL)
api_key: str = get_config("LAVAGUE_API_KEY")
telemetry: bool = is_flag_true("LAVAGUE_TELEMETRY", True)
parser: ActionParser = DEFAULT_PARSER

def request_api(self, endpoint: str, method: str, json: Optional[Any]) -> bytes:
headers = {
"Authorization": f"Bearer {self.api_key}",
}
if not self.telemetry:
headers["DNT"] = "1"
response = requests.request(
method,
f"{self.api_base_url}/{endpoint}",
json=json,
headers={
"Authorization": f"Bearer {self.api_key}",
},
headers=headers,
)
if response.status_code > 299:
raise ApiException(response.text)
return response.content

def run(self, url: str, objective: str) -> Trajectory:
content = self.request_api("/run", "POST", {"url": url, "objective": objective})
result = from_json(content)
result_list = result.get("results", [])
actions = [self.parser.parse(action) for action in result_list]
trajectory = Trajectory(**result, actions=actions)
return trajectory
return Trajectory.from_data(content, self.parser)


class ApiException(Exception):
Expand Down
104 changes: 0 additions & 104 deletions lavague-core/lavague/core/action/base.py

This file was deleted.

Loading

0 comments on commit b834544

Please sign in to comment.