You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I wasn't sure where to put this but I figured this was a good place. Maybe discourse instead? Open to suggestions
@ChrisRackauckas and I discussed yesterday that julia's subtyping dispatch model makes symbolic graph languages like ModelingToolkit (MTK) difficult. For instance, in MTK we have Expression <: Number, but this is far from ideal. For instance, what if someone is more interested in symbolically representing arrays? Or strings? or some other type that doesn't have anything to do with numbers. Sure, we can get by with <: Number even if the object isn't actually a number, but it's awkward.
I think once (if?) MTK starts doing algebraic simplicifications, it's going to need to be able to deal with applying type dependant simplification rules, for example only assuming commutativity between Real and Complex numbers but not say Matrix or Quaternion.
One approach for getting around the sub-typing constraint that I've been exploring is to instead to have symbolic computations happen inside a Cassette.jl pass, inside of which we apply code transformations such that Expression{T} acts like it is a T.
It turns out that @shashijust recently opened a PR to Cassette that would make this quite easy, so I mocked up a demo of what I was thinking that I'd like to share here for discussion.
Apologies for not following MTK naming conventions in this demo but I think you'll get the idea
#--------------------------------------------------------------------------------# Set up some symbolic types#--------------------------------------------------------------------------------abstract type Symbolic{T} end# Symbolic{T} will act like it is <: Tstruct Sym{T} <:Symbolic{T}
name::SymbolendSymbol(s::Sym) = s.name
struct SymExpr{T} <:Symbolic{T}
op
args::Vector{Any}end#--------------------------------------------------------------------------------# Pretty printing#--------------------------------------------------------------------------------function Base.show(io::IO, s::Sym{T}) where {T}
print(io, string(s.name)*"::$T")
endexpr(se::SymExpr, Ts) =Expr(:call, expr(se.op, Ts), expr.(se.args, (Ts,))...)
expr(x, Ts) = x
expr(f::Function, Ts) =Symbol(f)
functionexpr(s::Sym{T}, Ts) where {T}
if s ∉ Ts
push!(Ts, s)
end
s.name
endfunction Base.show(io::IO, se::SymExpr{T}) where {T}
sset =Set()
ex =expr(se, sset)
print(io, repr(ex)[2:end]*"::$T"*" where {"*repr(sset)[9:end-2]*"}")
end#--------------------------------------------------------------------------------# Set up the Cassette pass#--------------------------------------------------------------------------------using Cassette: Cassette, overdub, @context, ReflectOn
using Base.Core.Compiler: return_type
using SpecializeVarargs
@context SymContext
sym_substitute(::Type{Sym{T}}) where {T} = T
sym_substitute(::Type{SymExpr{T}}) where {T} = T
sym_substitute(::Type{T}) where {T} = T
function Cassette.overdub(ctx::SymContext, f::Function, args...)
argsT =typeof(args)
ifany((<:).(argsT.parameters, Symbolic))
argsT′ = [sym_substitute.(argsT.parameters)...]
ifisprimitive(f)
SymExpr{return_type(f, Tuple{argsT′...})}(f, [args...])
elseoverdub(ctx, ReflectOn{Tuple{typeof(f), argsT′...}}(), f, args...)
endelsef(args...)
endend# If isprimitive(f) == true, then inside a pass, we won't recurse into the insides of f. # Primitives are stopping points for us for f in [:+, :-, :*, :/, :^, :exp, :log,
:sin, :cos, :tan, :asin, :acos, :atan,
:sinh, :cosh, :tanh, :asinh, :acosh, :atanh, :adjoint]
@evalisprimitive(::typeof($f)) =trueendisprimitive(::Any) =false# Convenience macro for wrapping any enclosed function calls in the cassette passusing MacroTools: postwalk
macrosym(expr)
out =postwalk(expr) do ex
if ex isa Expr && ex.head ==:call
:(overdub(SymContext(), $(ex.args[1]), $(ex.args[2:end]...)))
else
ex
endendesc(out)
end
Okay, with that setup code out of the way, lets see what this got us. Suppose we have two functions from an external library that have very strict type constraints:
With the above definitions, we have no problems getting inside these functions:
julia> x =Sym{Float64}(:x)
x::Float64
julia> y =Sym{Vector{Int}}(:y)
y::Array{Int64,1}
julia>@symf(x)
(1+sin(x) ^0.5)::Float64where {x::Float64}
julia>@symg(y)
(adjoint(y) * y +2)::Int64where {y::Array{Int64,1}}
Okay, I thought that was pretty neat, but there's still some major problems with this sort of approach. Here's a two that are forefront in my mind:
This approach can be too good at recursing into code. For instance, you wouldn't want to accidentally call sin(x) on a symbolic x if it wasn't registered because then you'd be exposed to the internals of sin which nobody is interested in seeing. Alternatively, we could make it so that it only recurses into registered functions instead of the other way around, but I think that might almost as problematic because it'd be incredibly annoying to have to register every function you might care about, plus it'd make symbolic differentiation more difficult.
If we want types associated with expressions, then we either have to build in the mapping between input and output types for every registered function, or we have to rely on Core.Compiler.return_type which brings it's own host of problems, which are not limited to the usual warnings you'll hear from complier people. For instance, return_type(+, Tuple{Real, Real}) == Any, so Syms should be concretely typed instead of abstract 😕.
Do you think this is a dead end? Do you think it's promising? Any other thoughts or comments? Questions?
The text was updated successfully, but these errors were encountered:
I think that, along with a 'blacklist' of functions where we might want to throw an error if we recurse into them would be pretty solid.
How would staged programming work in this setup? I.e. how could a user run a simplification pass and get an optimized function?
This is mostly just about getting the computational graph, not how it's used. The only thing that changes for simplification and staged programming is that (supposing we get good type inference) the computational graph is fully typed at every level, letting us do type dependent optimizations if we wish.
I wasn't sure where to put this but I figured this was a good place. Maybe discourse instead? Open to suggestions
@ChrisRackauckas and I discussed yesterday that julia's subtyping dispatch model makes symbolic graph languages like ModelingToolkit (MTK) difficult. For instance, in MTK we have
Expression <: Number
, but this is far from ideal. For instance, what if someone is more interested in symbolically representing arrays? Or strings? or some other type that doesn't have anything to do with numbers. Sure, we can get by with<: Number
even if the object isn't actually a number, but it's awkward.I think once (if?) MTK starts doing algebraic simplicifications, it's going to need to be able to deal with applying type dependant simplification rules, for example only assuming commutativity between
Real
andComplex
numbers but not sayMatrix
orQuaternion
.One approach for getting around the sub-typing constraint that I've been exploring is to instead to have symbolic computations happen inside a Cassette.jl pass, inside of which we apply code transformations such that
Expression{T}
acts like it is aT
.It turns out that @shashi just recently opened a PR to Cassette that would make this quite easy, so I mocked up a demo of what I was thinking that I'd like to share here for discussion.
Two notes
Okay, with that setup code out of the way, lets see what this got us. Suppose we have two functions from an external library that have very strict type constraints:
With the above definitions, we have no problems getting inside these functions:
Okay, I thought that was pretty neat, but there's still some major problems with this sort of approach. Here's a two that are forefront in my mind:
This approach can be too good at recursing into code. For instance, you wouldn't want to accidentally call
sin(x)
on a symbolicx
if it wasn't registered because then you'd be exposed to the internals ofsin
which nobody is interested in seeing. Alternatively, we could make it so that it only recurses into registered functions instead of the other way around, but I think that might almost as problematic because it'd be incredibly annoying to have to register every function you might care about, plus it'd make symbolic differentiation more difficult.If we want types associated with expressions, then we either have to build in the mapping between input and output types for every registered function, or we have to rely on
Core.Compiler.return_type
which brings it's own host of problems, which are not limited to the usual warnings you'll hear from complier people. For instance,return_type(+, Tuple{Real, Real}) == Any
, soSym
s should be concretely typed instead of abstract 😕.Do you think this is a dead end? Do you think it's promising? Any other thoughts or comments? Questions?
The text was updated successfully, but these errors were encountered: