Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Networkx -> Rustworkx Migration #1

Merged
merged 5 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions downward.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# FastDownward python wrapper

from typing import Optional, Tuple

import glob
import os
import re
import subprocess
import tempfile


def _get_best_plan(plan_filepath: str) -> Tuple[str, float]:
def _get_best_plan(plan_filepath: str) -> tuple[str, float]:
best_cost = float("inf")
best_plan = None

Expand All @@ -25,8 +23,12 @@ def _get_best_plan(plan_filepath: str) -> Tuple[str, float]:


def plan(
domain: str, problem: str, downward: str = "downward", alias: str = "lama", **kwargs
) -> Tuple[Optional[str], int]:
domain: str,
problem: str,
downward: str = "downward",
alias: str = "lama",
**kwargs,
) -> tuple[str | None, int]:
"""Find plan using FastDownward.

Args:
Expand Down
14 changes: 6 additions & 8 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tqdm
import torch

from planetarium import pddl, graph, metric, oracle
from planetarium import builder, graph, metric, oracle
import llm_planner as llmp

from utils import apply_template
Expand Down Expand Up @@ -199,18 +199,18 @@ def result():

try:
# try to parse the LLM output
llm_problem_graph = pddl.build(llm_problem_pddl)
llm_problem_graph = builder.build(llm_problem_pddl)
parseable = True

# reduce and further validate the LLM output
oracle.reduce(llm_problem_graph.decompose()[0], validate=True)
oracle.reduce(llm_problem_graph.decompose()[1], validate=True)
valid = True

problem_graph = pddl.build(problem_pddl)
problem_graph = builder.build(problem_pddl)
init, _ = problem_graph.decompose()

if len(llm_problem_graph._constants) != len(problem_graph._constants):
if len(llm_problem_graph.constants) != len(problem_graph.constants):
resolved = True
return result()

Expand Down Expand Up @@ -255,8 +255,8 @@ def full_equivalence(
bool: True if the scene graphs are equivalent, False otherwise.
"""
return metric.equals(
oracle.fully_specify(source),
oracle.fully_specify(target),
oracle.fully_specify(source, return_reduced=True),
oracle.fully_specify(target, return_reduced=True),
is_placeholder=is_placeholder,
)

Expand Down Expand Up @@ -612,8 +612,6 @@ def main(config_path: str):

# Get LLM output first
problems = load_ungenerated_problems(config, config_str, problem_ids)
print(config_str)
print(len(problems))
# if len(problems) > 0:
# if config["evaluate"]["model"]["type"] == "openai":
# generate_openai(problems, config, config_str)
Expand Down
61 changes: 58 additions & 3 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import defaultdict
from functools import partial
import os
import sqlite3
import yaml

import dotenv
Expand All @@ -10,6 +12,7 @@
from torch import nn

import bitsandbytes as bnb
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import (
AutoTokenizer,
Expand All @@ -23,14 +26,62 @@
import tqdm as tqdm

import llm_planner as llmp
from utils import apply_template, load_dataset, strip
from utils import apply_template

from accelerate import Accelerator


HF_USER_TOKEN = os.getenv("HF_USER_TOKEN")


def load_dataset(config: dict) -> dict[str, Dataset]:
"""Load the dataset from the configuration.

Args:
config (dict): The dataset configuration.

Returns:
dict[str, Dataset]: The loaded dataset.
"""
with open(config["splits_path"], "r") as f:
split_ids_cfg = yaml.safe_load(f)

splits: set[str] = config.get("splits", {}).keys()
dataset = {split: defaultdict(list) for split in splits}

# Connect to database
conn = sqlite3.connect(config["database_path"])
c = conn.cursor()

# load domains
domains = {}
c.execute("SELECT name, domain_pddl FROM domains")
for domain_name, domain_pddl in c.fetchall():
domains[domain_name] = domain_pddl

# load problems
for split in splits:
queries = []
split_keys: list[str] = config["splits"][split]
for split_key in split_keys:
split_ids = split_ids_cfg
for key in split_key:
split_ids = split_ids[key]

c.execute(
f"SELECT domain, problem_pddl, natural_language FROM problems WHERE id in ({', '.join(['?'] * len(split_ids))})",
split_ids,
)
queries.extend(c.fetchall())

for domain, problem_pddl, natural_language in queries:
dataset[split]["domain"].append(domains[domain])
dataset[split]["problem"].append(problem_pddl)
dataset[split]["natural_language"].append(natural_language)

return {s: Dataset.from_dict(d, split=s) for s, d in dataset.items()}


def find_all_linear_names(
model: nn.Module,
bits: int | None = None,
Expand Down Expand Up @@ -62,6 +113,10 @@ def find_all_linear_names(
return list(lora_module_names)


def strip(text: str, bos_token: str, eos_token: str) -> str:
return text.removeprefix(bos_token) + eos_token


def preprocess(
tokenizer: PreTrainedTokenizer,
examples,
Expand Down Expand Up @@ -130,7 +185,7 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
)
else:
bnb_config = None

device_index = Accelerator().process_index
device_map = {"": device_index}
model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -139,7 +194,7 @@ def load_model(config: dict) -> tuple[PreTrainedTokenizer, PreTrainedModel]:
token=HF_USER_TOKEN,
torch_dtype=torch.bfloat16,
quantization_config=bnb_config,
device_map=device_map
device_map=device_map,
)

lora_config = LoraConfig(
Expand Down
File renamed without changes.
Loading
Loading