Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ZacLN committed Oct 14, 2015
1 parent a010777 commit a1e462a
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 176 deletions.
1 change: 1 addition & 0 deletions src/EconModel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export ARSim,
grow!,
interp,
save,
setaggregate!,
shrink!,
solveit!,
solve!,
Expand Down
23 changes: 13 additions & 10 deletions src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,26 @@ function Base.setindex!(M::Model,val::Float64,x::Symbol)
for i = 1:length(static.args)
push!(static.args,tchange!(copy(static.args[i]),1))
end
static = Dict(zip([x.args[1] for x in static.args],[x.args[2] for x in static.args]))

staticvars = filter(x->!in(x[1],[x.args[1].args[1] for x in static.args]) ,unique(getv(static,Any[])))

static = Dict(zip([x.args[1] for x in static.args],[x.args[2] for x in static.args]))
subs!(foc,static)

allvariables = unique(getv(foc,Any[]))
allvariables = unique(vcat(unique(getv(foc,Any[])),staticvars))
M.future = FutureVariables(foc,M.meta.auxillary,M.state)


variablelist = getMnames(allvariables,M.state,M.policy,M.future,M.auxillary,M.aggregate)
variablelist = getMnames(allvariables,M.state,M.policy,M.future,M.auxillary,M.aggregate)

for i = 1:length(aux.args)
if !in(aux.args[i].args[1],[x.args[1] for x in variablelist[:,1]])
x = copy(aux.args[i].args[1])
x = addindex!(x)
x = hcat(x,:(M.auxillary.X[i,$i]),symbol("A$i"))
variablelist = vcat(variablelist,x)
for i = 1:length(aux.args)
if !in(aux.args[i].args[1],[x.args[1] for x in variablelist[:,1]])
x = copy(aux.args[i].args[1])
x = addindex!(x)
x = hcat(x,:(M.auxillary.X[i,$i]),symbol("A$i"))
variablelist = vcat(variablelist,x)
end
end
end

for i = M.state.nendo+1:M.state.n
if !in(M.state.names[i],[x.args[1] for x in variablelist[:,1]])
Expand Down
55 changes: 18 additions & 37 deletions src/aggregate.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,9 @@
import Base:*,convert,kron,spzeros

type brack
i::Vector{Int}
w::Float64
end

convert(::Type{Tuple},x::Vector{Int})=ntuple(i->x[i],length(x))

*(a::brack,b::brack) = (x= deepcopy(a);push!(x.i,b.i[1]);x.w=a.w*b.w;x)

kron(a::Array{brack,1})= a

spzeros(n::Int) = spzeros(n,1)

function findbracket(v,x::Vector)
i=searchsortedfirst(x,v)
if i==1
return [brack([1],.5);brack([1],.5)]
else
dx = x[i]-x[i-1]
return [brack([i-1],(x[i]-v)/dx);brack([i],(v-x[i-1])/dx)]
end
end

type AggregateVariables
n::Int64
names::Array{Symbol,1}
X::Array{Float64}
XP::Array{Float64}
upd::Vector{Int}
target::Vector{Symbol}
isag::Vector{Bool}
d::Array{Float64}
Expand All @@ -37,20 +12,26 @@ type AggregateVariables
G
T::SparseMatrixCSC
end
AggregateVariables() = AggregateVariables(0,Symbol[],Float64[],Float64[],Int[],Symbol[],Bool[],Float64[],Float64[],nothing,nothing,spzeros(0,0))
AggregateVariables() = AggregateVariables(0,Symbol[],Float64[],Float64[],Symbol[],Bool[],Float64[],Float64[],nothing,nothing,spzeros(0,0))

function AggregateVariables(agg::Expr,State::StateVariables,Future::FutureVariables,Policy::PolicyVariables)
if agg == :[]
return AggregateVariables()
end
if agg.args[1].head==:(=)
if isa(agg.args[1],Symbol) || agg.args[1].head==:tuple
# warn("Agg warn")
isag= zeros(Bool,State.n)
nag = length(agg.args)
else
warn("Agg warn")
alist = isa(agg.args[1],Symbol) ? [agg.args[1]] : agg.args[1].args
for a in alist
in(a,State.names[1:State.nendo]) ? error("Endogenous state variable chosen as aggregate") : nothing
isag[find(a.==State.names)] = true
end
nag = length(agg.args)-1
elseif agg.args[1].head==:(=)
isag= zeros(Bool,State.n)
nag = length(agg.args)
end

isag = isag[State.nendo+1:end]

g = [sort(unique(State.X[:,i])) for i = 1:State.n]
G = ndgrid(g...)
Expand All @@ -61,15 +42,15 @@ function AggregateVariables(agg::Expr,State::StateVariables,Future::FutureVariab
X = zeros(State.G.n,nag)
XP = zeros(State.G.n*Future.nP,nag)

for i = 1:length(agg.args)
X[:,i] = agg.args[i].args[2].args[2]
for i = 1:length(agg.args)-any(isag)
X[:,i] = agg.args[i+any(isag)].args[2].args[2]
XP[:,i] = agg.args[i+any(isag)].args[2].args[2]
end
AggregateVariables(length(agg.args),
Symbol[x.args[1] for x in agg.args],
AggregateVariables(length(agg.args)-any(isag),
Symbol[x.args[1] for x in agg.args[1+any(isag):end]],
X,
XP,
Int[x.args[2].args[3] for x in agg.args],
Symbol[x.args[2].args[1] for x in agg.args],
Symbol[x.args[2].args[1] for x in agg.args[1+any(isag):end]],
isag,
d,
dG,
Expand Down
74 changes: 61 additions & 13 deletions src/aggregatefuncs.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import Base:*,convert,kron,spzeros

function updateaggregate!(M::Model)
for i = 1:M.aggregate.n
M.aggregate.X[:,i] = (M,M.aggregate.target[i])
type brack
i::Vector{Int}
w::Float64
end

convert(::Type{Tuple},x::Vector{Int})=ntuple(i->x[i],length(x))

*(a::brack,b::brack) = (x= deepcopy(a);push!(x.i,b.i[1]);x.w=a.w*b.w;x)

kron(a::Array{brack,1})= a

spzeros(n::Int) = spzeros(n,1)

function findbracket(v,x::Vector)
i=searchsortedfirst(x,v)
if i==1
return [brack([1],.5);brack([1],.5)]
else
dx = x[i]-x[i-1]
return [brack([i-1],(x[i]-v)/dx);brack([i],(v-x[i-1])/dx)]
end
end

Expand All @@ -14,9 +32,9 @@ function updatetransition!(M::Model)
brackets = ntuple(ip->findbracket(Pf[ip][i],M.aggregate.g[ip]),M.state.nendo)
ei = [findfirst(M.aggregate.G[M.state.nendo+ie][i].==M.aggregate.g[M.state.nendo+ie]) for ie = 1:M.state.nexog]
abrackets = kron(brackets...)
w = spzeros(ntuple(i->length(M.aggregate.g[i]),length(M.state.nendo))...)
w = spzeros(ntuple(i->length(M.aggregate.g[i]),M.state.nendo)...)
for b in abrackets
w[Tuple(b.i)...]=b.w
w[Tuple(b.i)...]+=b.w
end
for ie = 1:M.state.nexog
if M.aggregate.isag[ie]
Expand Down Expand Up @@ -51,13 +69,11 @@ end
function (M::Model,v::Symbol)
if in(v,M.state.names)
V = M.aggregate.G[findfirst(M.state.names.==v)]
elseif in(v,M.policy.names)
elseif in(v,vcat(M.aggregate.names,M.auxillary.names,M.policy.names))
V = reshape(interp(M,v,hcat([vec(x) for x in M.aggregate.G]...)),size(M.aggregate.G[1]))
elseif in(v,M.static.names)
M.static.sget(M)
V = reshape(interp(M,v,hcat([vec(x) for x in M.aggregate.G]...)),size(M.aggregate.G[1]))
elseif in(v,M.auxillary.names)
V = reshape(interp(M,v,hcat([vec(x) for x in M.aggregate.G]...)),size(M.aggregate.G[1]))
end
d = V.*M.aggregate.d
if all(!M.aggregate.isag)
Expand All @@ -74,12 +90,44 @@ function ∫(M::Model,v::Symbol)
end
end

function (M::Model,v::Symbol,V::Symbol)
ag = (M,v)
# function ∫(M::Model,v::Symbol,V::Symbol)
# ag = ∫(M,v)
# eag = M.state.names[M.state.nendo+1:end][M.aggregate.isag]
# X = ndgrid([M[e].x for e in eag]...)
# for i = 1:length(X[1])
# id=BitArray(.*([M[eag[ie],0].==X[ie][i] for ie = 1:length(eag)]...))
# M[V,0,id]=ag[i]*ones(sum(id))
# end
# end


function updateaggregate!(M::Model=0.0)
if any(M.aggregate.isag)
for i = 1:M.aggregate.n
v = M.aggregate.target[i]
ag = (M,v)
eag = M.state.names[M.state.nendo+1:end][M.aggregate.isag]
X = ndgrid([M[e].x for e in eag]...)
for ii = 1:length(X[1])
id=BitArray(.*([M[eag[ie],0].==X[ie][ii] for ie = 1:length(eag)]...))
M.aggregate.X[id,i] *= ϕ
M.aggregate.X[id,i] += (1-ϕ)*ag[ii]*ones(sum(id))
end
end
else
for i = 1:M.aggregate.n
M.aggregate.X[:,i] *= ϕ
M.aggregate.X[:,i] += (1-ϕ)*(M,M.aggregate.target[i])
end
end
end

function setaggregate!(M::Model,V::Symbol,ag::Vector{Float64})
eag = M.state.names[M.state.nendo+1:end][M.aggregate.isag]
X = ndgrid([M[e].x for e in eag]...)
for i = 1:length(X[1])
id=BitArray(.*([M[eag[ie],0].==X[ie][i] for ie = 1:length(eag)]...))
M[V,0,id]=ag[i]*ones(sum(id))
@assert size(X[1]) == size(ag) "Input does not match AGGREGATE state space size"
for ii = 1:length(X[1])
id=BitArray(.*([M[eag[ie],0].==X[ie][ii] for ie = 1:length(eag)]...))
M[V,0,id]=ag[ii]*ones(sum(id))
end
end
2 changes: 1 addition & 1 deletion src/auxillary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function AuxillaryVariables(aux::Expr,State::StateVariables,Future::FutureVariab
XP = zeros(State.G.n*Future.nP,length(aux.args))

for i = 1:length(aux.args)
X[:,i] = aux.args[i].args[2]
X[:,i] = aux.args[i].args[2].args[1]
# XP[:,i] = aux.args[i].args[2].args[2]
end
AuxillaryVariables(length(aux.args),
Expand Down
Loading

0 comments on commit a1e462a

Please sign in to comment.