From c5a3ab9760f6f2b8ee495a0a7a0c4471aa86f555 Mon Sep 17 00:00:00 2001 From: Dhruva2 Date: Mon, 24 Jul 2023 19:23:39 +0100 Subject: [PATCH] fixed bugs --- src/BasicUpdates.jl | 133 ---------------------------- src/InteractionSpecification.jl | 7 +- src/Particles.jl | 150 -------------------------------- src/SimulationHelper.jl | 17 ---- src/SimulationStatistics.jl | 67 -------------- 5 files changed, 3 insertions(+), 371 deletions(-) delete mode 100644 src/BasicUpdates.jl delete mode 100644 src/Particles.jl delete mode 100644 src/SimulationStatistics.jl diff --git a/src/BasicUpdates.jl b/src/BasicUpdates.jl deleted file mode 100644 index 3d9fd2a..0000000 --- a/src/BasicUpdates.jl +++ /dev/null @@ -1,133 +0,0 @@ -""" -State update and measurement structs for general use. - Add specific methods for entity subtypes -""" - -mutable struct StateUpdate{L} <: Interaction - last::L -end -mutable struct Measurement{L} <: Interaction - last::L -end - -observation() = nothing -state() = nothing - - -StateUpdate(env::Entity) = StateUpdate(env |> state) -Measurement(env::Entity) = Measurement(env |> observation) - -summary(s::StateUpdate) = s.last -summary(m::Measurement) = m.last - - -mutable struct MovingCovarianceEstimator{N<:Number} <: Interaction - mse::N - cov::N - innov::N -end -MovingCovarianceEstimator(; innov=5.0) = MovingCovarianceEstimator(0.0, 1.0, innov) -summary(m::MovingCovarianceEstimator) = m.cov - -function (m::MovingCovarianceEstimator)(time, learner::Entity, effected_updates, yₜ, args...; kwargs...) - λ = 1.0 / m.innov - m.mse = (1.0 - λ) * m.mse + λ * ( - (sum(abs2, yₜ - state(learner))) - ) - m.cov = m.mse - stochasticity(learner) -end - - -""" -We use OptimalKalmanGain instead of OptimalKalmanGain2 because the low pass filter irons out a lot of fluctuations. -""" -mutable struct OptimalKalmanGain{E<:Entity,N<:Number} <: Interaction - env::E - last::N - cov::N - innov::N -end -OptimalKalmanGain(env::Entity; innov=5.0) = OptimalKalmanGain(env, 0.0, 0.0, innov) - -""" -1. low pass filters squared error of learner and environmental states, to give an estimate of covariance -2. puts this covariance estimate into the the formula for kalman gain, along with true environmental volatility /stochasticity -""" -function (k::OptimalKalmanGain)(time, learner::Entity, effected_updates, args...; kwargs...) - env = k.env - λ = 1.0 / k.innov - k.cov = (1.0 - λ) * k.cov + λ * (sum(abs2, state(learner) - state(k.env))) - k.last = (k.cov + volatility(env)) / (k.cov + volatility(env) + stochasticity(env)) - nothing -end -summary(k::OptimalKalmanGain) = k.last - - -mutable struct OptimalKalmanGain2{E<:Entity, N<:Number} <: Interaction - env::E - last::N -end - -function (k::OptimalKalmanGain2)(time, learner::Entity, effected_updates, args...; kwargs...) - env = k.env - cov = sum(abs2, state(env) - state(learner)) - k.last = (cov + volatility(env)) / (cov + volatility(env) + stochasticity(env)) - nothing -end -OptimalKalmanGain2(env::Entity) = OptimalKalmanGain2(env, 0.0) -summary(k::OptimalKalmanGain2) = k.last - -mutable struct KalmanGainDiff{L,O<:OptimalKalmanGain} <: Interaction - okg::O - last::L -end -KalmanGainDiff(okg::OptimalKalmanGain) = KalmanGainDiff(okg, 0.0) - -function (k::KalmanGainDiff)(time, learner::Entity, effected_updates, args...; kwargs...) - k.okg(time, learner, effected_updates, args...) - k.last = k.okg.last - kalman_gain(learner) - nothing -end - -summary(k::KalmanGainDiff) = k.last - -mutable struct CumulativeSquaredError{E<:Entity,N<:Number, F<:Function} <: Interaction - env::E - of_what::F - error::N - divide_by::N -end - -CumulativeSquaredError(env::Entity; divide_by=1.0) = CumulativeSquaredError(env, state, 0.0, divide_by) -CumulativeSquaredError(env::Entity, of_what::Function; divide_by=1.0) = CumulativeSquaredError(env, of_what, 0.0, divide_by) - - -function (c::CumulativeSquaredError)(time, learner::Entity, effected_updates, args...; kwargs...) - n = c.divide_by - env = c.env - c.error += (1.0 / n) * sum(abs2, c.of_what(env) - c.of_what(learner)) - nothing -end -summary(c::CumulativeSquaredError) = c.error - -""" -change_func(t) gives the change to apply_to(env) at time t. The change NOT the actual value -""" -mutable struct ExternalFunctionUpdate{F1<:Function,L} <: Interaction - change_func::F1 - apply_to::Symbol - last::L -end - -ExternalFunctionUpdate(cf::Function, apply_to::Symbol, env::Entity) = ExternalFunctionUpdate(cf, apply_to, getfield(env, apply_to)) - -function (e::ExternalFunctionUpdate)(time, env::Entity, effected_updates, args...; kwargs...) - e.last = e.change_func(time) - new = e.last + getfield(env, e.apply_to) - setfield!(env, e.apply_to, new) - nothing -end - -summary(e::ExternalFunctionUpdate) = e.last - - diff --git a/src/InteractionSpecification.jl b/src/InteractionSpecification.jl index fc10947..8dfa481 100644 --- a/src/InteractionSpecification.jl +++ b/src/InteractionSpecification.jl @@ -11,6 +11,7 @@ struct InteractionSpecification{I<:Interaction,S,T,N<:Function} <: Specification record_what::N end +record_nothing(::Any) = nothing struct ObservationSpecification{I<:Thing,T,N<:Function} <: Specification of_what::I @@ -22,8 +23,8 @@ ObservationSpecification(i::Interaction, times, what) = InteractionSpecification Record(t::Thing, times, what) = ObservationSpecification(t, times, what) ### IE if you don't want to record, use the pair syntax -Base.Pair(i::Interaction, t) = InteractionSpecification(i, t, empty(t), summary) #update don't record -Base.Pair(e::Entity, t) = ObservationSpecification(e, t, summary) +Base.Pair(i::Interaction, t) = InteractionSpecification(i, t, empty(t), record_nothing) #update don't record +Base.Pair(e::Entity, t) = ObservationSpecification(e, t, record_nothing) Record(os::ObservationSpecification, what) = ObservationSpecification( os.of_what, @@ -38,7 +39,6 @@ Record(is::InteractionSpecification, t, what::Function) = InteractionSpecificati t, what ) -Record(is::InteractionSpecification, t) = Record(is, t, summary) Record(is::InteractionSpecification, t, what::Symbol) = Record(is, t, r -> getfield(r, what)) Record(is::InteractionSpecification, what::Function) = InteractionSpecification( @@ -69,7 +69,6 @@ nb this == is a bit misleading, but useful for subsequent code. Interaction Spec ==(i1::Specification, i2::Specification) = (==(i1.of_what, i2.of_what)) && ==(i1.record_what, i2.record_what) -summary(i::Interaction) = i diff --git a/src/Particles.jl b/src/Particles.jl deleted file mode 100644 index 093dcbb..0000000 --- a/src/Particles.jl +++ /dev/null @@ -1,150 +0,0 @@ -""" -Particle structs -""" - -abstract type Particle <: Entity end - -mutable struct ParticleCollection{P<:Entity, N<:Number} <: Entity - particles::Vector{P} - weights::Vector{N} -end - - -ParticleCollection(ps::Vector{P}) where {P<:Entity} = ParticleCollection(ps, ones(length(ps)) ./ length(ps)) - - -states(pc::ParticleCollection) = [p.state for p in pc.particles] -weights(pc::ParticleCollection) = pc.weights - -function collection_average(pc::ParticleCollection, f::Function) - sum(zip(pc.particles, pc.weights)) do (p, w) - f(p) * w - end -end - -# state(pc::ParticleCollection) = collection_average(pc, state) -# kalman_gain(pc::ParticleCollection) = collection_average(pc, kalman_gain) - -function collection_variance(pc::ParticleCollection, f::Function) - second_moment = sum(zip(pc.particles, pc.weights)) do (p, w) - f(p)^2 * w - end - return second_moment - collection_average(pc, f)^2 -end - - - -function get_particle_updates(pc::ParticleCollection, interactions::Vector{Vector{T}}) where {T<:Specification} - funcs, records = [p(interaction) for (p, interaction) in zip(pc.particles, interactions)] |> x -> (first.(x), last.(x)) - - function func(args...) - foreach(funcs) do f - f(args...) - end - nothing - end - - return func, records -end - - - - - -# fundamentally different operations for updating LR vs SV particles -struct ParticleWeightUpdate <: Interaction end - -function (w::ParticleWeightUpdate)(time, pc::ParticleCollection, effected_updates, yₜ, args...) - uw = unscaled_weights(pc.particles, yₜ) - pc.weights = uw / norm(uw, 1) - nothing -end - - - - - -""" -Systematic resampling - -generate a cdf from the weights: each point in the interval [0, 1] -""" - -abstract type Resampler <: Interaction end -struct SystematicResampler{N<:Number,I<:Integer,P<:Particle} <: Resampler - particle_copies::Vector{P} - bins::Vector{N} - bin_positions::Vector{I} - samples::Vector{N} - cutoff::N - resampled::typeof(Vector{Bool}(undef, 1)) -end - -function SystematicResampler(pc::ParticleCollection, cutoff) - _N = length(pc.particles) - T = eltype(pc.weights) - return SystematicResampler( - deepcopy(pc.particles), - cumsum(pc.weights), - Vector{Int64}(undef, _N), - Vector{T}(undef, _N), - cutoff, - [false] - ) -end - -resampled(s::SystematicResampler) = s.resampled[1] - -function (s::SystematicResampler)(time, pc::ParticleCollection, effected_updates, yₜ, args...) - - Neff = 1.0 / sum(pc.weights .^ 2) - _N = length(s.bins) - - if (Neff / _N) > s.cutoff - # println("not resampling at time", time) - s.resampled[1] = false - return nothing - else - # @info "resampling at time $time" - s.resampled[1] = true - end - - s.bins[:] = cumsum(pc.weights) - - if any(isnan.(s.bins)) - println(pc.weights) - end - - # a lin range with a bit of jiggle - map!(s.samples, 0:_N-1) do i - return (i + rand(Uniform(0.0, 1.0))) / _N - end - - # find the indices of the old samplers to map to the new samplers - map!(s.bin_positions, s.samples) do m - findfirst(1:_N) do i - # (i == _N) && (println(m); println(s.bins[end])) - if i == 1 - return m < s.bins[i] - else - return (m > s.bins[i-1]) && (m < s.bins[i]) - end - end - end - - retained_particles_indxs = unique(s.bin_positions) - num_retained_particles = length(retained_particles_indxs) - which_retained_p(i::Int64) = findfirst(x -> x == s.bin_positions[i], retained_particles_indxs) - # i -> bin_position[i] -> which element of unique is bin_positions[i] - - for (copy, particle) in zip(s.particle_copies[1:num_retained_particles], pc.particles[retained_particles_indxs]) - provide!(copy, particle) - end - - for (i, particle) in enumerate(pc.particles) - provide!(particle, s.particle_copies[which_retained_p(i)]) - end - - # @info "new particles spawned from $(length(unique(s.bin_positions))), new particles" - nothing -end diff --git a/src/SimulationHelper.jl b/src/SimulationHelper.jl index 8c96b33..a65142f 100644 --- a/src/SimulationHelper.jl +++ b/src/SimulationHelper.jl @@ -40,24 +40,7 @@ include("RecordBuilder.jl") export Recorder, name, get_record, difference -# include("BasicUpdates.jl") - -# export StateUpdate, Measurement, MovingCovarianceEstimator, OptimalKalmanGain, OptimalKalmanGain2, KalmanGainDiff, CumulativeSquaredError, ExternalFunctionUpdate - -export summary - include("PlotBuilder.jl") -include("Particles.jl") - -export Particle, ParticleCollection, ParticleWeightUpdate, Resampler, SystematicResampler - -export states, weights, resampled -# export collection_average, collection_variance, get_particle_updates - -include("SimulationStatistics.jl") - -# export do_repeats, param_against_difference, param_against_summary, param_against_summary_repeated, - end diff --git a/src/SimulationStatistics.jl b/src/SimulationStatistics.jl deleted file mode 100644 index ba19850..0000000 --- a/src/SimulationStatistics.jl +++ /dev/null @@ -1,67 +0,0 @@ -""" -These functions are convenient for plotting simulation hyperparameters / repeats against statistics (e.g. learning performance.) - -They all assume a simulation function of the form: - -f(hyperparameters::Dict) = _, _, records - -where records is something from which all useful statistics can be taken with functions. - -Statistics type functions take a simulation function, and output a function of the form: - range, outs = statistic(simulation::Function, hyperparameters) - - -""" - -""" - x -> f(x, ...) over range - f should return two outputs. of which the last is outputs over one repeat. -""" -function do_repeats(statistic::Function, range, num_repeats) - outputs = map(1:num_repeats) do i - map(range) do r - _, output = statistic(r) - return output - end - end - return range, outputs -end - - -""" -The final arguments should be functions that take the overall vector of records, and provide the specific record required -""" -function param_against_difference(f::Function, hypers::Dict, pname::Symbol, prange, record1::Function, record2::Function) - diffs = map(prange) do r - hypers[pname] = r - _, _, records = f(hypers) - return difference(record1(records), record2(records)) - end - return prange, diffs -end - -function param_against_summary(f::Function, hypers, pname::Symbol, prange, extract_summary::Function) - diffs = map(prange) do r - hypers[pname] = r - _, _, records = f(hypers) - return extract_summary(records) - end - return prange, diffs -end - -function param_against_summary_repeated(f::Function, hypers, pname::Symbol, prange, extract_summary::Function, repeats::Integer) - - stat(x) = param_against_summary(f, hypers, pname, x, extract_summary) - return do_repeats(stat, prange, repeats) -end - -""" - param_against_mse(hyperparameters, :num_particles, 1:1:20) |> plot -""" -param_against_mse(f::Function, hypers::Dict, pname::Symbol, prange) = param_against_summary( - f, - hypers, - pname, - prange, - x -> x[:collection](CumulativeSquaredError).summary[end] -) \ No newline at end of file