Skip to content

Commit

Permalink
util: Explicitly call __dlpack__ built-in method in `xp2tensorflo…
Browse files Browse the repository at this point in the history
…w` (#757)

`tf.experimental.dlpack.from_dlpack` expects a `PyCapsule` object.
  • Loading branch information
shadeMe authored Sep 7, 2022
1 parent fba3bf0 commit fc323e1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,8 @@ def xp2tensorflow(
dlpack_tensor = xp_tensor.toDlpack() # type: ignore
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor)
elif hasattr(xp_tensor, "__dlpack__"):
tf_tensor = tf.experimental.dlpack.from_dlpack(xp_tensor)
dlpack_tensor = xp_tensor.__dlpack__() # type: ignore
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_tensor)
else:
tf_tensor = tf.convert_to_tensor(xp_tensor)
if as_variable:
Expand Down

0 comments on commit fc323e1

Please sign in to comment.