Skip to content

Commit

Permalink
Controlling more node types (#424)
Browse files Browse the repository at this point in the history
Fixes #398.

---------

Co-authored-by: Bart de Koning <[email protected]>
  • Loading branch information
SouthEndMusic and SouthEndMusic authored Jul 14, 2023
1 parent 9e7ca64 commit 4a363d6
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 60 deletions.
156 changes: 116 additions & 40 deletions core/src/create.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,80 @@
function parse_static(
static::StructVector,
db::DB,
nodetype::String,
defaults::NamedTuple,
)::NamedTuple
static_type = eltype(static)
columnnames_static = collect(fieldnames(static_type))
mask = [symb [:node_id, :control_state] for symb in columnnames_static]
columnnames_variables = columnnames_static[mask]
columntypes_variables = collect(fieldtypes(static_type))[mask]
vals = []

node_ids = get_ids(db, nodetype)
n_nodes = length(node_ids)

# Initialize the vectors for the output
for i in eachindex(columntypes_variables)
if isa(columntypes_variables[i], Union)
columntype = nonmissingtype(columntypes_variables[i])
else
columntype = columntypes_variables[i]
end

push!(vals, zeros(columntype, n_nodes))
end

columnnames_out = copy(columnnames_variables)
columnnames_variables = Tuple(columnnames_variables)

push!(columnnames_out, :node_id)
push!(vals, node_ids)

control_mapping = Dict{Tuple{Int, String}, NamedTuple}()

push!(columnnames_out, :control_mapping)
push!(vals, control_mapping)

out = NamedTuple{Tuple(columnnames_out)}(Tuple(vals))

if n_nodes == 0
return out
end

# Node id of the node being processed
node_id = node_ids[1]

# Index in the output vectors for this node ID
node_idx = 1

for row in static
if node_id != row.node_id
node_idx += 1
node_id = row.node_id
end

# If this row is a control state, add it to the control mapping
if !ismissing(row.control_state)
control_values = NamedTuple{columnnames_variables}(values(row)[mask])
control_mapping[(row.node_id, row.control_state)] = control_values
end

# Assign the parameter values to the output
for columnname in columnnames_variables
val = getfield(row, columnname)

if ismissing(val)
val = getfield(defaults, columnname)
end

getfield(out, columnname)[node_idx] = val
end
end

return out
end

function Connectivity(db::DB)::Connectivity
graph_flow, edge_ids_flow, edge_connection_types_flow = create_graph(db, "flow")
graph_control, edge_ids_control, edge_connection_types_control =
Expand All @@ -22,7 +99,13 @@ end

function LinearResistance(db::DB, config::Config)::LinearResistance
static = load_structvector(db, config, LinearResistanceStaticV1)
return LinearResistance(static.node_id, static.resistance)
defaults = (;)
static_parsed = parse_static(static, db, "LinearResistance", defaults)
return LinearResistance(
static_parsed.node_id,
static_parsed.resistance,
static_parsed.control_mapping,
)
end

function TabulatedRatingCurve(db::DB, config::Config)::TabulatedRatingCurve
Expand Down Expand Up @@ -56,18 +139,27 @@ end

function ManningResistance(db::DB, config::Config)::ManningResistance
static = load_structvector(db, config, ManningResistanceStaticV1)
defaults = (;)
static_parsed = parse_static(static, db, "ManningResistance", defaults)
return ManningResistance(
static.node_id,
static.length,
static.manning_n,
static.profile_width,
static.profile_slope,
static_parsed.node_id,
static_parsed.length,
static_parsed.manning_n,
static_parsed.profile_width,
static_parsed.profile_slope,
static_parsed.control_mapping,
)
end

function FractionalFlow(db::DB, config::Config)::FractionalFlow
static = load_structvector(db, config, FractionalFlowStaticV1)
return FractionalFlow(static.node_id, static.fraction)
defaults = (;)
static_parsed = parse_static(static, db, "FractionalFlow", defaults)
return FractionalFlow(
static_parsed.node_id,
static_parsed.fraction,
static_parsed.control_mapping,
)
end

function LevelBoundary(db::DB, config::Config)::LevelBoundary
Expand All @@ -83,31 +175,16 @@ end
function Pump(db::DB, config::Config)::Pump
static = load_structvector(db, config, PumpStaticV1)

control_mapping = Dict{Tuple{Int, String}, NamedTuple}()

if length(static.control_state) > 0 && !any(ismissing.(static.control_state))
# Starting flow_rates are first one found (can be updated by control initialisation)
node_ids::Vector{Int} = []
flow_rates::Vector{Float64} = []
defaults = (; min_flow_rate = 0.0, max_flow_rate = NaN)
static_parsed = parse_static(static, db, "Pump", defaults)

for (node_id, control_state, row) in
zip(static.node_id, static.control_state, static)
if node_id node_ids
push!(node_ids, node_id)
push!(flow_rates, row.flow_rate)
end

control_mapping[(node_id, control_state)] = variable_nt(row)
end
else
node_ids = static.node_id
flow_rates = static.flow_rate
end

min_flow_rate = coalesce.(static.min_flow_rate, 0.0)
max_flow_rate = coalesce.(static.max_flow_rate, NaN)

return Pump(node_ids, flow_rates, min_flow_rate, max_flow_rate, control_mapping)
return Pump(
static_parsed.node_id,
static_parsed.flow_rate,
static_parsed.min_flow_rate,
static_parsed.max_flow_rate,
static_parsed.control_mapping,
)
end

function Terminal(db::DB, config::Config)::Terminal
Expand Down Expand Up @@ -202,18 +279,17 @@ end

function PidControl(db::DB, config::Config)::PidControl
static = load_structvector(db, config, PidControlStaticV1)
defaults = (proportional = NaN, integral = NaN, derivative = NaN)
static_parsed = parse_static(static, db, "PidControl", defaults)

proportional = coalesce.(static.proportional, NaN)
integral = coalesce.(static.integral, NaN)
derivative = coalesce.(static.derivative, NaN)
error = zero(derivative)
error = zero(static_parsed.node_id)

return PidControl(
static.node_id,
static.listen_node_id,
proportional,
integral,
derivative,
static_parsed.node_id,
static_parsed.listen_node_id,
static_parsed.proportional,
static_parsed.integral,
static_parsed.derivative,
error,
)
end
Expand Down
3 changes: 3 additions & 0 deletions core/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ Requirements:
struct LinearResistance <: AbstractParameterNode
node_id::Vector{Int}
resistance::Vector{Float64}
control_mapping::Dict{Tuple{Int, String}, NamedTuple}
end

"""
Expand Down Expand Up @@ -180,6 +181,7 @@ struct ManningResistance <: AbstractParameterNode
manning_n::Vector{Float64}
profile_width::Vector{Float64}
profile_slope::Vector{Float64}
control_mapping::Dict{Tuple{Int, String}, NamedTuple}
end

"""
Expand All @@ -192,6 +194,7 @@ Requirements:
struct FractionalFlow <: AbstractParameterNode
node_id::Vector{Int}
fraction::Vector{Float64}
control_mapping::Dict{Tuple{Int, String}, NamedTuple}
end

"""
Expand Down
3 changes: 1 addition & 2 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,11 @@ end
@version LevelBoundaryStaticV1 begin
node_id::Int
level::Float64
control_state::Union{Missing, String}
end

@version FlowBoundaryStaticV1 begin
node_id::Int
flow_rate::Float64
control_state::Union{Missing, String}
end

@version LinearResistanceStaticV1 begin
Expand Down Expand Up @@ -183,6 +181,7 @@ end
proportional::Float64
integral::Union{Missing, Float64}
derivative::Union{Missing, Float64}
control_state::Union{Missing, String}
end

function variable_names(s::Any)
Expand Down
1 change: 0 additions & 1 deletion docs/core/usage.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ column | type | unit | restriction
------------- | ------- | ------------ | -----------
node_id | Int | - | sorted
flow_rate | Float64 | $m^3 s^{-1}$ | -
control_state | String | - | (optional)


### LinearResistance
Expand Down
7 changes: 0 additions & 7 deletions docs/schema/FlowBoundaryStatic.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@
"format": "default",
"description": "node_id",
"type": "integer"
},
"control_state": {
"format": "default",
"description": "control_state",
"type": [
"string"
]
}
},
"required": [
Expand Down
7 changes: 0 additions & 7 deletions docs/schema/LevelBoundaryStatic.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@
"format": "double",
"description": "level",
"type": "number"
},
"control_state": {
"format": "default",
"description": "control_state",
"type": [
"string"
]
}
},
"required": [
Expand Down
7 changes: 7 additions & 0 deletions docs/schema/PIDControlStatic.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
"type": [
"number"
]
},
"control_state": {
"format": "default",
"description": "control_state",
"type": [
"string"
]
}
},
"required": [
Expand Down
5 changes: 2 additions & 3 deletions python/ribasim/ribasim/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: root.schema.json
# timestamp: 2023-07-11T15:21:47+00:00
# timestamp: 2023-07-13T14:34:28+00:00

from __future__ import annotations

Expand Down Expand Up @@ -38,7 +38,6 @@ class LevelBoundaryStatic(BaseModel):
remarks: Optional[str] = Field("", description="a hack for pandera")
node_id: int = Field(..., description="node_id")
level: float = Field(..., description="level")
control_state: Optional[str] = Field(None, description="control_state")


class DiscreteControlCondition(BaseModel):
Expand Down Expand Up @@ -81,6 +80,7 @@ class PidControlStatic(BaseModel):
proportional: float = Field(..., description="proportional")
node_id: int = Field(..., description="node_id")
derivative: Optional[float] = Field(None, description="derivative")
control_state: Optional[str] = Field(None, description="control_state")


class ManningResistanceStatic(BaseModel):
Expand All @@ -97,7 +97,6 @@ class FlowBoundaryStatic(BaseModel):
remarks: Optional[str] = Field("", description="a hack for pandera")
flow_rate: float = Field(..., description="flow_rate")
node_id: int = Field(..., description="node_id")
control_state: Optional[str] = Field(None, description="control_state")


class Node(BaseModel):
Expand Down

0 comments on commit 4a363d6

Please sign in to comment.