Skip to content

Commit

Permalink
fixup actually worknig
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKaravaev committed Mar 28, 2024
1 parent 01d64d6 commit 3632bc3
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ default_section=THIRDPARTY

known_localfolder=
ciare_world_creator,
known_third_party = aiohttp,aspose,chromadb,click,langchain,lxml,obj2mjcf,objaverse,openai,pandas,pytest,questionary,requests,tabulate,tinydb,tqdm,trimesh
known_third_party = aiohttp,chromadb,click,langchain,lxml,numpy,obj2mjcf,objaverse,openai,pandas,pytest,questionary,requests,tabulate,tinydb,tqdm,trimesh
2 changes: 1 addition & 1 deletion ciare_world_creator/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def cli(ctx):
# "Enter query for world generation(E.g Two cars and person next to it)",
# style=STYLE,
# ).ask()
world_query = "10 cups"
world_query = "Surgical room with medical personnel"
if not world_query:
sys.exit(os.EX_OK)

Expand Down
226 changes: 220 additions & 6 deletions ciare_world_creator/sim_interfaces/mujoco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import os
import shutil
import sys
import xml.etree.ElementTree as ET
from io import StringIO
from pathlib import Path

import numpy as np
import obj2mjcf
import objaverse
import trimesh
from obj2mjcf.cli import Args, process_obj
Expand Down Expand Up @@ -37,26 +43,234 @@ def add_models(self, placed_models, models):
objects = objaverse.load_objects(
uids=[entry["uuid"] for entry in full_placed_models]
)
print(objects)
obj_locs = list(objects.values())

print(obj_locs)
# Step 3: Modify the XML tree
visual_count = 0
collision_count = 0
material_count = 0

for i in range(len(full_placed_models)):
full_placed_models[i]["model_loc"] = obj_locs[i]
main_root = ET.Element("mujoco", model="test")

mesh = trimesh.load(obj_locs[i])
for i, _ in enumerate(full_placed_models):
material_map = {}
print(full_placed_models)
full_placed_models[i]["model_loc"] = objects[full_placed_models[i]["uuid"]]

mesh = trimesh.load(full_placed_models[i]["model_loc"], force="mesh")
print(mesh.extents)
if mesh.extents[0] < 1:
mesh.apply_scale(1.0 / 1.0)
elif mesh.extents[0] < 10:
mesh.apply_scale(1.0 / 10.0)
elif mesh.extents[0] < 100:
mesh.apply_scale(1.0 / 100.0)
elif mesh.extents[0] < 1000:
mesh.apply_scale(1.0 / 1000.0)
elif mesh.extents[0] < 10000:
mesh.apply_scale(1.0 / 10000.0)
print(mesh.extents)
# trimesh.exchange.obj.export_obj(mesh)
obj, data = trimesh.exchange.export.export_obj(
mesh, include_texture=True, return_texture=True
)

obj_path = f"./converted/{full_placed_models[i]['uuid']}.obj"
path = f"./converted/{full_placed_models[i]['uuid']}"
path = os.path.abspath(path)
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
nested_path = Path(
os.path.abspath(str(path) + f"/{full_placed_models[i]['uuid']}")
)
nested_path.mkdir(parents=True, exist_ok=True)

obj_path = f"{path}/{full_placed_models[i]['uuid']}.obj"
with open(obj_path, "w") as f:
f.write(obj)
# save the MTL and images
for k, v in data.items():
with open(os.path.join("./converted/", k), "wb") as f:
with open(os.path.join(path, k), "wb") as f:
f.write(v)
args = Args("./", save_mjcf=True, compile_model=True, overwrite=True)
# List all files in the source directory
files = os.listdir(path)
for file in files:
# Check if the file is a .png or .jpg file
if file.endswith(".png") or file.endswith(".jpg"):
print(file)
# Construct paths for the source and destination
source_path = os.path.join(path, file)
destination_path = os.path.join(nested_path, file)
print(source_path, destination_path)
shutil.copy(source_path, destination_path)
# sys.exit(0)
args = Args(
obj_dir=path,
verbose=True,
save_mjcf=True,
compile_model=True,
overwrite=True,
)

sys.stdout = StringIO()
process_obj(Path(obj_path), args)
printed_output = sys.stdout.getvalue()

# Restore stdout
sys.stdout = sys.__stdout__
if "Error compiling model" in printed_output:
continue

saved_mjc_path = Path(
os.path.abspath(
str(path)
+ f"/{full_placed_models[i]['uuid']}"
+ f"/{full_placed_models[i]['uuid']}.xml"
)
)
tree = ET.parse(saved_mjc_path)
root = tree.getroot()

included_tree = ET.parse(saved_mjc_path)
included_root = included_tree.getroot()

# Step 1: Modify default class attributes
for default in included_root.findall(".//default"):
print(default)
class_attribute = default.get("class")
if class_attribute == "visual":
material_map[class_attribute] = f"visual{visual_count}"
default.set("class", material_map[class_attribute])
visual_count += 1
elif class_attribute == "collision":
material_map[class_attribute] = f"collision{visual_count}"
default.set("class", material_map[class_attribute])
collision_count += 1

# Step 2: Find and modify body tag
for body in included_root.findall(".//body"):
# Add pos and euler attributes
body.set("pos", f"0 {visual_count} 3.025")
body.set("euler", "90 0 0")
# body.append('<joint type="free" />')
joint_tag = ET.SubElement(body, "joint", type="free")

# Step 3: Rewrite material name and references
materials = included_root.findall(".//material")
for i, texture in enumerate(included_root.findall(".//texture")):
old_name = texture.get("name")
material_map[old_name] = f"material_{material_count}"
texture.set("name", material_map[old_name])
material_count += 1

for i, material in enumerate(materials):
old_name = material.get("name")
if old_name not in material_map:
material_map[old_name] = f"material_{material_count}"
material_count += 1
material.set("name", material_map[old_name])

texture = material.get("texture")
if texture:
material.set("texture", material_map[old_name])
for i, geom in enumerate(included_root.findall(".//geom")):
# material.set('name', f'material_{visual_count}')
material = geom.get("material")
if material:
geom.set("material", material_map[material])
class_ = geom.get("class")
if class_:
geom.set("class", material_map[class_])
# Replace include element with modified content
# include.clear()
# include.tag = included_root.tag
# include.attrib = included_root.attrib
# include.extend(included_root)
# Step 4: Write the modified XML to a file
print(saved_mjc_path)
included_tree.write(saved_mjc_path)
# Insert include tags for each filepath
include = ET.SubElement(main_root, "include", file=str(saved_mjc_path))
include.tail = "\n"
# Create the tree
tree = ET.ElementTree(main_root)
asset_xml_elements = [ET.Element("asset"), ET.Element("worldbody")]
asset_elements = [
ET.SubElement(
asset_xml_elements[0],
"texture",
type="skybox",
builtin="gradient",
rgb1=".3 .5 .7",
rgb2="0 0 0",
width="32",
height="512",
),
ET.SubElement(
asset_xml_elements[0],
"texture",
name="body",
type="cube",
builtin="flat",
mark="cross",
width="128",
height="128",
rgb1="0.8 0.6 0.4",
rgb2="0.8 0.6 0.4",
markrgb="1 1 1",
random="0.01",
),
ET.SubElement(
asset_xml_elements[0],
"texture",
name="grid",
type="2d",
builtin="checker",
width="512",
height="512",
rgb1=".1 .2 .3",
rgb2=".2 .3 .4",
),
ET.SubElement(
asset_xml_elements[0],
"material",
name="grid",
texture="grid",
texrepeat="1 1",
texuniform="true",
reflectance=".2",
),
]

# Adding child elements to 'worldbody' element
worldbody_elements = [
ET.SubElement(
asset_xml_elements[1],
"geom",
name="floor",
size="0 0 .05",
type="plane",
material="grid",
condim="3",
),
ET.SubElement(
asset_xml_elements[1],
"light",
name="spot",
mode="fixed",
diffuse=".8 .8 .8",
specular="0.3 0.3 0.3",
pos="0 -6 4",
cutoff="30",
),
]
for elem in asset_elements:
elem.tail = "\n"
for elem in worldbody_elements:
elem.tail = "\n"

# Append the asset and worldbody XML elements to the root
main_root.extend(asset_xml_elements)
# Write to file
tree.write("text.xml", encoding="utf-8", xml_declaration=True)
Loading

0 comments on commit 3632bc3

Please sign in to comment.