From 964564102643951bdb892ab07f21f44c9461d2f3 Mon Sep 17 00:00:00 2001 From: Michael Dawson-Haggerty Date: Wed, 4 Dec 2024 14:40:59 -0500 Subject: [PATCH] fix more tests --- tests/test_export.py | 14 ++++---------- tests/test_gltf.py | 7 +++++-- trimesh/exchange/load.py | 11 +++++++++-- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/test_export.py b/tests/test_export.py index e671fc785..ea750297b 100644 --- a/tests/test_export.py +++ b/tests/test_export.py @@ -336,8 +336,6 @@ def test_parse_file_args(self): # it's wordy f = g.trimesh.exchange.load._parse_file_args - RET_COUNT = 5 - # a path that doesn't exist nonexists = f"/banana{g.random()}" assert not g.os.path.exists(nonexists) @@ -348,13 +346,11 @@ def test_parse_file_args(self): # should be able to extract type from passed filename args = f(file_obj=exists, file_type=None) - assert len(args) == RET_COUNT - assert args[1] == "obj" + assert args.file_type == "obj" # should be able to extract correct type from longer name args = f(file_obj=exists, file_type="YOYOMA.oBj") - assert len(args) == RET_COUNT - assert args[1] == "obj" + assert args.file_type == "obj" # with a nonexistent file and no extension it should raise try: @@ -367,15 +363,13 @@ def test_parse_file_args(self): # nonexistent file with extension passed should return # file name anyway, maybe something else can handle it args = f(file_obj=nonexists, file_type=".ObJ") - assert len(args) == RET_COUNT # should have cleaned up case - assert args[1] == "obj" + assert args.file_type == "obj" # make sure overriding type works for string filenames args = f(file_obj=exists, file_type="STL") - assert len(args) == RET_COUNT # should have used manually passed type over .obj - assert args[1] == "stl" + assert args.file_type == "stl" def test_buffered_random(self): """Test writing to non-standard file""" diff --git a/tests/test_gltf.py b/tests/test_gltf.py index c43a5e9e1..c2c78c8e4 100644 --- a/tests/test_gltf.py +++ b/tests/test_gltf.py @@ -53,6 +53,9 @@ def validate_glb(data, name=None): raise ValueError("gltf_validator failed") +load_kwargs = g.trimesh.exchange.load._load_kwargs + + class GLTFTest(g.unittest.TestCase): def test_duck(self): scene = g.get_mesh("Duck.glb", process=False) @@ -196,7 +199,7 @@ def test_units(self): kwargs = g.trimesh.exchange.gltf.load_glb(g.trimesh.util.wrap_as_stream(export)) # roundtrip it - reloaded = g.trimesh.exchange.load.load_kwargs(kwargs) + reloaded = load_kwargs(kwargs) # make basic assertions g.scene_equal(original, reloaded) @@ -264,7 +267,7 @@ def test_merge_buffers(self): assert len(export.keys()) == 2 # reload the export - reloaded = g.trimesh.exchange.load.load_kwargs( + reloaded = load_kwargs( g.trimesh.exchange.gltf.load_gltf( file_obj=None, resolver=g.trimesh.visual.resolvers.ZipResolver(export) ) diff --git a/trimesh/exchange/load.py b/trimesh/exchange/load.py index 0b5c33b4a..6f336da38 100644 --- a/trimesh/exchange/load.py +++ b/trimesh/exchange/load.py @@ -120,9 +120,16 @@ def load( # we are matching deprecated behavior here. # gltf/glb always return a scene - # - file_type = loaded.metadata["file_type"] - if len(loaded.geometry) == 1 and file_type in {"obj", "stl", "ply", "svg", "binvox"}: + if len(loaded.geometry) == 1 and file_type in { + "obj", + "stl", + "ply", + "svg", + "binvox", + "xaml", + "dxf", + }: # matching old behavior, you should probably use `load_scene` return next(iter(loaded.geometry.values()))