diff --git a/src/constraint_tree.jl b/src/constraint_tree.jl index 2fb4a03..d9061f2 100644 --- a/src/constraint_tree.jl +++ b/src/constraint_tree.jl @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import DataStructures: SortedDict +import DataStructures: SortedDict, SortedSet """ $(TYPEDEF) @@ -98,7 +98,7 @@ $(TYPEDSIGNATURES) Find the expected count of variables in a [`ConstraintTree`](@ref). """ -var_count(x::ConstraintTree) = isempty(elems(x)) ? 0 : maximum(var_count.(values(elems(x)))) +var_count(x::ConstraintTree) = isempty(elems(x)) ? 0 : maximum(var_count.(values(x))) """ $(TYPEDSIGNATURES) @@ -159,6 +159,65 @@ incr_var_idxs(x::QuadraticValue, incr::Int) = QuadraticValue( weights = x.weights, ) +""" +$(TYPEDSIGNATURES) + +Push all variable indexes found in `x` to the `out` container. + +(The container needs to support the standard `push!`.) +""" +collect_variables!(x::Constraint, out) = collect_variables!(x.value) +collect_variables!(x::LinearValue, out) = + for idx in x.idxs + push!(out, idx) + end +collect_variables!(x::QuadraticValue, out) = + for (idx, idy) in x.idxs + push!(out, idx, idy) + end +collect_variables!(x::ConstraintTree, out) = collect_variables!.(values(x), Ref(out)) + +""" +$(TYPEDSIGNATURES) + +Prune the unused variable indexes from an object `x` (such as a +[`ConstraintTree`](@ref)). + +This first runs [`collect_variables!`](@ref) to determine the actual used +variables, then calls [`renumber_variables`](@ref) to create a renumbered +object. +""" +function prune_variables(x) + vars = SortedSet{Int}() + collect_variables!(x, vars) + push!(vars, 0) + vv = collect(vars) + @assert vv[1] == 0 "variable indexes are broken" + return renumber_variables(x, SortedDict(vv .=> 0:length(vv))) +end + +""" +$(TYPEDSIGNATURES) + +Renumber all variables in an object (such as [`ConstraintTree`](@ref)). The new +variable indexes are taken from the `mapping` parameter at the index of the old +variable's index. + +This does not run any consistency checks on the result; the `mapping` must +therefore be monotonically increasing, and the zero index must map to itself, +otherwise invalid [`Value`](@ref)s will be produced. +""" +renumber_variables(x::ConstraintTree, mapping) = + ConstraintTree(k => renumber_Variables(v, mapping) for (k, v) in x) +renumber_variables(x::Constraint, mapping) = + Constraint(renumber_variables(x.value, mapping), x.bound) +renumber_variables(x::LinearValue, mapping) = + LinearValue(idxs = [mapping[idx] for idx in x.idxs], weights = x.weights) +renumber_variables(x::QuadraticValue, mapping) = QuadraticValue( + idxs = [(mapping[idx], mapping[idy]) for (idx, idy) in x.idxs], + weights = x.weights, +) + # # Algebraic construction #