From 6e7fd43997968270d98c6684e779e715ad3b89a9 Mon Sep 17 00:00:00 2001 From: Michael Dawson-Haggerty Date: Tue, 3 Dec 2024 15:36:35 -0500 Subject: [PATCH] fix some metadata passing --- tests/test_loaded.py | 2 +- trimesh/__init__.py | 2 ++ trimesh/exchange/load.py | 58 ++++++++++++++++++++++------------------ 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/tests/test_loaded.py b/tests/test_loaded.py index 0fe534e46..09105545f 100644 --- a/tests/test_loaded.py +++ b/tests/test_loaded.py @@ -36,7 +36,7 @@ def test_fileobj(self): # check load_mesh file_obj = open(g.os.path.join(g.dir_models, "featuretype.STL"), "rb") assert not file_obj.closed - mesh = g.trimesh.load(file_obj=file_obj, file_type="stl") + mesh = g.trimesh.load_mesh(file_obj=file_obj, file_type="stl") # should have actually loaded the mesh assert len(mesh.faces) == 3476 # should not close the file object diff --git a/trimesh/__init__.py b/trimesh/__init__.py index 40d3b85c3..ec75e39cc 100644 --- a/trimesh/__init__.py +++ b/trimesh/__init__.py @@ -54,6 +54,7 @@ load_mesh, load_path, load_remote, + load_scene, ) # geometry objects @@ -108,6 +109,7 @@ "load_mesh", "load_path", "load_remote", + "load_scene", "nsphere", "path", "permutate", diff --git a/trimesh/exchange/load.py b/trimesh/exchange/load.py index 51282697e..bc5c11b6e 100644 --- a/trimesh/exchange/load.py +++ b/trimesh/exchange/load.py @@ -108,6 +108,7 @@ def load( file_type=file_type, resolver=resolver, allow_remote=allow_remote, + **kwargs, ) # combine a scene into a single mesh @@ -156,6 +157,7 @@ def load_scene( Loaded geometry as trimesh classes """ + # parse all possible values of file objects into simple types arg = _parse_file_args( file_obj=file_obj, file_type=file_type, @@ -201,14 +203,17 @@ def load_scene( arg.file_obj.close() if not isinstance(loaded, Scene): - return Scene(loaded) + loaded = Scene(loaded) + + # add any file path metadata + loaded.metadata.update(arg.metadata) return loaded def load_mesh(*args, **kwargs) -> Trimesh: """ - Load a mesh file into a Trimesh object. + Load a file into a Trimesh object. Parameters ----------- @@ -284,11 +289,11 @@ def _load_compressed(file_obj, file_type=None, resolver=None, mixed=False, **kwa compressed_type = util.split_extension(name).lower() # if file has metadata type include it - if compressed_type in "yaml": + if compressed_type in ("yaml", "yml"): import yaml meta_archive[name] = yaml.safe_load(data) - elif compressed_type in "json": + elif compressed_type == "json": import json meta_archive[name] = json.loads(data) @@ -297,14 +302,18 @@ def _load_compressed(file_obj, file_type=None, resolver=None, mixed=False, **kwa # don't raise an exception, just try the next one continue # store the file name relative to the archive - arg.metadata["file_name"] = archive_name + "/" + os.path.basename(name) + metadata = { + "file_name": os.path.basename(name), + "file_path": os.path.join(archive_name, name), + } + # load the individual geometry geometries.append( load_scene( file_obj=data, file_type=compressed_type, resolver=arg.resolver, - metadata=arg.metadata, + metadata=metadata, **kwargs, ) ) @@ -315,7 +324,7 @@ def _load_compressed(file_obj, file_type=None, resolver=None, mixed=False, **kwa # if we opened the file in this function # clean up after ourselves if arg.was_opened: - file_obj.close() + arg.file_obj.close() # append meshes or scenes into a single Scene object result = append_scenes(geometries) @@ -327,19 +336,19 @@ def _load_compressed(file_obj, file_type=None, resolver=None, mixed=False, **kwa return result -def load_remote(url, **kwargs) -> Scene: +def load_remote(url: str, **kwargs) -> Scene: """ Load a mesh at a remote URL into a local trimesh object. - This must be called explicitly rather than automatically - from trimesh.load to ensure users don't accidentally make - network requests. + This is a thin wrapper around: + `trimesh.load_scene(file_obj=url, allow_remote=True, **kwargs)` Parameters ------------ - url : string + url URL containing mesh file - **kwargs : passed to `load` + **kwargs + Passed to `load_scene` Returns ------------ @@ -354,7 +363,7 @@ def _load_kwargs(*args, **kwargs) -> Geometry: Load geometry from a properly formatted dict or kwargs """ - def handle_scene(): + def handle_scene() -> Scene: """ Load a scene from our kwargs. @@ -407,7 +416,7 @@ def handle_scene(): return scene - def handle_mesh(): + def handle_mesh() -> Trimesh: """ Handle the keyword arguments for a Trimesh object """ @@ -464,13 +473,9 @@ def handle_pointcloud(): for func, expected in handlers: if all(i in kwargs for i in expected): # all expected kwargs exist - handler = func - # exit the loop as we found one - break - else: - raise ValueError(f"unable to determine type: {kwargs.keys()}") + return func() - return handler() + raise ValueError(f"unable to determine type: {kwargs.keys()}") @dataclass @@ -542,6 +547,7 @@ def _parse_file_args( metadata = {} opened = False + file_path = None if "metadata" in kwargs and isinstance(kwargs["metadata"], dict): metadata.update(kwargs["metadata"]) @@ -561,6 +567,7 @@ def _parse_file_args( exists = os.path.isfile(file_path) except BaseException: exists = False + file_path = None # file obj is a string which exists on filesystm if exists: @@ -568,8 +575,6 @@ def _parse_file_args( if resolver is None: resolver = resolvers.FilePathResolver(file_path) # save the file name and path to metadata - metadata["file_path"] = file_path - metadata["file_name"] = os.path.basename(file_obj) # if file_obj is a path that exists use extension as file_type if file_type is None: file_type = util.split_extension(file_path, special=["tar.gz", "tar.bz2"]) @@ -604,15 +609,16 @@ def _parse_file_args( if isinstance(file_type, str) and "." in file_type: # if someone has passed the whole filename as the file_type # use the file extension as the file_type - if "file_path" not in metadata: - metadata["file_path"] = file_type - metadata["file_name"] = os.path.basename(file_type) + file_path = file_type file_type = util.split_extension(file_type) if resolver is None and os.path.exists(file_type): resolver = resolvers.FilePathResolver(file_type) # all our stored extensions reference in lower case file_type = file_type.lower() + if file_path is not None: + metadata["file_path"] = file_path + metadata["file_name"] = os.path.basename(file_path) # if we still have no resolver try using file_obj name if (