Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix __eq__ and __ne__ for classes implementing them #707

Merged
merged 1 commit into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 21 additions & 30 deletions src/build123d/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,10 @@ def __repr__(self) -> str:

__str__ = __repr__

def __eq__(self, other: Vector) -> bool: # type: ignore[override]
def __eq__(self, other: object) -> bool:
"""Vectors equal operator =="""
if not isinstance(other, Vector):
return NotImplemented
return self.wrapped.IsEqual(other.wrapped, 0.00001, 0.00001)

def __hash__(self) -> int:
Expand Down Expand Up @@ -670,7 +672,7 @@ def __str__(self) -> str:

def __eq__(self, other: object) -> bool:
if not isinstance(other, Axis):
return False
return NotImplemented
return self.position == other.position and self.direction == other.direction

def located(self, new_location: Location):
Expand Down Expand Up @@ -1468,10 +1470,10 @@ def __mul__(self, other: T) -> T:
def __pow__(self, exponent: int) -> Location:
return Location(self.wrapped.Powered(exponent))

def __eq__(self, other: Location) -> bool:
def __eq__(self, other: object) -> bool:
"""Compare Locations"""
if not isinstance(other, Location):
raise ValueError("other must be a Location")
return NotImplemented
quaternion1 = gp_Quaternion()
quaternion1.SetEulerAngles(
gp_EulerSequence.gp_Intrinsic_XYZ,
Expand Down Expand Up @@ -2139,27 +2141,6 @@ def offset(self, amount: float) -> Plane:
origin=self.origin + self.z_dir * amount, x_dir=self.x_dir, z_dir=self.z_dir
)

def _eq_iter(self, other: Plane):
"""Iterator to successively test equality

Args:
other: Plane to compare to

Returns:
Are planes equal
"""
# equality tolerances
eq_tolerance_origin = 1e-6
eq_tolerance_dot = 1e-6

yield isinstance(other, Plane) # comparison is with another Plane
# origins are the same
yield abs(self._origin - other.origin) < eq_tolerance_origin
# z-axis vectors are parallel (assumption: both are unit vectors)
yield abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot
# x-axis vectors are parallel (assumption: both are unit vectors)
yield abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot

def __copy__(self) -> Plane:
"""Return copy of self"""
return Plane(gp_Pln(self.wrapped.Position()))
Expand All @@ -2168,13 +2149,23 @@ def __deepcopy__(self, _memo) -> Plane:
"""Return deepcopy of self"""
return Plane(gp_Pln(self.wrapped.Position()))

def __eq__(self, other: Plane):
def __eq__(self, other: object):
"""Are planes equal operator =="""
return all(self._eq_iter(other))
if not isinstance(other, Plane):
return NotImplemented

def __ne__(self, other: Plane):
"""Are planes not equal operator !+"""
return not self.__eq__(other)
# equality tolerances
eq_tolerance_origin = 1e-6
eq_tolerance_dot = 1e-6

return (
# origins are the same
abs(self._origin - other.origin) < eq_tolerance_origin
# z-axis vectors are parallel (assumption: both are unit vectors)
and abs(self.z_dir.dot(other.z_dir) - 1) < eq_tolerance_dot
# x-axis vectors are parallel (assumption: both are unit vectors)
and abs(self.x_dir.dot(other.x_dir) - 1) < eq_tolerance_dot
)

def __neg__(self) -> Plane:
"""Reverse z direction of plane operator -"""
Expand Down
12 changes: 9 additions & 3 deletions src/build123d/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,7 +1972,7 @@ def is_equal(self, other: Shape) -> bool:

def __eq__(self, other) -> bool:
"""Are shapes same operator =="""
return self.is_same(other) if isinstance(other, Shape) else False
return self.is_same(other) if isinstance(other, Shape) else NotImplemented

def is_valid(self) -> bool:
"""Returns True if no defect is detected on the shape S or any of its
Expand Down Expand Up @@ -3704,9 +3704,15 @@ def __or__(self, filter_by: Union[Axis, GeomType] = Axis.Z):
"""Filter by axis or geomtype operator |"""
return self.filter_by(filter_by)

def __eq__(self, other: ShapeList):
def __eq__(self, other: object):
"""ShapeLists equality operator =="""
return set(self) == set(other)
return set(self) == set(other) if isinstance(other, ShapeList) else NotImplemented

# Normally implementing __eq__ is enough, but ShapeList subclasses list,
# which already implements __ne__, so we need to override it, too
def __ne__(self, other: ShapeList):
"""ShapeLists inequality operator !="""
return set(self) != set(other) if isinstance(other, ShapeList) else NotImplemented

def __add__(self, other: ShapeList):
"""Combine two ShapeLists together operator +"""
Expand Down
54 changes: 50 additions & 4 deletions tests/test_direct_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@
RAD2DEG = 180 / math.pi


# Always equal to any other object, to test that __eq__ cooperation is working
class AlwaysEqual:
def __eq__(self, other):
return True


class DirectApiTestCase(unittest.TestCase):
def assertTupleAlmostEquals(
self,
Expand Down Expand Up @@ -363,13 +369,13 @@ def test_axis_equal(self):
self.assertEqual(Axis.X, Axis.X)
self.assertEqual(Axis.Y, Axis.Y)
self.assertEqual(Axis.Z, Axis.Z)
self.assertEqual(Axis.X, AlwaysEqual())

def test_axis_not_equal(self):
self.assertNotEqual(Axis.X, Axis.Y)
random_obj = object()
self.assertNotEqual(Axis.X, random_obj)


class TestBoundBox(DirectApiTestCase):
def test_basic_bounding_box(self):
v = Vertex(1, 1, 1)
Expand Down Expand Up @@ -1730,15 +1736,21 @@ def test_to_axis(self):
self.assertVectorAlmostEquals(axis.position, (1, 2, 3), 6)
self.assertVectorAlmostEquals(axis.direction, (0, 1, 0), 6)

def test_eq(self):
def test_equal(self):
loc = Location((1, 2, 3), (4, 5, 6))
diff_position = Location((10, 20, 30), (4, 5, 6))
diff_orientation = Location((1, 2, 3), (40, 50, 60))
same = Location((1, 2, 3), (4, 5, 6))

self.assertEqual(loc, same)
self.assertEqual(loc, AlwaysEqual())

def test_not_equal(self):
loc = Location((1, 2, 3), (40, 50, 60))
diff_position = Location((3, 2, 1), (40, 50, 60))
diff_orientation = Location((1, 2, 3), (60, 50, 40))

self.assertNotEqual(loc, diff_position)
self.assertNotEqual(loc, diff_orientation)
self.assertNotEqual(loc, object())

def test_neg(self):
loc = Location((1, 2, 3), (0, 35, 127))
Expand Down Expand Up @@ -2666,6 +2678,8 @@ def test_plane_equal(self):
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
Plane(origin=(0, 0, 0), x_dir=(1, 0, 0), z_dir=(0, 1, 1)),
)
# __eq__ cooperation
self.assertEqual(Plane.XY, AlwaysEqual())

def test_plane_not_equal(self):
# type difference
Expand Down Expand Up @@ -2955,6 +2969,17 @@ def test_is_equal(self):
box = Solid.make_box(1, 1, 1)
self.assertTrue(box.is_equal(box))

def test_equal(self):
box = Solid.make_box(1, 1, 1)
self.assertEqual(box, box)
self.assertEqual(box, AlwaysEqual())

def test_not_equal(self):
box = Solid.make_box(1, 1, 1)
diff = Solid.make_box(1, 2, 3)
self.assertNotEqual(box, diff)
self.assertNotEqual(box, object())

def test_tessellate(self):
box123 = Solid.make_box(1, 2, 3)
verts, triangles = box123.tessellate(1e-6)
Expand Down Expand Up @@ -3439,6 +3464,20 @@ def test_compound(self):
sl = ShapeList([Box(1, 2, 3), Vertex(1, 1, 1)])
self.assertAlmostEqual(sl.compound().volume, 1 * 2 * 3, 5)

def test_equal(self):
box = Box(1, 1, 1)
cyl = Cylinder(1, 1)
sl = ShapeList([box, cyl])
same = ShapeList([cyl, box])
self.assertEqual(sl, same)
self.assertEqual(sl, AlwaysEqual())

def test_not_equal(self):
sl = ShapeList([Box(1, 1, 1), Cylinder(1, 1)])
diff = ShapeList([Box(1, 1, 1), Box(1, 2, 3)])
self.assertNotEqual(sl, diff)
self.assertNotEqual(sl, object())


class TestShells(DirectApiTestCase):
def test_shell_init(self):
Expand Down Expand Up @@ -3753,6 +3792,13 @@ def test_vector_equals(self):
c = Vector(1, 2, 3.000001)
self.assertEqual(a, b)
self.assertEqual(a, c)
self.assertEqual(a, AlwaysEqual())

def test_vector_not_equal(self):
a = Vector(1, 2, 3)
b = Vector(3, 2, 1)
self.assertNotEqual(a, b)
self.assertNotEqual(a, object())

def test_vector_distance(self):
"""
Expand Down
Loading