Skip to content

Commit

Permalink
feat: don't unroll Recurrence
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 18, 2025
1 parent 30e7b01 commit 4ae54ef
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.5.1"
version = "1.6.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
5 changes: 4 additions & 1 deletion ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ module LuxReactantExt

using Enzyme: Enzyme, Const, Duplicated, Active
using Optimisers: Optimisers
using Reactant: Reactant, @compile, @code_hlo, AnyTracedRArray, TracedRArray, TracedRNumber
using Reactant: Reactant, @compile, @code_hlo, @trace, AnyTracedRArray, TracedRArray,
TracedRNumber
using Setfield: @set!
using Static: True, False

using Lux: Lux, LuxOps, Training, Utils
using Lux.Training: TrainingBackendCache, ReactantBackend
using LuxCore: LuxCore

Lux.is_extension_loaded(::Val{:Reactant}) = true

Expand All @@ -26,5 +28,6 @@ end

include("patches.jl")
include("training.jl")
include("layers.jl")

end
30 changes: 30 additions & 0 deletions ext/LuxReactantExt/layers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Embedding
function (e::Lux.Embedding)(x::TracedRNumber{<:Reactant.ReactantInt}, ps, st::NamedTuple)
return ps.weight[:, x], st
end

# Recurrent Layers
function (r::Lux.Recurrence{False})(x::AnyTracedRArray, ps, st::NamedTuple)
if r.ordering isa Lux.TimeLastIndex ||
(r.ordering isa Lux.BatchLastIndex && ndims(x) == 2)
idxs = ntuple(Returns(Colon()), ndims(x) - 1)
(out, carry), st = r.cell(x[idxs..., 1], ps, st)
T = size(x, ndims(x))
@trace for i in 2:T
(out, carry), st = r.cell((x[idxs..., i], carry), ps, st)
end
return out, st
elseif r.ordering isa Lux.BatchLastIndex
idxs = ntuple(Returns(Colon()), ndims(x) - 2)
(out, carry), st = r.cell(x[idxs..., 1, :], ps, st)
T = size(x, ndims(x) - 1)
@trace for i in 2:T
(out, carry), st = r.cell((x[idxs..., i, :], carry), ps, st)
end
return out, st
else
error("Unknown ordering: $(r.ordering)")
end
end

# TODO: We need to implement the return sequence version as well
5 changes: 0 additions & 5 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,3 @@ Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(ve

# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g

# Embedding
function (e::Lux.Embedding)(x::TracedRNumber{<:Reactant.ReactantInt}, ps, st::NamedTuple)
return ps.weight[:, x], st
end

0 comments on commit 4ae54ef

Please sign in to comment.