Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discussion: The problem with Expression <: Number, and a proposed fix #38

Closed
MasonProtter opened this issue Jan 21, 2020 · 2 comments
Closed

Comments

@MasonProtter
Copy link
Member

I wasn't sure where to put this but I figured this was a good place. Maybe discourse instead? Open to suggestions


@ChrisRackauckas and I discussed yesterday that julia's subtyping dispatch model makes symbolic graph languages like ModelingToolkit (MTK) difficult. For instance, in MTK we have Expression <: Number, but this is far from ideal. For instance, what if someone is more interested in symbolically representing arrays? Or strings? or some other type that doesn't have anything to do with numbers. Sure, we can get by with <: Number even if the object isn't actually a number, but it's awkward.

I think once (if?) MTK starts doing algebraic simplicifications, it's going to need to be able to deal with applying type dependant simplification rules, for example only assuming commutativity between Real and Complex numbers but not say Matrix or Quaternion.

One approach for getting around the sub-typing constraint that I've been exploring is to instead to have symbolic computations happen inside a Cassette.jl pass, inside of which we apply code transformations such that Expression{T} acts like it is a T.

It turns out that @shashi just recently opened a PR to Cassette that would make this quite easy, so I mocked up a demo of what I was thinking that I'd like to share here for discussion.

Two notes

  1. The following code requires Shashi's Cassette PR ReflectOn: pick which method's body to rewrite JuliaLabs/Cassette.jl#157.
  2. Apologies for not following MTK naming conventions in this demo but I think you'll get the idea
#--------------------------------------------------------------------------------
# Set up some symbolic types
#--------------------------------------------------------------------------------
abstract type Symbolic{T} end # Symbolic{T} will act like it is <: T

struct Sym{T} <: Symbolic{T}
    name::Symbol
end

Symbol(s::Sym) = s.name

struct SymExpr{T} <: Symbolic{T}
    op
    args::Vector{Any}
end

#--------------------------------------------------------------------------------
# Pretty printing
#--------------------------------------------------------------------------------
function Base.show(io::IO, s::Sym{T}) where {T}
    print(io, string(s.name)*"::$T")
end

expr(se::SymExpr, Ts) = Expr(:call, expr(se.op, Ts), expr.(se.args, (Ts,))...)
expr(x, Ts)           = x
expr(f::Function, Ts) = Symbol(f)
function expr(s::Sym{T}, Ts) where {T} 
    if s  Ts
        push!(Ts, s)
    end
    s.name
end

function Base.show(io::IO, se::SymExpr{T}) where {T}
    sset = Set()
    ex = expr(se, sset)
    print(io, repr(ex)[2:end]*"::$T" * " where {"*repr(sset)[9:end-2]*"}")
end

#--------------------------------------------------------------------------------
# Set up the Cassette pass
#--------------------------------------------------------------------------------
using Cassette: Cassette, overdub, @context, ReflectOn
using Base.Core.Compiler: return_type
using SpecializeVarargs
@context SymContext

sym_substitute(::Type{Sym{T}})     where {T} = T
sym_substitute(::Type{SymExpr{T}}) where {T} = T
sym_substitute(::Type{T})          where {T} = T

function Cassette.overdub(ctx::SymContext, f::Function, args...)
    argsT = typeof(args)
    if any((<:).(argsT.parameters, Symbolic))
        argsT′ = [sym_substitute.(argsT.parameters)...]
        if isprimitive(f)
            SymExpr{return_type(f, Tuple{argsT′...})}(f, [args...])
        else
            overdub(ctx, ReflectOn{Tuple{typeof(f), argsT′...}}(), f, args...)
        end
    else
        f(args...)
    end
end

# If isprimitive(f) == true, then inside a pass, we won't recurse into the insides of f. 
# Primitives are stopping points for us 
for f in [:+, :-, :*, :/, :^, :exp, :log, 
          :sin, :cos, :tan, :asin, :acos, :atan,
          :sinh, :cosh, :tanh, :asinh, :acosh, :atanh, :adjoint]
    @eval isprimitive(::typeof($f)) = true
end
isprimitive(::Any) = false

# Convenience macro for wrapping any enclosed function calls in the cassette pass
using MacroTools: postwalk
macro sym(expr)
    out = postwalk(expr) do ex
        if ex isa Expr && ex.head == :call
            :(overdub(SymContext(), $(ex.args[1]), $(ex.args[2:end]...)))
        else
            ex
        end
    end
    esc(out)
end

Okay, with that setup code out of the way, lets see what this got us. Suppose we have two functions from an external library that have very strict type constraints:

f(x::Float64)     = 1 + sin(x)^(1/2)
g(x::Vector{Int}) = x'x + 2

With the above definitions, we have no problems getting inside these functions:

julia> x = Sym{Float64}(:x)
x::Float64

julia> y = Sym{Vector{Int}}(:y)
y::Array{Int64,1}

julia> @sym f(x)
(1 + sin(x) ^ 0.5)::Float64 where {x::Float64}

julia> @sym g(y)
(adjoint(y) * y + 2)::Int64 where {y::Array{Int64,1}}

Okay, I thought that was pretty neat, but there's still some major problems with this sort of approach. Here's a two that are forefront in my mind:

  1. This approach can be too good at recursing into code. For instance, you wouldn't want to accidentally call sin(x) on a symbolic x if it wasn't registered because then you'd be exposed to the internals of sin which nobody is interested in seeing. Alternatively, we could make it so that it only recurses into registered functions instead of the other way around, but I think that might almost as problematic because it'd be incredibly annoying to have to register every function you might care about, plus it'd make symbolic differentiation more difficult.

  2. If we want types associated with expressions, then we either have to build in the mapping between input and output types for every registered function, or we have to rely on Core.Compiler.return_type which brings it's own host of problems, which are not limited to the usual warnings you'll hear from complier people. For instance, return_type(+, Tuple{Real, Real}) == Any, so Syms should be concretely typed instead of abstract 😕.

Do you think this is a dead end? Do you think it's promising? Any other thoughts or comments? Questions?

@YingboMa
Copy link
Member

We could register all the functions with frules.

How would staged programming work in this setup? I.e. how could a user run a simplification pass and get an optimized function?

@MasonProtter
Copy link
Member Author

MasonProtter commented Jan 21, 2020

We could register all the functions with frules.

That's a great idea.

I think that, along with a 'blacklist' of functions where we might want to throw an error if we recurse into them would be pretty solid.

How would staged programming work in this setup? I.e. how could a user run a simplification pass and get an optimized function?

This is mostly just about getting the computational graph, not how it's used. The only thing that changes for simplification and staged programming is that (supposing we get good type inference) the computational graph is fully typed at every level, letting us do type dependent optimizations if we wish.

@ChrisRackauckas ChrisRackauckas transferred this issue from SciML/ModelingToolkit.jl Feb 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants