Skip to content

Commit

Permalink
Merge pull request #47 from unfoldtoolbox/preprocFunc
Browse files Browse the repository at this point in the history
Allow to specify preprocessing function
  • Loading branch information
ReneSkukies authored Jan 18, 2024
2 parents 82e2bdc + 8e7ce94 commit d763c13
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
"""
- removeTimeexpandedXs (true): Removes the timeexpanded designmatrix which significantly reduces the memory-consumption. This Xs is rarely needed, but can be recovered (look into the Unfold.load function)
extractData (function) - specify the function that translate the MNE Raw object to an data array. Default is `rawToData` which uses get_data and allows to pick `channels` - see @Ref(`rawToData`). The optional kw- arguments (e.g. channels) need to be specified directly in the `runUnfold` function as kw-args
"""
function runUnfold(dataDF, eventsDF, bfDict; channels::AbstractVector{<:Union{String, Integer}}=[], eventcolumn="event",removeTimeexpandedXs=true)
function runUnfold(dataDF, eventsDF, bfDict; eventcolumn="event",removeTimeexpandedXs=true, extractData = rawToData,kwargs...)
subjects = unique(dataDF.subject)

resultsDF = DataFrame()
Expand All @@ -16,10 +18,11 @@ function runUnfold(dataDF, eventsDF, bfDict; channels::AbstractVector{<:Union{St
# Get current subject
raw = @subset(dataDF, :subject .== sub).data

tmpData = pyconvert(Array,raw[1].get_data(picks=pylist(channels),units="uV"))

tmpEvents = @subset(eventsDF, :subject .== sub)

tmpData = extractData(raw[1],tmpEvents;kwargs...)


# Fit Model
m = fit(UnfoldModel,bfDict,tmpEvents,tmpData; eventcolumn=eventcolumn);

Expand All @@ -35,6 +38,14 @@ function runUnfold(dataDF, eventsDF, bfDict; channels::AbstractVector{<:Union{St
return resultsDF
end

# Function to run Preprocessing functions on data
function rawToData(raw,tmpEvents;channels::AbstractVector{<:Union{String, Integer}}=[])
return pyconvert(Array,raw.get_data(picks=pylist(channels),units="uV"))
end

# Calculate Grand average; this is likely a TODO
# Commented this out for now as this might go into UnfoldStats; R.S. 18/01/24
#=
function calculateGA(resultsDF; channels=:false)
GA = @chain resultsDF begin
# TODO: check if this works
Expand All @@ -44,7 +55,11 @@ function calculateGA(resultsDF; channels=:false)
# need to check which variables to use
@by([:basisname,:coefname,:time, :channel], :estimate = mean(estimate))
end
end
=#

#=
# Function to run unfold on epoched data
function runUnfold(DataDF, EventsDF, formula, sfreq, τ = (-0.3,1.); channels::Union{Nothing, String, Integer}=nothing)
Expand Down Expand Up @@ -83,4 +98,4 @@ function runUnfold(DataDF, EventsDF, formula, sfreq, τ = (-0.3,1.); channels::U
end
return resultsDF
end
=#
=#

0 comments on commit d763c13

Please sign in to comment.