Skip to content

Commit

Permalink
Add tests for projectto
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Aug 23, 2023
1 parent 20f27bd commit fae9e31
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
*.jl.mem
/Manifest.toml
/docs/build/
/test/Manifest.toml
8 changes: 8 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ function (project::ProjectTo{CoherencyMatrix})(dx::AbstractMatrix)
return CoherencyMatrix(dx, project.basis1, project.basis2)
end

function (project::ProjectTo{CoherencyMatrix})(dx::CoherencyMatrix{B1, B2}) where {B1, B2}
@assert B1() == project.basis1 "First basis does not match in $(typeof(dx)) and $(project.basis1)"
@assert B2() == project.basis2 "Second basis does not match in $(typeof(dx)) and $(project.basis2)"
@assert size(dx) == (2,2) "Issue in Coherency pullback the matrix is not 2x2"
return dx
end



# function ChainRulesCore.rrule(::Type{<:CoherencyMatrix}, e11, e21, e12, e22, basis::NTuple{2, <:PolBasis})
# c = CoherencyMatrix(e11, e21, e12, e22, basis)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using PolarizedTypes
using StaticArrays
using ChainRulesCore
using JET
using Test

Expand Down Expand Up @@ -190,4 +192,20 @@ using Test
end

end

@testset "ChainRules" begin
I = 2.0 + 0.5im
Q = rand(ComplexF64) - 0.5
U = rand(ComplexF64) - 0.5
V = rand(ComplexF64) - 0.5
s = StokesParams(I, Q, U, V)
c = CoherencyMatrix{CirBasis, LinBasis}(s)

cmat = SMatrix(c)
prc = ChainRulesCore.ProjectTo(c)
@test prc(cmat) == c
@test prc(c) == c
@test_throws "First basis does" prc(CoherencyMatrix(cmat, LinBasis()))
@test_throws "Second basis does" prc(CoherencyMatrix(cmat, CirBasis()))
end
end

0 comments on commit fae9e31

Please sign in to comment.