Skip to content

Commit

Permalink
Writeonly capture fix (EnzymeAD#1616)
Browse files Browse the repository at this point in the history
* Writeonly capture fix

* Update runtests.jl

* Update runtests.jl

* Update runtests.jl

* Update runtests.jl
  • Loading branch information
wsmoses authored Jul 8, 2024
1 parent 4da44dd commit c799584
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1543,8 +1543,11 @@ function detect_writeonly!(mod::LLVM.Module)
end
for (i, a) in enumerate(parameters(f))
if isa(value_type(a), LLVM.PointerType)
todo = LLVM.Value[a]
seen = Set{LLVM.Value}()
todo = Tuple{LLVM.Value, LLVM.Instruction}[]
for u in LLVM.uses(a)
push!(todo, (a, LLVM.user(u)))
end
seen = Set{Tuple{LLVM.Value, LLVM.Instruction}}()
mayread = false
maywrite = false
while length(todo) > 0
Expand All @@ -1553,20 +1556,23 @@ function detect_writeonly!(mod::LLVM.Module)
continue
end
push!(seen, cur)
curv, curi = cur

if isa(cur, LLVM.StoreInst)
maywrite = true
continue
if isa(curi, LLVM.StoreInst)
if operands(curi)[1] != curv
maywrite = true
continue
end
end

if isa(cur, LLVM.LoadInst)
if isa(curi, LLVM.LoadInst)
mayread = true
continue
end

if isa(cur, LLVM.Argument) || isa(cur, LLVM.GetElementPtrInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst)
for u in LLVM.uses(cur)
push!(todo, LLVM.user(u))
if isa(curi, LLVM.GetElementPtrInst) || isa(curi, LLVM.BitCastInst) || isa(curi, LLVM.AddrSpaceCastInst)
for u in LLVM.uses(curi)
push!(todo, (curi, LLVM.user(u)))
end
continue
end
Expand Down
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,24 @@ function euroad(f::T) where T
return g
end

@noinline function womylogpdf(X::AbstractArray{<:Real})
map(womylogpdf, X)
end

function womylogpdf(x::Real)
(x - 2)
end


function wologpdf_test(x)
return womylogpdf(x)
end

@testset "Ensure writeonly deduction combines with capture" begin
res = Enzyme.autodiff(Enzyme.Forward, wologpdf_test, Duplicated([0.5], [0.7]))
@test res[1] [0.7]
end

euroad′(x) = first(autodiff(Reverse, euroad, Active, Active(x)))[1]

@test euroad(0.5) -log(0.5) # -log(1-x)
Expand Down

0 comments on commit c799584

Please sign in to comment.