Skip to content

Commit

Permalink
Implemented Möller-Trumbore intersection algorithm for differentiable…
Browse files Browse the repository at this point in the history
… ray triangle intersection

PiperOrigin-RevId: 551483422
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Oct 18, 2023
1 parent 1b0203e commit e422cc1
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
72 changes: 72 additions & 0 deletions tensorflow_graphics/geometry/representation/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,5 +451,77 @@ def intersection_ray_sphere(sphere_center,
return intersections_points, normals


def intersection_ray_triangle(
ray_org,
ray_dir,
triangles,
epsilon=1e-8,
name="ray_intersection_ray_triangle",
):
"""Möller-Trumbore intersection algorithm.
Simultaneously computes barycentric coordinates and distance to intersections
of ray to planes defined by triangles. Uses epsilon to detect and ignore
numerically unstable cases, returning all zeros instead. No attempt is made
to ensure that intersections are contained within each triangle.
Note:
In the following, A1 to An are optional batch dimensions.
Args:
ray_org: A tensor of shape `[A1, ..., An, 3]`,
where the last dimension represents the 3D position of the ray origin.
ray_dir: A tensor of shape `[A1, ..., An, 3]`, where
the last dimension represents the normalized 3D direction of the ray.
triangles: A tensor of shape `[A1, ..., An, 3, 3]`, containing batches of
triangles represented using 3 vertices, where the last dimension
represents the 3D position of each vertex.
epsilon: Epsilon value use to detect and ignore degenerate cases.
name: A name for this op that defaults to "ray_intersection_ray_triangle"
Returns:
A tensor of shape `[A1, ..., An, 3]` representing the barycentric
coordinates of each intersection location, and a tensor of shape
`[A1, ..., An]` containing the distance of each ray origin to the
intersection location
"""
with tf.name_scope(name):
ray_org = tf.convert_to_tensor(value=ray_org)
ray_dir = tf.convert_to_tensor(value=ray_dir)
triangles = tf.convert_to_tensor(value=triangles)

shape.check_static(
tensor=ray_org, tensor_name="ray_org", has_dim_equals=(-1, 3))
shape.check_static(
tensor=ray_dir, tensor_name="ray_dir", has_dim_equals=(-1, 3))
shape.check_static(
tensor=triangles,
tensor_name="triangles",
has_dim_equals=[(-2, 3), (-1, 3)],
)

shape.compare_batch_dimensions(
(ray_org, ray_dir, triangles), (-2, -2, -3),
broadcast_compatible=False)

e1 = triangles[..., 1, :] - triangles[..., 0, :]
e2 = triangles[..., 2, :] - triangles[..., 0, :]
s = ray_org - triangles[..., 0, :]
h = tf.linalg.cross(ray_dir, e2)
q = tf.linalg.cross(s, e1)
a = vector.dot(h, e1, keepdims=False)
invalid = tf.abs(a) < epsilon
denom = tf.where(invalid, tf.zeros_like(a), tf.math.divide_no_nan(1.0, a))

t = denom * vector.dot(q, e2, keepdims=False)
b1 = denom * vector.dot(h, s, keepdims=False)
b2 = denom * vector.dot(q, ray_dir, keepdims=False)
b0 = 1 - b1 - b2
barys = tf.stack((b0, b1, b2), axis=-1)
barys = tf.where(invalid[..., tf.newaxis], tf.zeros_like(barys), barys)
t = tf.where(invalid, tf.zeros_like(t), t)
return barys, t


# API contains all public functions and classes.
__all__ = export_api.get_functions_and_classes()
52 changes: 52 additions & 0 deletions tensorflow_graphics/geometry/representation/tests/ray_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,58 @@ def test_intersection_ray_sphere_preset(self, test_inputs, test_outputs):
self.assert_output_is_correct(
ray.intersection_ray_sphere, test_inputs, test_outputs, tile=False)

@flagsaver.flagsaver(tfg_add_asserts_to_graph=False)
def test_intersection_ray_triangle_random(self):
"""Test the intersection_ray_triangle function."""
tensor_size = np.random.randint(3)
tensor_shape = np.random.randint(1, 10, size=(tensor_size)).tolist()
ray_org = np.random.uniform(size=tensor_shape + [3])
ray_dir = np.random.uniform(size=tensor_shape + [3])
ray_dir /= np.linalg.norm(ray_dir, axis=-1, keepdims=True)
triangles = np.random.uniform(size=tensor_shape + [3, 3])

barys, t = ray.intersection_ray_triangle(ray_org, ray_dir, triangles)

intersections_barys = tf.math.reduce_sum(
barys[..., tf.newaxis] * triangles, axis=-2)

intersections_dists = t[..., tf.newaxis] * ray_dir + ray_org

intersections_barys = tf.where(
tf.abs(t) > 0,
intersections_barys,
tf.zeros_like(intersections_barys))

intersections_dists = tf.where(
tf.abs(t) > 0,
intersections_dists,
tf.zeros_like(intersections_dists))

self.assertAllClose(
intersections_barys, intersections_dists, atol=1e-04, rtol=1e-04)

@parameterized.parameters(
(
(
(0.0, 0.0, 0.0),
(0.0, 0.0, 1.0),
((1.0, 0.0, 1.0), (0.0, 1.0, 1.0), (0.0, 0.0, 1.0)),
),
((0.0, 0.0, 1.0), 1.0),
),
(
(
(0.5, 0.5, 0.0),
(0.0, 0.0, 1.0),
((1.0, 0.0, 0.5), (0.0, 1.0, 0.5), (0.0, 0.0, 0.5)),
),
((0.5, 0.5, 0.0), 0.5),
),
)
def test_intersection_ray_triangle_preset(self, test_inputs, test_outputs):
self.assert_output_is_correct(
ray.intersection_ray_triangle, test_inputs, test_outputs, tile=False)


if __name__ == "__main__":
test_case.main()

0 comments on commit e422cc1

Please sign in to comment.