diff --git a/test/layer_tests.jl b/test/layer_tests.jl index 5789a64..db80e52 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -305,6 +305,14 @@ end __f = (x, ps) -> sum(first(pd(x, ps, st))) @test_gradients(__f, x, ps; atol=1.0f-3, rtol=1.0f-3, broken_backends=[AutoTracker()]) + + pd2 = Layers.PositiveDefinite(model, ones(2)) + ps, st = Lux.setup(StableRNG(0), pd2) |> dev + + x0 = ones(Float32, 2) |> aType + y, _ = pd2(x0, ps, st) + + @test all(y .== 0.0f0) end end