Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKaravaev committed Mar 24, 2024
1 parent ef1822e commit 01d64d6
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 109 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ default_section=THIRDPARTY

known_localfolder=
ciare_world_creator,
known_third_party = aiohttp,aspose,chromadb,click,langchain,lxml,objaverse,openai,pandas,pytest,questionary,requests,tabulate,tinydb,tqdm
known_third_party = aiohttp,aspose,chromadb,click,langchain,lxml,obj2mjcf,objaverse,openai,pandas,pytest,questionary,requests,tabulate,tinydb,tqdm,trimesh
4 changes: 2 additions & 2 deletions ciare_world_creator/collections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def fill_index(collection, loader: BaseLoader):
f"Generating indicies for chromadb. This might take a while, but it's done only once",
style="bold italic fg:green",
)
models = loader.get_models()

models, _ = loader.get_models()
print(models)
df_models = pd.DataFrame(models)
df_models = df_models.drop_duplicates(subset="name")
df_models["tags"] = df_models["tags"].apply(
Expand Down
76 changes: 40 additions & 36 deletions ciare_world_creator/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,11 @@
from ciare_world_creator.model_databases.fetch_worlds import download_world
from ciare_world_creator.model_databases.gazebo import GazeboLoader
from ciare_world_creator.model_databases.objaverse import ObjaverseLoader
from ciare_world_creator.sim_interfaces.gazebo import GazeboSimInterface
from ciare_world_creator.sim_interfaces.mujoco import MujocoSimInterface
from ciare_world_creator.utils.cache import Cache
from ciare_world_creator.utils.style import STYLE
from ciare_world_creator.xml.worlds import (
add_model_to_xml,
check_world,
find_model,
find_world,
save_xml,
)
from ciare_world_creator.xml.worlds import find_model


@click.command(
Expand All @@ -39,25 +35,26 @@ def cli(ctx):
from ciare_world_creator.llm.model import prompt_model

simulators = ["mujoco", "gazebo"]
chosen_simulator = questionary.select(
message=("Choose simulator to generate world for."),
choices=simulators,
style=STYLE,
).ask()

# chosen_simulator = questionary.select(
# message=("Choose simulator to generate world for."),
# choices=simulators,
# style=STYLE,
# ).ask()
chosen_simulator = "mujoco"
if chosen_simulator == "gazebo":
# Only gazebo is supported
loader = GazeboLoader()
full_models = loader.get_models_full()
full_worlds = loader.get_worlds_full()
interface = GazeboSimInterface()
elif chosen_simulator == "mujoco":
loader = ObjaverseLoader()
interface = MujocoSimInterface()
models, worlds = loader.get_models()

world_query = questionary.text(
"Enter query for world generation(E.g Two cars and person next to it)",
style=STYLE,
).ask()
# world_query = questionary.text(
# "Enter query for world generation(E.g Two cars and person next to it)",
# style=STYLE,
# ).ask()
world_query = "10 cups"
if not world_query:
sys.exit(os.EX_OK)

Expand All @@ -68,7 +65,7 @@ def cli(ctx):
exists = db.search(World.prompt == query)

openai.api_key = os.getenv("OPENAI_API_KEY")
models = openai.Model.list()
llm_models = openai.Model.list()

chosen_model = "gpt-4"
if exists:
Expand All @@ -78,7 +75,7 @@ def cli(ctx):
)
return

model_collection = get_or_create_collection("models_" + chosen_simulator)
model_collection = get_or_create_collection("models_" + chosen_simulator, loader)
try:
claim_query_result = model_collection.query(
query_texts=[query],
Expand All @@ -101,16 +98,18 @@ def cli(ctx):
)
]

generate_world = questionary.confirm(
"Do you want to spawn model in an empty world?"
" Saying no will download world from database, but it's very unstable. Y/n",
style=STYLE,
).ask()
# generate_world = questionary.confirm(
# "Do you want to spawn model in an empty world?"
# " Saying no will download world from database, but it's very unstable. Y/n",
# style=STYLE,
# ).ask()

generate_world = False

if generate_world is None:
sys.exit(os.EX_OK)

if not generate_world:
if generate_world:
content = fmt_world_qa_tmpl.format(context_str=worlds)

questionary.print("Generating world... 🌎", style="bold fg:yellow")
Expand All @@ -120,7 +119,7 @@ def cli(ctx):
f"World is {world['World']}, downloading it", style="bold italic fg:green"
)

full_world = find_world(world["World"], full_worlds)
full_world = interface.find_world(world["World"], worlds)
template_world_path = None
if world["World"] != "None":
template_world_path = download_world(
Expand All @@ -137,7 +136,7 @@ def cli(ctx):
world = {"World": "None"}
template_world_path = os.path.join(cache.worlds_path, "empty.sdf")

if not check_world(template_world_path):
if not interface.check_world(template_world_path):
questionary.print(
"Suggested world is malformed. Falling back to empty world",
style="bold italic fg:red",
Expand All @@ -148,11 +147,12 @@ def cli(ctx):
"Spawning models in the world... 🫖", style="bold italic fg:yellow"
)
content = fmt_model_qa_tmpl.format(context_str=context)
models = prompt_model(content, query, chosen_model)
chosen_models = prompt_model(content, query, chosen_model)

for model in models:
if not find_model(model["Model"], full_models):
models = prompt_model(
print(chosen_models)
for model in chosen_models:
if not find_model(model["Model"], models):
chosen_models = prompt_model(
content,
f"{model} was not found in context list. "
"Generate only the one that are in the context",
Expand All @@ -161,12 +161,14 @@ def cli(ctx):

questionary.print("Placing models in the world... 📍", style="bold italic fg:yellow")
content = fmt_place_qa_tmpl.format(
context_str=f"Arrange following models: {str(models)}",
context_str=f"Arrange following models: {str(chosen_models)}",
world_file=open(template_world_path, "r"),
)

# print(content)
# sys.exit(0)
placement = prompt_model(content, query, chosen_model)

print(placement)
# TODO handle ,.; etc
cleaned_query = re.sub(r'[<>:;.,"/\\|?*]', "", query).strip()
world_name = f'world_{cleaned_query.replace(" ", "_")}'
Expand All @@ -177,9 +179,11 @@ def cli(ctx):
# TODO add asserts on model fields
non_existent_models = []

interface.add_models(placement, models)
sys.exit(0)
for model in placement:
# Example usage
m = find_model(model["Model"], full_models)
m = find_model(model["Model"], models)
if not m:
questionary.print(
f"Model {model} was not found in database. "
Expand Down
14 changes: 14 additions & 0 deletions ciare_world_creator/model_databases/objaverse.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
import json
import os

import objaverse

from ciare_world_creator.model_databases.base import BaseLoader


class ObjaverseLoader(BaseLoader):
def __init__(self):
fp = "./LVIS.json"
if os.path.exists(fp):
# If the file does not exist, create it and dump the JSON data
with open(fp, "r") as file:
cached = json.load(file)
self.annotations = cached[0]
self.uid_to_category = cached[1]
return

lvis_annotations = objaverse.load_lvis_annotations()

truncated_annotations = (
Expand All @@ -19,6 +31,8 @@ def __init__(self):
self.uid_to_category[item] = key

self.annotations = objaverse.load_annotations(self.uid_to_category.keys())
with open(fp, "w") as file:
json.dump([self.annotations, self.uid_to_category], file)

def get_models(self):
only_description_models = []
Expand Down
Empty file.
61 changes: 61 additions & 0 deletions ciare_world_creator/sim_interfaces/gazebo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from lxml import etree
from lxml import etree as ET


class GazeboSimInterface:
def __init__(self):
pass

def check_world(self, template_world_path):
"""Load world and asserts if basic tags are there."""
parser = ET.XMLParser(recover=True, remove_blank_text=True)

tree = etree.parse(template_world_path, parser=parser)

root = tree.getroot()
world_xml = root.find("world")
return world_xml is not None

def add_model_to_xml(
self, model_name, pose_x, pose_y, pose_z, pose_roll, pose_pitch, pose_yaw, uri
):
# Create the new <include> element
include = ET.Element("include")

name = ET.SubElement(include, "name")
name.text = model_name

pose = ET.SubElement(include, "pose")
pose.text = f"{pose_x} {pose_y} {pose_z} {pose_roll} {pose_pitch} {pose_yaw}"

uri_element = ET.SubElement(include, "uri")
uri_element.text = uri

return include

def save_xml(self, xml_file, template_world_path, include_tags):
parser = ET.XMLParser(recover=True, remove_blank_text=True)

tree = etree.parse(template_world_path, parser=parser)

root = tree.getroot()

world_xml = root.find("world")

for include in include_tags:
world_xml.append(include)

# Indent the XML with two spaces
tree_str = ET.tostring(
root,
pretty_print=True,
encoding="utf-8",
xml_declaration=True,
with_tail=True,
)

# parsed_tree = ET.fromstring(tree_str)

# Save the formatted XML to the file
with open(xml_file, "wb") as file:
file.write(tree_str)
62 changes: 62 additions & 0 deletions ciare_world_creator/sim_interfaces/mujoco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
from pathlib import Path

import objaverse
import trimesh
from obj2mjcf.cli import Args, process_obj


class MujocoSimInterface:
def __init__(self):
pass

def check_world(self, world):
pass

def generate_world(self):
world = {"World": "None"}
template_world_path = os.path.join(cache.worlds_path, "empty.sdf")

def find_entry_by_name(self, name, full_list):
for entry in full_list:
if entry["name"] == name:
return entry
return None

def add_models(self, placed_models, models):
full_placed_models = []

for model in placed_models:
if model_entry := self.find_entry_by_name(model["Model"], models):
model_entry.update(model)
full_placed_models.append(model_entry)
print(model_entry)
print(full_placed_models)

# model_db_interface.load_models(full_placed_models)
objects = objaverse.load_objects(
uids=[entry["uuid"] for entry in full_placed_models]
)
obj_locs = list(objects.values())

print(obj_locs)

for i in range(len(full_placed_models)):
full_placed_models[i]["model_loc"] = obj_locs[i]

mesh = trimesh.load(obj_locs[i])
# trimesh.exchange.obj.export_obj(mesh)
obj, data = trimesh.exchange.export.export_obj(
mesh, include_texture=True, return_texture=True
)

obj_path = f"./converted/{full_placed_models[i]['uuid']}.obj"
with open(obj_path, "w") as f:
f.write(obj)
# save the MTL and images
for k, v in data.items():
with open(os.path.join("./converted/", k), "wb") as f:
f.write(v)
args = Args("./", save_mjcf=True, compile_model=True, overwrite=True)

process_obj(Path(obj_path), args)
Loading

0 comments on commit 01d64d6

Please sign in to comment.