Skip to content

Commit

Permalink
First round of optimisations
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiManyer committed Sep 20, 2024
1 parent 95797f9 commit 71744ed
Show file tree
Hide file tree
Showing 18 changed files with 542 additions and 27 deletions.
6 changes: 6 additions & 0 deletions src/Adaptivity/MacroFEs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ function Arrays.get_children(n::TreeNode,a::FineToCoarseArray)
(similar_tree_node(n,a.rrule),similar_tree_node(n,a.fine_data))
end

function Arrays.testitem(a::FineToCoarseArray{<:FineToCoarseField,A,Nothing}) where {T,A}
n_children = num_subcells(a.rrule)
fine_fields = Fill(testitem(testitem(a.fine_data)),n_children)
return FineToCoarseField(fine_fields,a.rrule)
end

@inline function combine_fine_to_coarse_type(
rr::RefinementRule,fine_data::AbstractVector{<:AbstractVector{T}},ids::FineToCoarseIndices
) where T <: Field
Expand Down
8 changes: 8 additions & 0 deletions src/Arrays/Autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ struct AutoDiffMap{F} <: Map
f::F
end

function return_value(k::AutoDiffMap,ydual,x,cfg::ForwardDiff.GradientConfig{T}) where T
return_cache(k,ydual,x,cfg)
end

function return_cache(k::AutoDiffMap,ydual,x,cfg::ForwardDiff.GradientConfig{T}) where T
ydual isa Real || throw(ForwardDiff.GRAD_ERROR)
result = similar(x, ForwardDiff.valtype(ydual))
Expand All @@ -109,6 +113,10 @@ function evaluate!(result,k::AutoDiffMap,ydual,x,cfg::ForwardDiff.GradientConfig
return result
end

function return_value(k::AutoDiffMap,ydual,x,cfg::ForwardDiff.JacobianConfig{T,V,N}) where {T,V,N}
return_cache(k,ydual,x,cfg)
end

function return_cache(k::AutoDiffMap,ydual,x,cfg::ForwardDiff.JacobianConfig{T,V,N}) where {T,V,N}
ydual isa AbstractArray || throw(ForwardDiff.JACOBIAN_ERROR)
result = similar(ydual, ForwardDiff.valtype(eltype(ydual)), length(ydual), N)
Expand Down
25 changes: 25 additions & 0 deletions src/Arrays/CompressedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ function lazy_map(::typeof(evaluate),::Type{T},g::Union{CompressedArray,Fill}...
end
end

function lazy_map(::typeof(return_value),::Type{T},g::CompressedArray...) where T
if _have_same_ptrs(g)
_lazy_map_compressed_value(g...)
else
LazyArray(T,g...)
end
end

function lazy_map(::typeof(return_value),::Type{T},g::Union{CompressedArray,Fill}...) where T
g_compressed = _find_compressed_ones(g)
if _have_same_ptrs(g_compressed)
g1 = first(g_compressed)
g_all_compressed = map(gi->_compress(gi,g1),g)
_lazy_map_compressed_value(g_all_compressed...)
else
LazyArray(T,g...)
end
end

function _find_compressed_ones(g)
g_compressed = ( gi for gi in g if isa(gi,CompressedArray) )
g_compressed
Expand All @@ -83,6 +102,12 @@ function _lazy_map_compressed(g::CompressedArray...)
CompressedArray(vals,ptrs)
end

function _lazy_map_compressed_value(g::CompressedArray...)
vals = map(return_value, map(gi->gi.values,g)...)
ptrs = first(g).ptrs
CompressedArray(vals,ptrs)
end

function _have_same_ptrs(g)
g1 = first(g)
all(map( gi -> gi.ptrs === g1.ptrs || gi.ptrs == g1.ptrs, g))
Expand Down
36 changes: 27 additions & 9 deletions src/Arrays/LazyArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ function lazy_map(k,f::AbstractArray...)
lazy_map(k,T,f...)
end

#lazy_map(::typeof(evaluate),k::AbstractArray,f::AbstractArray...) = LazyArray(k,f...)

# This is the function to be overload to specialize on the Map f
"""
lazy_map(f,::Type{T},a::AbstractArray...) where T
Expand Down Expand Up @@ -253,14 +251,20 @@ function _sum_lazy_array(cache,a)
r
end

# function testitem(a::LazyArray{A,T} where A) where T
# if length(a) > 0
# first(a)
# else
# gi = testitem(a.maps)
# fi = map(testitem,a.args)
# return_value(gi,fi...)
# end::T
# end

function testitem(a::LazyArray{A,T} where A) where T
if length(a) > 0
first(a)
else
gi = testitem(a.maps)
fi = map(testitem,a.args)
return_value(gi,fi...)
end::T
gi = testitem(a.maps)
fi = map(testitem,a.args)
return_value(gi,fi...)
end

# Particular implementations for Fill
Expand All @@ -279,6 +283,20 @@ function lazy_map(::typeof(evaluate),::Type{T}, f::Fill, a::Fill...) where T
Fill(r, s)
end

function lazy_map(::typeof(return_value),f::Fill, a::Fill...)
ai = map(ai->ai.value,a)
r = return_value(f.value, ai...)
s = _common_size(f, a...)
Fill(r, s)
end

function lazy_map(::typeof(return_value),::Type{T}, f::Fill, a::Fill...) where T
ai = map(ai->ai.value,a)
r = return_value(f.value, ai...)
s = _common_size(f, a...)
Fill(r, s)
end

function _common_size(a::AbstractArray...)
a1, = a
#@check all(map(ai->length(a1) == length(ai),a)) "Array sizes $(map(size,a)) are not compatible."
Expand Down
5 changes: 5 additions & 0 deletions src/Arrays/Maps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,11 @@ struct OperationMap{K,L} <: Map
end
end

function return_value(c::OperationMap,x...)
lx = map(fi -> return_value(fi,x...),c.l)
return_value(c.k,lx...)
end

function return_cache(c::OperationMap,x...)
cl = map(fi -> return_cache(fi,x...),c.l)
lx = map(fi -> return_value(fi,x...),c.l)
Expand Down
1 change: 1 addition & 0 deletions src/CellData/CellData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Gridap.Arrays: lazy_append
import Gridap.Arrays: get_array
import Gridap.Arrays: evaluate!
import Gridap.Arrays: return_cache
import Gridap.Arrays: return_value
import Gridap.Fields: gradient, DIV
import Gridap.Fields: ∇∇
import Gridap.Fields: integrate
Expand Down
20 changes: 14 additions & 6 deletions src/CellData/CellFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,14 @@ function evaluate!(cache,f::CellField,x::CellPoint)
lazy_map(evaluate,cell_field,cell_point)
end

# This is quite important for optimizing the OperationCellField constructor checks
function return_value(f::CellField,x::CellPoint)
_f, _x = _to_common_domain(f,x)
cell_field = get_data(_f)
cell_point = get_data(_x)
lazy_map(return_value,cell_field,cell_point)
end

function _to_common_domain(f::CellField,x::CellPoint)
trian_f = get_triangulation(f)
trian_x = get_triangulation(x)
Expand Down Expand Up @@ -479,12 +487,12 @@ struct OperationCellField{DS} <: CellField
if num_cells(trian) > 0
@check begin
pts = _get_cell_points(args...)
#x = testitem(get_data(pts))
#f = map(ak -> testitem(get_data(ak)), args)
#fx = map(fk -> return_value(fk,x), f)
#r = Fields.BroadcastingFieldOpMap(op.op)(fx...)
ax = map(i->i(pts),args)
axi = map(first,ax)
# x = testitem(get_data(pts))
# f = map(ak -> testitem(get_data(ak)), args)
# fx = map(fk -> return_value(fk,x), f)
# r = Fields.BroadcastingFieldOpMap(op.op)(fx...)
ax = map(ai->return_value(ai,pts),args)
axi = map(testitem,ax)
r = Fields.BroadcastingFieldOpMap(op.op)(axi...)
true
end
Expand Down
64 changes: 64 additions & 0 deletions src/Fields/ApplyOptimizations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ function lazy_map(
i_to_basis_x = lazy_map(evaluate,i_to_basis,x)
lazy_map(LinearCombinationMap(:),i_to_values,i_to_basis_x)
end
function lazy_map(
::typeof(return_value), a::LazyArray{<:Fill{typeof(linear_combination)}}, x::AbstractArray)

i_to_values = a.args[1]
i_to_basis = a.args[2]
i_to_basis_x = lazy_map(return_value,i_to_basis,x)
lazy_map(LinearCombinationMap(:),i_to_values,i_to_basis_x)
end

# We always keep the parent when transposing (needed for some optimizations below)

Expand All @@ -33,6 +41,12 @@ function lazy_map(
i_to_basis_x = lazy_map(evaluate,a.args[1],x)
lazy_map(TransposeMap(),i_to_basis_x)
end
function lazy_map(
::typeof(return_value), a::LazyArray{<:Fill{typeof(transpose)}}, x::AbstractArray)

i_to_basis_x = lazy_map(return_value,a.args[1],x)
lazy_map(TransposeMap(),i_to_basis_x)
end

# Optimization for
#
Expand All @@ -48,6 +62,15 @@ function lazy_map(
fx = lazy_map(evaluate,f,gx)
fx
end
function lazy_map(
::typeof(return_value), a::LazyArray{<:Fill{typeof(∘)}}, x::AbstractArray)

f = a.args[1]
g = a.args[2]
gx = lazy_map(return_value,g,x)
fx = lazy_map(return_value,f,gx)
fx
end

# Optimization for
#
Expand All @@ -63,6 +86,15 @@ function lazy_map(
fx = lazy_map(evaluate,f,gx)
fx
end
function lazy_map(
::typeof(return_value), a::LazyArray{<:Fill{Broadcasting{typeof(∘)}}}, x::AbstractArray)

f = a.args[1]
g = a.args[2]
gx = lazy_map(return_value,g,x)
fx = lazy_map(return_value,f,gx)
fx
end

# Optimization for
#
Expand All @@ -78,6 +110,15 @@ function lazy_map(
op = a.maps.value.op
lazy_map( Broadcasting(op), fx...)
end
function lazy_map(
::typeof(return_value),
a::LazyArray{<:Fill{<:Operation}},
x::AbstractArray)

fx = map( fi->lazy_map(return_value,fi,x), a.args)
op = a.maps.value.op
lazy_map( Broadcasting(op), fx...)
end

# Optimization for
#
Expand All @@ -91,6 +132,13 @@ function lazy_map(
op = a.maps.value.f.op
lazy_map(BroadcastingFieldOpMap(op),fx...)
end
function lazy_map(
::typeof(return_value), a::LazyArray{<:Fill{<:Broadcasting{<:Operation}}}, x::AbstractArray)

fx = map( fi->lazy_map(return_value,fi,x), a.args)
op = a.maps.value.f.op
lazy_map(BroadcastingFieldOpMap(op),fx...)
end

# Optimization for
#
Expand Down Expand Up @@ -315,6 +363,22 @@ Arrays.getindex!(cache,a::MemoArray{T,N},i::Vararg{Integer,N}) where {T,N} = get
Arrays.testitem(a::MemoArray) = testitem(a.parent)
Arrays.get_array(a::MemoArray) = get_array(a.parent)

function lazy_map(::typeof(return_value),a::MemoArray,x::AbstractArray{<:Point})
key = (:return_value,objectid(x))
if ! haskey(a.memo,key)
a.memo[key] = lazy_map(return_value,a.parent,x)
end
a.memo[key]
end

function lazy_map(::typeof(return_value),a::MemoArray,x::AbstractArray{<:AbstractArray{<:Point}})
key = (:return_value,objectid(x))
if ! haskey(a.memo,key)
a.memo[key] = lazy_map(return_value,a.parent,x)
end
a.memo[key]
end

function lazy_map(::typeof(evaluate),a::MemoArray,x::AbstractArray{<:Point})
key = (:evaluate,objectid(x))
if ! haskey(a.memo,key)
Expand Down
12 changes: 12 additions & 0 deletions src/Fields/ArrayBlocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,18 @@ function evaluate!(cache,k::Broadcasting{<:Operation},h::Field,f::ArrayBlock)
g
end

function return_value(k::Broadcasting{<:Operation},h::ArrayBlock{A,N},f::ArrayBlock{B,N}) where {A,B,N}
i = findfirst(h.touched)
j = findfirst(f.touched)
@notimplementedif (isnothing(i) || isnothing(j))
ci = return_value(k,h.array[i],f.array[j])
a = Array{typeof(ci),N}(undef,size(f.array))
fill!(a,ci)
touched = Array{Bool,N}(undef,size(f.array))
touched .= f.touched .&& h.touched
ArrayBlock(a,touched)
end

function return_value(k::Broadcasting{<:Operation},h::ArrayBlock,f::ArrayBlock)
evaluate(k,h,f)
end
Expand Down
Loading

0 comments on commit 71744ed

Please sign in to comment.