diff --git a/src/regularizer.jl b/src/regularizer.jl index 070b1494..4ca890fc 100644 --- a/src/regularizer.jl +++ b/src/regularizer.jl @@ -157,7 +157,7 @@ julia> reg([1 2 3; 4 5 6; 7 8 9]) ``` """ function Tikhonov(;num_dims=2, sum_dims=1:num_dims, weights=[1, 1], step=1, mode="laplace") - if weights == nothing + if isnothing(weights) weights = ones(Int, num_dims) end if mode == "laplace" @@ -235,7 +235,7 @@ julia> reg([1 2 3; 4 5 6; 7 8 9]) """ function GR(; num_dims=2, sum_dims=1:num_dims, weights=[1, 1], step=1, mode="forward", ϵ=1f-8) - if weights == nothing + if isnothing(weights) weights = ones(Int, num_dims) end if mode == "central" @@ -273,7 +273,17 @@ indicating over which dimensions we must sum over. """ function generate_TV(num_dims, sum_dims_arr, weights, ind1, ind2, ϵ=1f-8; debug=false) out, add = [], [] - for (d, w) in zip(sum_dims_arr, weights) + # just for now. This should be changed to adapt to the number of dimensions present in `arr` + if isnothing(num_dims) + num_dims=2 + end + if isnothing(sum_dims_arr) + sum_dims_arr = 1:num_dims + end + if isnothing(weights) + weights = ones(Int, num_dims) + end + for (d, w) in zip(sum_dims_arr, weights) inds1, inds2 = generate_indices(num_dims, d, ind1, ind2) push!(add, :($w * abs2(arr[$(inds1...)] - arr[$(inds2...)]))) end @@ -286,6 +296,8 @@ function generate_TV(num_dims, sum_dims_arr, weights, ind1, ind2, ϵ=1f-8; debug return out end +# a hack to find out whether the input arr `arr` is a CuArray, without needing to import CuArray +is_cuda_arr(arr) = startswith("$(typeof(arr))","CuArray") """ TV(; ) @@ -312,12 +324,9 @@ julia> reg([1 2 3; 4 5 6; 7 8 9]) 12.649111f0 ``` """ -function TV(; num_dims=2, sum_dims=1:num_dims, weights=nothing, step=1, mode="forward", ϵ=1f-8) +function TV(; num_dims=nothing, sum_dims=nothing, weights=nothing, step=1, mode="forward", ϵ=1f-8) - if weights == nothing - weights = ones(Int, num_dims) - end - + if mode == "central" total_var = @eval arr -> ($(generate_TV(num_dims, sum_dims, weights, step, (-1) * step, ϵ)...)) @@ -327,7 +336,10 @@ function TV(; num_dims=2, sum_dims=1:num_dims, weights=nothing, step=1, mode="fo else throw(ArgumentError("The provided mode is not valid.")) end - return total_var + + # ToDo: This needs more tweeking such that both methods have the same signature! + total_var_cuda = TV_cuda(num_dims=num_dims, weights=weights, ϵ=ϵ) + return arr -> is_cuda_arr(arr) ? total_var_cuda(arr) : total_var(arr) end diff --git a/src/regularizer_cuda.jl b/src/regularizer_cuda.jl index 1eaf82e5..6c9c791e 100644 --- a/src/regularizer_cuda.jl +++ b/src/regularizer_cuda.jl @@ -4,10 +4,14 @@ f_inds(rs, b) = ntuple(i -> i == b ? rs[i] .+ 1 : rs[i], length(rs)) """ - TV_cuda(; num_dims=2) + TV_cuda(; num_dims=nothing, weights=nothing, ϵ=1f-8) This function returns a function to calculate the Total Variation regularizer of a 2 or 3 dimensional array. -`num_dims` can be either `2` or `3`. + +# Arguments +`num_dims` can be either `2` or `3` or `nothing` in which case the array dimension is assumed upon use. +`weights` specifies the weight along each dimension. By default a weight of one is assumed along each dimension. +`ϵ` specifies a constant which allows to smoothly vary between TV and grad^2 regularization: L = sqrt.(grad^2+ϵ). ```julia-repl julia> using CUDA @@ -18,19 +22,32 @@ julia> reg(CuArray([1 2 3; 4 5 6; 7 8 9])) 12.649111f0 ``` """ -function TV_cuda(; num_dims=2, weights=ones(Float32, num_dims), ϵ=1f-8) - if num_dims == 3 +function TV_cuda(; num_dims=nothing, weights=nothing, ϵ=1f-8) + if isnothing(num_dims) + return arr -> TV_view(arr, weights, ϵ) + elseif num_dims == 3 return arr -> TV_3D_view(arr, weights, ϵ) elseif num_dims == 2 return arr -> TV_2D_view(arr, weights, ϵ) else - throw(ArgumentError("num_dims must be 2 or 3")) + throw(ArgumentError("num_dims must be nothing or 2 or 3 ")) end return reg_TV end -function TV_2D_view(arr::AbstractArray{T, N}, weights, ϵ=1f-8) where {T, N} +function TV_view(arr::AbstractArray{T, 2}, weights=nothing, ϵ=1f-8) where {T} + return TV_2D_view(arr, weights, ϵ) +end + +function TV_view(arr::AbstractArray{T, 3}, weights=nothing, ϵ=1f-8) where {T} + return TV_3D_view(arr, weights, ϵ) +end + +function TV_2D_view(arr::AbstractArray{T, N}, weights=nothing, ϵ=1f-8) where {T, N} + if isnothing(weights) + weights = ones(Float32, ndims(arr)) + end as = ntuple(i -> axes(arr, i), Val(N)) rs = map(x -> first(x):last(x)-1, as) arr0 = view(arr, f_inds(rs, 0)...) @@ -39,7 +56,10 @@ function TV_2D_view(arr::AbstractArray{T, N}, weights, ϵ=1f-8) where {T, N} return @fastmath sum(sqrt.(ϵ .+ weights[1] .* (arr1 .- arr0).^2 .+ weights[2] .* (arr0 .- arr2).^2)) end -function TV_3D_view(arr::AbstractArray{T, N}, weights, ϵ=1f-8) where {T, N} +function TV_3D_view(arr::AbstractArray{T, N}, weights=nothing, ϵ=1f-8) where {T, N} + if isnothing(weights) + weights = ones(Float32, ndims(arr)) + end as = ntuple(i -> axes(arr, i), Val(N)) rs = map(x -> first(x):last(x)-1, as) arr0 = view(arr, f_inds(rs, 0)...)