Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regularization #39

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions src/regularizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ julia> reg([1 2 3; 4 5 6; 7 8 9])
```
"""
function Tikhonov(;num_dims=2, sum_dims=1:num_dims, weights=[1, 1], step=1, mode="laplace")
if weights == nothing
if isnothing(weights)
weights = ones(Int, num_dims)
end
if mode == "laplace"
Expand Down Expand Up @@ -235,7 +235,7 @@ julia> reg([1 2 3; 4 5 6; 7 8 9])
"""
function GR(; num_dims=2, sum_dims=1:num_dims, weights=[1, 1], step=1,
mode="forward", ϵ=1f-8)
if weights == nothing
if isnothing(weights)
weights = ones(Int, num_dims)
end
if mode == "central"
Expand Down Expand Up @@ -273,7 +273,17 @@ indicating over which dimensions we must sum over.
"""
function generate_TV(num_dims, sum_dims_arr, weights, ind1, ind2, ϵ=1f-8; debug=false)
out, add = [], []
for (d, w) in zip(sum_dims_arr, weights)
# just for now. This should be changed to adapt to the number of dimensions present in `arr`
if isnothing(num_dims)
num_dims=2
end
if isnothing(sum_dims_arr)
sum_dims_arr = 1:num_dims
end
if isnothing(weights)
weights = ones(Int, num_dims)
end
for (d, w) in zip(sum_dims_arr, weights)
inds1, inds2 = generate_indices(num_dims, d, ind1, ind2)
push!(add, :($w * abs2(arr[$(inds1...)] - arr[$(inds2...)])))
end
Expand All @@ -286,6 +296,8 @@ function generate_TV(num_dims, sum_dims_arr, weights, ind1, ind2, ϵ=1f-8; debug
return out
end

# a hack to find out whether the input arr `arr` is a CuArray, without needing to import CuArray
is_cuda_arr(arr) = startswith("$(typeof(arr))","CuArray")

"""
TV(; <keyword arguments>)
Expand All @@ -312,12 +324,9 @@ julia> reg([1 2 3; 4 5 6; 7 8 9])
12.649111f0
```
"""
function TV(; num_dims=2, sum_dims=1:num_dims, weights=nothing, step=1, mode="forward", ϵ=1f-8)
function TV(; num_dims=nothing, sum_dims=nothing, weights=nothing, step=1, mode="forward", ϵ=1f-8)

if weights == nothing
weights = ones(Int, num_dims)
end


if mode == "central"
total_var = @eval arr -> ($(generate_TV(num_dims, sum_dims, weights,
step, (-1) * step, ϵ)...))
Expand All @@ -327,7 +336,10 @@ function TV(; num_dims=2, sum_dims=1:num_dims, weights=nothing, step=1, mode="fo
else
throw(ArgumentError("The provided mode is not valid."))
end
return total_var

# ToDo: This needs more tweeking such that both methods have the same signature!
total_var_cuda = TV_cuda(num_dims=num_dims, weights=weights, ϵ=ϵ)
return arr -> is_cuda_arr(arr) ? total_var_cuda(arr) : total_var(arr)
end


Expand Down
34 changes: 27 additions & 7 deletions src/regularizer_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ f_inds(rs, b) = ntuple(i -> i == b ? rs[i] .+ 1 : rs[i], length(rs))


"""
TV_cuda(; num_dims=2)
TV_cuda(; num_dims=nothing, weights=nothing, ϵ=1f-8)
This function returns a function to calculate the Total Variation regularizer
of a 2 or 3 dimensional array.
`num_dims` can be either `2` or `3`.

# Arguments
`num_dims` can be either `2` or `3` or `nothing` in which case the array dimension is assumed upon use.
`weights` specifies the weight along each dimension. By default a weight of one is assumed along each dimension.
`ϵ` specifies a constant which allows to smoothly vary between TV and grad^2 regularization: L = sqrt.(grad^2+ϵ).

```julia-repl
julia> using CUDA
Expand All @@ -18,19 +22,32 @@ julia> reg(CuArray([1 2 3; 4 5 6; 7 8 9]))
12.649111f0
```
"""
function TV_cuda(; num_dims=2, weights=ones(Float32, num_dims), ϵ=1f-8)
if num_dims == 3
function TV_cuda(; num_dims=nothing, weights=nothing, ϵ=1f-8)
if isnothing(num_dims)
return arr -> TV_view(arr, weights, ϵ)
elseif num_dims == 3
return arr -> TV_3D_view(arr, weights, ϵ)
elseif num_dims == 2
return arr -> TV_2D_view(arr, weights, ϵ)
else
throw(ArgumentError("num_dims must be 2 or 3"))
throw(ArgumentError("num_dims must be nothing or 2 or 3 "))
end

return reg_TV
end

function TV_2D_view(arr::AbstractArray{T, N}, weights, ϵ=1f-8) where {T, N}
function TV_view(arr::AbstractArray{T, 2}, weights=nothing, ϵ=1f-8) where {T}
return TV_2D_view(arr, weights, ϵ)
end

function TV_view(arr::AbstractArray{T, 3}, weights=nothing, ϵ=1f-8) where {T}
return TV_3D_view(arr, weights, ϵ)
end

function TV_2D_view(arr::AbstractArray{T, N}, weights=nothing, ϵ=1f-8) where {T, N}
if isnothing(weights)
weights = ones(Float32, ndims(arr))
end
as = ntuple(i -> axes(arr, i), Val(N))
rs = map(x -> first(x):last(x)-1, as)
arr0 = view(arr, f_inds(rs, 0)...)
Expand All @@ -39,7 +56,10 @@ function TV_2D_view(arr::AbstractArray{T, N}, weights, ϵ=1f-8) where {T, N}
return @fastmath sum(sqrt.(ϵ .+ weights[1] .* (arr1 .- arr0).^2 .+ weights[2] .* (arr0 .- arr2).^2))
end

function TV_3D_view(arr::AbstractArray{T, N}, weights, ϵ=1f-8) where {T, N}
function TV_3D_view(arr::AbstractArray{T, N}, weights=nothing, ϵ=1f-8) where {T, N}
if isnothing(weights)
weights = ones(Float32, ndims(arr))
end
as = ntuple(i -> axes(arr, i), Val(N))
rs = map(x -> first(x):last(x)-1, as)
arr0 = view(arr, f_inds(rs, 0)...)
Expand Down