Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <[email protected]>
  • Loading branch information
beat-buesser committed Dec 24, 2023
1 parent 4abbedb commit 11aac26
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
try:
object_detector, x_test, y_test = get_pytorch_detr

print("x_test[0]")
print(x_test[0])
print("x_test[1]")
print(x_test[1])

grads = object_detector.loss_gradient(x=x_test, y=y_test)

assert grads.shape == (2, 3, 800, 800)
Expand Down Expand Up @@ -141,7 +146,7 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
print("expected_gradients1")
print(grads[0, 0, 10, :32])

np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=1)
np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=4)

expected_gradients2 = np.asarray(
[
Expand Down Expand Up @@ -183,7 +188,7 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
print("expected_gradients2")
print(grads[1, 0, 10, :32])

np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=2)
np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=4)

except ARTTestException as e:
art_warning(e)
Expand Down

0 comments on commit 11aac26

Please sign in to comment.