Skip to content

Commit

Permalink
multiple fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
TorkelE committed Jan 18, 2025
1 parent 9230667 commit 78ca087
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 62 deletions.
49 changes: 24 additions & 25 deletions src/dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ This model also contains production and degradation reactions, where `0` denotes
either no substrates or no products in a reaction.
Options:
In addition to reactions, the macro also supports "option" inputs. Each option is designated
by a tag starting with a `@` followed by its input. A list of options can be found [here](https://docs.sciml.ai/Catalyst/stable/api/#api_dsl_options).
In addition to reactions, the macro also supports "option" inputs (permitting e.g. the addition
of observables). Each option is designated by a tag starting with a `@` followed by its input.
A list of options can be found [here](https://docs.sciml.ai/Catalyst/stable/api/#api_dsl_options).
"""
macro reaction_network(name::Symbol, network_expr::Expr)
make_rs_expr(QuoteNode(name), network_expr)
Expand Down Expand Up @@ -198,8 +199,8 @@ function recursive_find_reactants!(ex::ExprValues, mult::ExprValues,
# If the expression corresponds to a reactant on our list, increase its multiplicity.
idx = findfirst(r.reactant == ex for r in reactants)
if !isnothing(idx)
new_mult = processmult(+, mult, reactants[idx].stoichiometry)
reactants[idx] = DSLReactant(ex, new_mult)
newmult = processmult(+, mult, reactants[idx].stoichiometry)
reactants[idx] = DSLReactant(ex, newmult)

# If the expression corresponds to a new reactant, add it to the list.
else
Expand All @@ -210,12 +211,12 @@ function recursive_find_reactants!(ex::ExprValues, mult::ExprValues,
elseif ex.args[1] == :*
# The normal case (e.g. 3*X or 3*(X+Y)). Update the current multiplicity and continue.
if length(ex.args) == 3
new_mult = processmult(*, mult, ex.args[2])
recursive_find_reactants!(ex.args[3], new_mult, reactants)
newmult = processmult(*, mult, ex.args[2])
recursive_find_reactants!(ex.args[3], newmult, reactants)
# More complicated cases (e.g. 2*3*X). Yes, `ex.args[1:(end - 1)]` should start at 1 (not 2).
else
new_mult = processmult(*, mult, Expr(:call, ex.args[1:(end - 1)]...))
recursive_find_reactants!(ex.args[end], new_mult, reactants)
newmult = processmult(*, mult, Expr(:call, ex.args[1:(end - 1)]...))
recursive_find_reactants!(ex.args[end], newmult, reactants)
end
# If we have encountered a sum of different reactants, apply recursion on each.
elseif ex.args[1] == :+
Expand Down Expand Up @@ -243,11 +244,10 @@ end
function extract_metadata(metadata_line::Expr)
metadata = :([])
for arg in metadata_line.args
if arg.head != :(=)
(arg.head != :(=)) &&
error("Malformatted metadata line: $metadata_line. Each entry in the vector should contain a `=`.")
elseif !(arg.args[1] isa Symbol)
(arg.args[1] isa Symbol) ||
error("Malformatted metadata entry: $arg. Entries left-hand-side should be a single symbol.")
end
push!(metadata.args, :($(QuoteNode(arg.args[1])) => $(arg.args[2])))
end
return metadata
Expand Down Expand Up @@ -422,7 +422,7 @@ function push_reactions!(reactions::Vector{DSLReaction}, subs::ExprValues,
lengs = (tup_leng(subs), tup_leng(prods), tup_leng(rate), tup_leng(metadata))
maxl = maximum(lengs)
if any(!(leng == 1 || leng == maxl) for leng in lengs)
throw("Malformed reaction, rate=$rate, subs=$subs, prods=$prods, metadata=$metadata.")
error("Malformed reaction, rate: $rate, subs: $subs, prods: $prods, metadata: $metadata.")
end

# Loops through each reaction encoded by the reaction's different components.
Expand Down Expand Up @@ -452,12 +452,12 @@ end
function extract_syms(opts, vartype::Symbol)
# If the corresponding option have been used, uses `Symbolics._parse_vars` to find all
# variable within it (returning them in a vector).
if haskey(opts, vartype)
return if haskey(opts, vartype)
ex = opts[vartype]
vars = Symbolics._parse_vars(vartype, Real, ex.args[3:end])
return Vector{Union{Symbol, Expr}}(vars.args[end].args)
Vector{Union{Symbol, Expr}}(vars.args[end].args)
else
return Union{Symbol, Expr}[]
Union{Symbol, Expr}[]
end
end

Expand Down Expand Up @@ -593,6 +593,14 @@ function get_rxexpr(rx::DSLReaction)
return rx_constructor
end

# Recursively escape functions within equations of an equation written using user-defined functions.
# Does not expand special function calls like "hill(...)" and differential operators.
function escape_equation!(eqexpr::Expr, diffsyms)
eqexpr.args[2] = recursive_escape_functions!(eqexpr.args[2], diffsyms)
eqexpr.args[3] = recursive_escape_functions!(eqexpr.args[3], diffsyms)
eqexpr
end

### DSL Option Handling ###

# Finds the time independent variable, and any potential spatial independent variables.
Expand Down Expand Up @@ -856,7 +864,7 @@ end

### `@reaction` Macro & its Internals ###

@doc raw"""
"""
@reaction
Macro for generating a single [`Reaction`](@ref) object using a similar syntax as the `@reaction_network`
Expand Down Expand Up @@ -962,15 +970,6 @@ function recursive_escape_functions!(expr::ExprValues, diffsyms = [])
expr
end

# Returns the length of a expression tuple, or 1 if it is not an expression tuple (probably
# a Symbol/Numerical). This is used to handle bundled reaction (like `d, (X,Y) --> 0`).
# Recursively escape functions in the right-hand-side of an equation written using user-defined functions. Special function calls like "hill(...)" are not expanded.
function escape_equation!(eqexpr::Expr, diffsyms)
eqexpr.args[2] = recursive_escape_functions!(eqexpr.args[2], diffsyms)
eqexpr.args[3] = recursive_escape_functions!(eqexpr.args[3], diffsyms)
eqexpr
end

# Returns the length of a expression tuple, or 1 if it is not an expression tuple (probably a Symbol/Numerical).
function tup_leng(ex::ExprValues)
(typeof(ex) == Expr && ex.head == :tuple) && (return length(ex.args))
Expand Down
12 changes: 3 additions & 9 deletions src/expression_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,9 @@ end

# Throws an error when a forbidden symbol is used.
function forbidden_symbol_check(sym)
used_forbidden_syms = intersect(forbidden_symbols_error, sym)
isempty(used_forbidden_syms) && return
error("The following symbol(s) are used as species or parameters: $used_forbidden_syms, this is not permitted.")
end

# Checks that no symbol was sued for multiple purposes.
function unique_symbol_check(syms)
allunique(syms)||
error("Reaction network independent variables, parameters, species, and variables must all have distinct names, but a duplicate has been detected. ")
used_forbidden_syms =
isempty(used_forbidden_syms) ||
error("The following symbol(s) are used as species or parameters: $used_forbidden_syms, this is not permitted.")
end

### Catalyst-specific Expressions Manipulation ###
Expand Down
22 changes: 11 additions & 11 deletions test/dsl/dsl_basic_model_construction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ let
u0 = rnd_u0(networks[1], rng; factor)
p = rnd_ps(networks[1], rng; factor)
t = rand(rng)

@test f_eval(networks[1], u0, p, t) f_eval(networks[2], u0, p, t)
@test jac_eval(networks[1], u0, p, t) jac_eval(networks[2], u0, p, t)
@test g_eval(networks[1], u0, p, t) g_eval(networks[2], u0, p, t)
Expand All @@ -207,18 +207,18 @@ let
(l3, l4), Y2 Y3
(l5, l6), Y3 Y4
c, Y4
end
end

# Checks that the networks' functions evaluates equally for various randomised inputs.
@unpack X1, X2, X3, X4, p, d, k1, k2, k3, k4, k5, k6 = network
for factor in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3]
u0_1 = Dict(rnd_u0(network, rng; factor))
p_1 = Dict(rnd_ps(network, rng; factor))
u0_2 = [:Y1 => u0_1[X1], :Y2 => u0_1[X2], :Y3 => u0_1[X3], :Y4 => u0_1[X4]]
p_2 = [:q => p_1[p], :c => p_1[d], :l1 => p_1[k1], :l2 => p_1[k2], :l3 => p_1[k3],
p_2 = [:q => p_1[p], :c => p_1[d], :l1 => p_1[k1], :l2 => p_1[k2], :l3 => p_1[k3],
:l4 => p_1[k4], :l5 => p_1[k5], :l6 => p_1[k6]]
t = rand(rng)

@test f_eval(network, u0_1, p_1, t) f_eval(differently_written_5, u0_2, p_2, t)
@test jac_eval(network, u0_1, p_1, t) jac_eval(differently_written_5, u0_2, p_2, t)
@test g_eval(network, u0_1, p_1, t) g_eval(differently_written_5, u0_2, p_2, t)
Expand Down Expand Up @@ -271,7 +271,7 @@ let
u0 = rnd_u0(networks[1], rng; factor)
p = rnd_ps(networks[1], rng; factor)
t = rand(rng)

@test f_eval(networks[1], u0, p, t) f_eval(networks[2], u0, p, t)
@test jac_eval(networks[1], u0, p, t) jac_eval(networks[2], u0, p, t)
@test g_eval(networks[1], u0, p, t) g_eval(networks[2], u0, p, t)
Expand All @@ -293,7 +293,7 @@ let
(sqrt(3.7), exp(1.9)), X4 X1 + X2
end
push!(identical_networks_3, reaction_networks_standard[9] => no_parameters_9)
push!(parameter_sets, [:p1 => 1.5, :p2 => 1, :p3 => 2, :d1 => 0.01, :d2 => 2.3, :d3 => 1001,
push!(parameter_sets, [:p1 => 1.5, :p2 => 1, :p3 => 2, :d1 => 0.01, :d2 => 2.3, :d3 => 1001,
:k1 => π, :k2 => 42, :k3 => 19.9, :k4 => 999.99, :k5 => sqrt(3.7), :k6 => exp(1.9)])

no_parameters_10 = @reaction_network begin
Expand All @@ -305,14 +305,14 @@ let
1.0, X5
end
push!(identical_networks_3, reaction_networks_standard[10] => no_parameters_10)
push!(parameter_sets, [:p => 0.01, :k1 => 3.1, :k2 => 3.2, :k3 => 0.0, :k4 => 2.1, :k5 => 901.0,
push!(parameter_sets, [:p => 0.01, :k1 => 3.1, :k2 => 3.2, :k3 => 0.0, :k4 => 2.1, :k5 => 901.0,
:k6 => 63.5, :k7 => 7, :k8 => 8, :d => 1.0])

for (networks, p_1) in zip(identical_networks_3, parameter_sets)
for factor in [1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3]
u0 = rnd_u0(networks[1], rng; factor)
t = rand(rng)

@test f_eval(networks[1], u0, p_1, t) f_eval(networks[2], u0, [], t)
@test jac_eval(networks[1], u0, p_1, t) jac_eval(networks[2], u0, [], t)
@test g_eval(networks[1], u0, p_1, t) g_eval(networks[2], u0, [], t)
Expand Down Expand Up @@ -383,7 +383,7 @@ let
τ = rand(rng)
u = rnd_u0(reaction_networks_conserved[1], rng; factor)
p_2 = rnd_ps(time_network, rng; factor)
p_1 = [p_2; reaction_networks_conserved[1].k1 => τ;
p_1 = [p_2; reaction_networks_conserved[1].k1 => τ;
reaction_networks_conserved[1].k4 => τ; reaction_networks_conserved[1].k5 => τ]

@test f_eval(reaction_networks_conserved[1], u, p_1, τ) f_eval(time_network, u, p_2, τ)
Expand Down Expand Up @@ -463,7 +463,7 @@ let
@test rn1 == rn2
end

# Tests arrow variants in `@reaction`` macro.
# Tests arrow variants in `@reaction` macro.
let
@test isequal((@reaction k, 0 --> X), (@reaction k, X <-- 0))
@test isequal((@reaction k, 0 --> X), (@reaction k, X 0))
Expand Down Expand Up @@ -533,4 +533,4 @@ let
@test_throws Exception @eval @reaction_network begin
k, X^Y --> XY
end
end
end
26 changes: 9 additions & 17 deletions test/dsl/dsl_options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,23 +417,15 @@ let
@test issetequal(species(rn), spcs)
end

# Tests errors in `@variables` declarations.
# Tests error when disallowed name is used for variable.
let
# Variable used as species in reaction.
@test_throws Exception @eval rn = @reaction_network begin
@variables K(t)
k, K + A --> B
end

# Tests error when disallowed name is used for variable.
@test_throws Exception @eval @reaction_network begin
@variables π(t)
end
end

# Tests that explicitly declaring a single symbol as several things does not work.
# Several of these are broken, but note sure how to test broken-ness on `@test_throws false Exception @eval`.
# Relevant issue: https://github.com/SciML/Catalyst.jl/issues/1173
let
# Species + parameter.
@test_throws Exception @eval @reaction_network begin
Expand Down Expand Up @@ -1150,14 +1142,6 @@ let
end
end

# Erroneous `@default_noise_scaling` declaration (other noise scaling tests are mostly in the SDE file).
let
# Default noise scaling with multiple entries.
@test_throws Exception @eval @reaction_network begin
@default_noise_scaling η1 η2
end
end

### Other DSL Option Tests ###

# test combinatoric_ratelaws DSL option
Expand Down Expand Up @@ -1264,6 +1248,14 @@ let
@test isequal(Catalyst.expand_registered_functions(equations(rn4)[1]), D(A) ~ v*(A^n))
end

# Erroneous `@default_noise_scaling` declaration (other noise scaling tests are mostly in the SDE file).
let
# Default noise scaling with multiple entries.
@test_throws Exception @eval @reaction_network begin
@default_noise_scaling η1 η2
end
end

### test that @no_infer properly throws errors when undeclared variables are written ###

import Catalyst: UndeclaredSymbolicError
Expand Down

0 comments on commit 78ca087

Please sign in to comment.