From e99851e09a45b95344174fb2b0f8b08a7aa9d444 Mon Sep 17 00:00:00 2001 From: CXC2001 Date: Fri, 10 May 2024 15:36:25 +0200 Subject: [PATCH] fix for b2b --- Project.toml | 3 ++ src/UnfoldDecode.jl | 3 ++ src/b2b.jl | 74 +++++++++++++++++++++++---------------------- 3 files changed, 44 insertions(+), 36 deletions(-) diff --git a/Project.toml b/Project.toml index 6b997b3c..1d218709 100644 --- a/Project.toml +++ b/Project.toml @@ -5,10 +5,13 @@ version = "0.1.0" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +MLBase = "f0e99cf1-93fa-52ec-9ecc-5026115318e0" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Unfold = "181c99d8-e21b-4ff3-b70b-c233eddec679" [compat] diff --git a/src/UnfoldDecode.jl b/src/UnfoldDecode.jl index fa06349e..99f1a2b7 100644 --- a/src/UnfoldDecode.jl +++ b/src/UnfoldDecode.jl @@ -6,7 +6,10 @@ import Unfold.fit using MLJ using MultivariateStats import MLJBase +using MLBase using DataFrames +using ProgressMeter +using LinearAlgebra using Logging # to deactivate some MLJ output # Write your package code here. diff --git a/src/b2b.jl b/src/b2b.jl index 4b24005f..a70dc6a3 100644 --- a/src/b2b.jl +++ b/src/b2b.jl @@ -2,42 +2,44 @@ # https://github.com/unfoldtoolbox/Unfold.jl/edit/main/src/solver.jl # Basic implementation of https://doi.org/10.1016/j.neuroimage.2020.117028 -solver_b2b(X, data, cross_val_reps) = solver_b2b(X, data, cross_val_reps = cross_val_reps) -function solver_b2b( - X, - data::AbstractArray{T,3}; - cross_val_reps = 10, - multithreading = true, - showprogress=true, -) where {T<:Union{Missing,<:Number}} - - X, data = dropMissingEpochs(X, data) - - - E = zeros(size(data, 2), size(X, 2), size(X, 2)) - W = Array{Float64}(undef, size(data, 2), size(X, 2), size(data, 1)) - - prog = Progress(size(data, 2) * cross_val_reps, 0.1;enabled=showprogress) - @maybe_threads multithreading for m = 1:cross_val_reps - k_ix = collect(Kfold(size(data, 3), 2)) - X1 = @view X[k_ix[1], :] - X2 = @view X[k_ix[2], :] - - for t = 1:size(data, 2) - - Y1 = @view data[:, t, k_ix[1]] - Y2 = @view data[:, t, k_ix[2]] - - - G = (Y1' \ X1) - H = X2 \ (Y2' * G) - - E[t, :, :] += Diagonal(H[diagind(H)]) - ProgressMeter.next!(prog; showvalues = [(:time, t), (:cross_val_rep, m)]) - end - E[t, :, :] = E[t, :, :] ./ cross_val_reps - W[t, :, :] = (X * E[t, :, :])' / data[:, t, :] +function solver_b2b(X,data::AbstractArray{T,3};kwargs...) where {T<:Union{Missing,<:Number}} + X, data = drop_missing_epochs(X, data) + solver_b2b(X,data; kwargs...) +end +function solver_b2b( + X, + data::AbstractArray{T,3}; + cross_val_reps = 10, + multithreading = true, + show_progress=true, + ) where {T<:Number} + + E = zeros(T,size(data, 2), size(X, 2), size(X, 2)) + W = Array{T}(undef, size(data, 2), size(X, 2), size(data, 1)) + + prog = Progress(size(data, 2) * cross_val_reps;dt=0.1,enabled=show_progress) + Unfold.@maybe_threads multithreading for t = 1:size(data, 2) + + for m = 1:cross_val_reps + k_ix = collect(Kfold(size(data, 3), 2)) + X1 = @view X[k_ix[1], :] # view(X,k_ix[1],:) + X2 = @view X[k_ix[2], :] + + Y1 = @view data[:, t, k_ix[1]] + Y2 = @view data[:, t, k_ix[2]] + + + G = (Y1' \ X1) + H = X2 \ (Y2' * G) + + E[t, :, :] += Diagonal(H[diagind(H)]) + ProgressMeter.next!(prog; showvalues = [(:time, t), (:cross_val_rep, m)]) + end + E[t, :, :] .= E[t, :, :] ./ cross_val_reps + W[t, :, :] .= (X * E[t, :, :])' / data[:, t, :] + + end # extract diagonal @@ -45,5 +47,5 @@ function solver_b2b( # reshape to conform to ch x time x pred beta = permutedims(beta, [3 1 2]) modelinfo = Dict("W" => W, "E" => E, "cross_val_reps" => cross_val_reps) # no history implemented (yet?) - return LinearModelFit(beta, modelinfo) + return Unfold.LinearModelFit(beta, modelinfo) end