Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract not working within a rule #403

Open
AzizZayed opened this issue Jul 31, 2024 · 6 comments
Open

Extract not working within a rule #403

AzizZayed opened this issue Jul 31, 2024 · 6 comments

Comments

@AzizZayed
Copy link

Given these types and functions:

(datatype Dimension
    (UnknownDim)
    (Dim i64)
)

(function idim (Dimension) i64)
(rule
	((Dim i))
 	((set (idim (Dim i)) i))
)

(datatype Shape
   (Shape2D Dimension Dimension)
)

(function nrows (Shape) Dimension)
(function ncols (Shape) Dimension)
(rewrite (nrows (Shape2D ?r ?z)) ?r)
(rewrite (ncols (Shape2D ?r ?z)) ?z)

(datatype MatOp
   (Mat String Shape)
   (MatMul MatOp MatOp Shape)
)

(function get_shape (MatOp) Shape)
(rewrite (get_shape (Mat ?n ?s)) ?s)
(rewrite (get_shape (MatMul ?x ?y ?s)) ?s)
   
(let mx (Mat "x" (Shape2D (Dim 5) (Dim 10)))) ; a*b = 5*10
(let my (Mat "y" (Shape2D (Dim 10) (Dim 15)))) ; b*c = 10*15
(let mz (Mat "z" (Shape2D (Dim 15) (Dim 2)))) ; c*d = 15*2

(let mxy (MatMul mx my (Shape2D (Dim 5) (Dim 15))))
(let myz (MatMul my mz (Shape2D (Dim 10) (Dim 2))))

The extract action is not working within this rule:

(rule
    ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c)))
     (= ?sy (ncols (get_shape ?y))))
    ((extract ?sy))
)

(run 100)

The output is simply blank. Which I assume means the lhs was not found in the egraph and thus no match occurred. However the following rule works:

(rule
    ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c))))
    ((extract ?c))
)

(run 100)

I'm assuming this means that the rule can't find (ncols (get_shape ?y)) however, I want (= ?sy (ncols (get_shape ?y))) to be like calling a function to retrieve the number of columns in ?y. What is going on here?

@saulshanabrook
Copy link
Member

I'm assuming this means that the rule can't find (ncols (get_shape ?y)) however, I want (= ?sy (ncols (get_shape ?y))) to be like calling a function to retrieve the number of columns in ?y.

You need to "demand" that expression so that it exists in the e-graph. Something won't match an expression that isn't already adding to the e-graph and also that expression won't be rewritten unless it's added. So you need to do something like:

(rule
    ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c))))
    ((ncols (get_shape ?y)))
)

(run 100)

(rule
    ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c)))
     (= ?sy (ncols (get_shape ?y))))
    ((extract ?sy))
)

@AzizZayed
Copy link
Author

AzizZayed commented Aug 1, 2024

Can we use a let expression within a rule to add things to the egraph, instead of using another rule? Like

(rule
    ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c)))
     (let sy (ncols (get_shape ?y)))) ; let in facts
    ((extract sy))
)

or

(rule
    ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c))))
    ((let sy (ncols (get_shape ?y))) (extract sy)) ; let in actions
)

@saulshanabrook
Copy link
Member

Yes the second example should work you posted. Let can appear in an action but not in a fact.

@AzizZayed
Copy link
Author

AzizZayed commented Aug 1, 2024

How come this does not work?

I want to add (ncols (get_shape ?y)) in the egraph with the let within the actions of a rule.

(rule
    ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c))))
    ((let b (ncols (get_shape ?y))) )
)

I get the error in the console:

panicked at src/lib.rs:1039:37:
error while running actions for (rule ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c))))
      ((let b (ncols (get_shape ?y))))
         ): Not found: No value found for ncols [Value { tag: "Shape", bits: 15 }]

Stack:

Error
    at imports.wbg.__wbg_new_abda76e883ba8a5f (https://egraphs-good.github.io/egglog/web_demo.js:339:21)
    at https://egraphs-good.github.io/egglog/web_demo_bg.wasm:wasm-function[1592]:0x150048
    at https://egraphs-good.github.io/egglog/web_demo_bg.wasm:wasm-function[864]:0x13953a
    at https://egraphs-good.github.io/egglog/web_demo_bg.wasm:wasm-function[1274]:0x14a566
    at https://egraphs-good.github.io/egglog/web_demo_bg.wasm:wasm-function[55]:0x74c7e
    at https://egraphs-good.github.io/egglog/web_demo_bg.wasm:wasm-function[31]:0x49b1d
    at https://egraphs-good.github.io/egglog/web_demo_bg.wasm:wasm-function[31]:0x48ae5
    at https://egraphs-good.github.io/egglog/web_demo_bg.wasm:wasm-function[20]:0x1fc7d
    at https://egraphs-good.github.io/egglog/web_demo_bg.wasm:wasm-function[33]:0x4d4f9
    at __exports.run_program (https://egraphs-good.github.io/egglog/web_demo.js:193:22)

Full program:

(datatype Dimension
    (UnknownDim)
    (Dim i64)
)

(function idim (Dimension) i64)

(rule
	((Dim i))
 	((set (idim (Dim i)) i))
)

(sort DimVec (Vec Dimension))

(datatype Shape
   (Shape1D Dimension)
   (Shape2D Dimension Dimension)
   (ShapeND DimVec)
)

(function nrows (Shape) Dimension)
(function ncols (Shape) Dimension :merge old)
(function dimN (Shape i64) Dimension)

(rewrite (nrows (Shape1D ?d)) ?d)
(rewrite (dimN (Shape1D ?d) 0) ?d)

(rewrite (nrows (Shape2D ?r ?z)) ?r)
(rewrite (ncols (Shape2D ?r ?z)) ?z)
(rewrite (dimN (Shape2D ?r ?z) 0) ?r)
(rewrite (dimN (Shape2D ?r ?z) 1) ?z)

(rewrite (nrows (ShapeND ?dvec)) (dimN (ShapeND ?dvec) 0))
(rewrite (ncols (ShapeND ?dvec)) (dimN (ShapeND ?dvec) 1))
(rewrite (dimN (ShapeND ?dvec) ?i) (vec-get ?dvec ?i))

(datatype MatOp
   (Mat String Shape)
   (MatMul MatOp MatOp Shape)
)

(function get_shape (MatOp) Shape)

(rewrite (get_shape (Mat ?n ?s)) ?s)
(rewrite (get_shape (MatMul ?x ?y ?s)) ?s)

(function Cost (MatOp) i64 :merge (max old new))

(rule ((= ?lhs (Mat ?n ?s))) ((set (Cost ?lhs) 0)))

(rule
    ((= ?lhs (MatMul ?x ?y (Shape2D ?a ?c))))
    ((let b (ncols (get_shape ?y))) )
) ; (set (Cost ?lhs) (+ (+ (Cost ?x) (Cost ?y)) (* (* (idim ?a) (idim ?b)) (idim ?c))))
    
(let mx (Mat "x" (Shape2D (Dim 5) (Dim 10)))) ; a*b = 5*10
(let my (Mat "y" (Shape2D (Dim 10) (Dim 15)))) ; b*c = 10*15
(let mz (Mat "z" (Shape2D (Dim 15) (Dim 2)))) ; c*d = 15*2

(let sx (get_shape mx))
(let sy (get_shape my))
(let sz (get_shape mz))

(let mxy (MatMul mx my (Shape2D (Dim 5) (Dim 15))))
(let myz (MatMul my mz (Shape2D (Dim 10) (Dim 2))))

(let sxy (get_shape mxy))
(let syz (get_shape myz))

(run 1000)

(print-function idim 10)
(print-function get_shape 10)
(print-function nrows 10)
(print-function ncols 10)
(print-function Cost 10)

@saulshanabrook
Copy link
Member

If you have a function with a custom merge function, as you do here, then it won't have a default value... You can set a custom default value with :default <...>. With the default merge function then the default value is like an empty equivalence class.

Sorry that is the most confusing part of egglog and even after understanding it, it isn't particularly intuitive.

Basically just any function that has a custom merge function without a default or a function that returns a primitive must be set explicitly to a value instead of being just requested without a value being set first.

@oflatt
Copy link
Member

oflatt commented Aug 8, 2024

Note that default is being removed in egglog soon! It's too confusing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants