diff --git a/.gitignore b/.gitignore index 20fe29d..7233576 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.jl.mem /Manifest.toml /docs/build/ +/test/Manifest.toml diff --git a/src/utils.jl b/src/utils.jl index c9837a9..2a34014 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,6 +4,14 @@ function (project::ProjectTo{CoherencyMatrix})(dx::AbstractMatrix) return CoherencyMatrix(dx, project.basis1, project.basis2) end +function (project::ProjectTo{CoherencyMatrix})(dx::CoherencyMatrix{B1, B2}) where {B1, B2} + @assert B1() == project.basis1 "First basis does not match in $(typeof(dx)) and $(project.basis1)" + @assert B2() == project.basis2 "Second basis does not match in $(typeof(dx)) and $(project.basis2)" + @assert size(dx) == (2,2) "Issue in Coherency pullback the matrix is not 2x2" + return dx +end + + # function ChainRulesCore.rrule(::Type{<:CoherencyMatrix}, e11, e21, e12, e22, basis::NTuple{2, <:PolBasis}) # c = CoherencyMatrix(e11, e21, e12, e22, basis) diff --git a/test/Project.toml b/test/Project.toml index c416f8b..41b37d0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,3 +1,5 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index ed9dd26..a2e3b76 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,6 @@ using PolarizedTypes +using StaticArrays +using ChainRulesCore using JET using Test @@ -190,4 +192,20 @@ using Test end end + + @testset "ChainRules" begin + I = 2.0 + 0.5im + Q = rand(ComplexF64) - 0.5 + U = rand(ComplexF64) - 0.5 + V = rand(ComplexF64) - 0.5 + s = StokesParams(I, Q, U, V) + c = CoherencyMatrix{CirBasis, LinBasis}(s) + + cmat = SMatrix(c) + prc = ChainRulesCore.ProjectTo(c) + @test prc(cmat) == c + @test prc(c) == c + @test_throws "First basis does" prc(CoherencyMatrix(cmat, LinBasis())) + @test_throws "Second basis does" prc(CoherencyMatrix(cmat, CirBasis())) + end end