Skip to content

Commit

Permalink
fix for b2b
Browse files Browse the repository at this point in the history
  • Loading branch information
CXC2001 committed May 10, 2024
1 parent e6a6648 commit e99851e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 36 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ version = "0.1.0"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLBase = "f0e99cf1-93fa-52ec-9ecc-5026115318e0"
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Unfold = "181c99d8-e21b-4ff3-b70b-c233eddec679"

[compat]
Expand Down
3 changes: 3 additions & 0 deletions src/UnfoldDecode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import Unfold.fit
using MLJ
using MultivariateStats
import MLJBase
using MLBase
using DataFrames
using ProgressMeter
using LinearAlgebra
using Logging # to deactivate some MLJ output

# Write your package code here.
Expand Down
74 changes: 38 additions & 36 deletions src/b2b.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,50 @@
# https://github.com/unfoldtoolbox/Unfold.jl/edit/main/src/solver.jl

# Basic implementation of https://doi.org/10.1016/j.neuroimage.2020.117028
solver_b2b(X, data, cross_val_reps) = solver_b2b(X, data, cross_val_reps = cross_val_reps)
function solver_b2b(
X,
data::AbstractArray{T,3};
cross_val_reps = 10,
multithreading = true,
showprogress=true,
) where {T<:Union{Missing,<:Number}}

X, data = dropMissingEpochs(X, data)


E = zeros(size(data, 2), size(X, 2), size(X, 2))
W = Array{Float64}(undef, size(data, 2), size(X, 2), size(data, 1))

prog = Progress(size(data, 2) * cross_val_reps, 0.1;enabled=showprogress)
@maybe_threads multithreading for m = 1:cross_val_reps
k_ix = collect(Kfold(size(data, 3), 2))
X1 = @view X[k_ix[1], :]
X2 = @view X[k_ix[2], :]

for t = 1:size(data, 2)

Y1 = @view data[:, t, k_ix[1]]
Y2 = @view data[:, t, k_ix[2]]


G = (Y1' \ X1)
H = X2 \ (Y2' * G)

E[t, :, :] += Diagonal(H[diagind(H)])
ProgressMeter.next!(prog; showvalues = [(:time, t), (:cross_val_rep, m)])
end
E[t, :, :] = E[t, :, :] ./ cross_val_reps
W[t, :, :] = (X * E[t, :, :])' / data[:, t, :]
function solver_b2b(X,data::AbstractArray{T,3};kwargs...) where {T<:Union{Missing,<:Number}}
X, data = drop_missing_epochs(X, data)
solver_b2b(X,data; kwargs...)
end

function solver_b2b(
X,
data::AbstractArray{T,3};
cross_val_reps = 10,
multithreading = true,
show_progress=true,
) where {T<:Number}

E = zeros(T,size(data, 2), size(X, 2), size(X, 2))
W = Array{T}(undef, size(data, 2), size(X, 2), size(data, 1))

prog = Progress(size(data, 2) * cross_val_reps;dt=0.1,enabled=show_progress)
Unfold.@maybe_threads multithreading for t = 1:size(data, 2)

for m = 1:cross_val_reps
k_ix = collect(Kfold(size(data, 3), 2))
X1 = @view X[k_ix[1], :] # view(X,k_ix[1],:)
X2 = @view X[k_ix[2], :]

Y1 = @view data[:, t, k_ix[1]]
Y2 = @view data[:, t, k_ix[2]]


G = (Y1' \ X1)
H = X2 \ (Y2' * G)

E[t, :, :] += Diagonal(H[diagind(H)])
ProgressMeter.next!(prog; showvalues = [(:time, t), (:cross_val_rep, m)])
end
E[t, :, :] .= E[t, :, :] ./ cross_val_reps
W[t, :, :] .= (X * E[t, :, :])' / data[:, t, :]


end

# extract diagonal
beta = mapslices(diag, E, dims = [2, 3])
# reshape to conform to ch x time x pred
beta = permutedims(beta, [3 1 2])
modelinfo = Dict("W" => W, "E" => E, "cross_val_reps" => cross_val_reps) # no history implemented (yet?)
return LinearModelFit(beta, modelinfo)
return Unfold.LinearModelFit(beta, modelinfo)
end

0 comments on commit e99851e

Please sign in to comment.