Skip to content

Commit

Permalink
workflow improvements
Browse files Browse the repository at this point in the history
* always stream the script output
* make terminal output (breaks, spaces, etc) more legible
* remove duplicate prints
* always say "Retrying..." on error
  • Loading branch information
granawkins committed Feb 17, 2024
1 parent fe57bda commit 39d9042
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 63 deletions.
44 changes: 19 additions & 25 deletions src/rawdog/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,42 @@


def rawdog(prompt: str, config, llm_client):
llm_client.add_message("user", prompt)
leash = config.get("leash")
retries = int(config.get("retries"))
_continue = True
_first = True
while _continue is True:
_continue = False
error, script, output, return_code = "", "", "", 0
try:
if _first:
message, script = llm_client.get_script(prompt, stream=leash)
_first = False
else:
message, script = llm_client.get_script(stream=leash)
if leash:
print(80 * "-")
message, script = llm_client.get_script()
if script:
if leash:
print(f"\n{80 * '-'}")
if (
input("Execute script in markdown block? (Y/n): ")
.strip()
.lower()
== "n"
):
_ok = input(
f"\n{38 * '-'} Execute script in markdown block? (Y/n):"
)
if _ok.strip().lower() == "n":
llm_client.add_message("user", "User chose not to run script")
break
output, error, return_code = execute_script(script, llm_client)
elif message:
elif not leash and message:
print(message)
except KeyboardInterrupt:
break

_continue = (output and output.strip().endswith("CONTINUE")) or (
return_code != 0 and error and retries > 0
)
if error:
retries -= 1
llm_client.add_message("user", f"Error: {error}")
print(f"Error: {error}")
if script and not leash:
print(f"{80 * '-'}\n{script}\n{80 * '-'}")
if output:
llm_client.add_message("user", f"LAST SCRIPT OUTPUT:\n{output}")
if leash or not _continue:
print(output)
if output.endswith("CONTINUE"):
_continue = True
if error:
llm_client.add_message("user", f"Error: {error}")
if return_code != 0:
retries -= 1
if retries > 0:
print("Retrying...\n")
_continue = True


def banner(config):
Expand Down
98 changes: 65 additions & 33 deletions src/rawdog/execute_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,39 +36,71 @@ def install_pip_packages(*packages: str):
)


def execute_script(script: str, llm_client) -> tuple[str, str, int]:
python_executable = get_rawdog_python_executable()
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp_script:
tmp_script_name = tmp_script.name
tmp_script.write(script)
tmp_script.flush()
def _execute_script_in_subprocess(script) -> tuple[str, str, int]:
"""Write script to tempfile, execute from .rawdog/venv, stream and return output"""
output, error, return_code = "", "", 0
try:
python_executable = get_rawdog_python_executable()
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp_script:
tmp_script_name = tmp_script.name
tmp_script.write(script)
tmp_script.flush()

retry = True
while retry:
retry = False
result = subprocess.run(
[python_executable, tmp_script_name], capture_output=True, text=True
process = subprocess.Popen(
[python_executable, tmp_script_name],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.DEVNULL, # Raises EOF error if subprocess asks for input
text=True,
)
output = result.stdout
error = result.stderr
return_code = result.returncode
if error and "ModuleNotFoundError: No module named" in error:
match = re.search(r"No module named '(\w+)'", error)
if match:
module = match.group(1)
module_name = llm_client.get_python_package(module)
if (
input(
f"Rawdog wants to use {module_name}. Install to rawdog's"
" venv with pip? (Y/n): "
)
.strip()
.lower()
!= "n"
):
install_result = install_pip_packages(module_name)
if install_result.returncode == 0:
retry = True
else:
print("Failed to install package")
while True:
_stdout = process.stdout.readline()
_stderr = process.stderr.readline()
if _stdout:
output += _stdout
print(_stdout, end="")
if _stderr:
error += _stderr
print(_stderr, end="", file=sys.stderr)
if _stdout == "" and _stderr == "" and process.poll() is not None:
break
return_code = process.returncode
except Exception as e:
error += str(e)
print(e)
return_code = 1
return output, error, return_code


def _execute_script_with_dependency_resolution(
script, llm_client
) -> tuple[str, str, int]:
retry = True
output, error, return_code = "", "", 0
while retry:
retry = False
output, error, return_code = _execute_script_in_subprocess(script)
if error and "ModuleNotFoundError: No module named" in error:
match = re.search(r"No module named '(\w+)'", error)
if match:
module = match.group(1)
module_name = llm_client.get_python_package(module)
if (
input(
f"Rawdog wants to use {module_name}. Install to rawdog's"
" venv with pip? (Y/n): "
)
.strip()
.lower()
!= "n"
):
install_result = install_pip_packages(module_name)
if install_result.returncode == 0:
retry = True
else:
print("Failed to install package")
return output, error, return_code


def execute_script(script: str, llm_client) -> tuple[str, str, int]:
return _execute_script_with_dependency_resolution(script, llm_client)
6 changes: 2 additions & 4 deletions src/rawdog/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import os
from textwrap import dedent
from typing import Optional

from litellm import completion, completion_cost

Expand Down Expand Up @@ -75,15 +74,14 @@ def get_python_package(self, import_name: str):

return response.choices[0].message.content

def get_script(self, prompt: Optional[str] = None, stream=False):
if prompt:
self.conversation.append({"role": "user", "content": prompt})
def get_script(self):
messages = self.conversation.copy()

base_url = self.config.get("llm_base_url")
model = self.config.get("llm_model")
temperature = self.config.get("llm_temperature")
custom_llm_provider = self.config.get("llm_custom_provider")
stream = self.config.get("leash")

log = {
"model": model,
Expand Down
2 changes: 1 addition & 1 deletion src/rawdog/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def parse_script(response: str) -> tuple[str, str]:
# Parse delimiter
n_delimiters = response.count("```")
if n_delimiters < 2:
return f"Error: No script found in response:\n{response}", ""
return response, ""
segments = response.split("```")
message = f"{segments[0]}\n{segments[-1]}"
script = "```".join(segments[1:-1]).strip() # Leave 'inner' delimiters alone
Expand Down

0 comments on commit 39d9042

Please sign in to comment.