From 4c7c3df9c3d675f0faeff3e7a6b7271088cabd5f Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 15:26:08 -0700 Subject: [PATCH 01/13] Rename make-egg-runner to make-egraph --- src/core/egg-herbie.rkt | 14 +++++++------- src/core/localize.rkt | 12 ++++++------ src/core/mainloop.rkt | 2 +- src/core/patch.rkt | 4 ++-- src/core/preprocess.rkt | 10 +++++----- src/core/simplify.rkt | 8 ++++---- 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 06c6e0c8f..a74f63113 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -22,7 +22,7 @@ "batch.rkt") (provide (struct-out egg-runner) - make-egg-runner + make-egraph run-egg) (module+ test @@ -57,7 +57,7 @@ id->spec)) ; map from e-class id to an approx-spec or #f ; Makes a new egraph that is managed by Racket's GC -(define (make-egraph) +(define (make-egraph-data) (egraph-data (egraph_create) (make-hash) (make-hash) (make-hash))) ; Creates a new runner using an existing egraph. @@ -347,7 +347,7 @@ (cons '(cos.f32 (PI.f32)) '(cos.f32 (PI.f32))) (cons '(if (TRUE) x y) '(if (TRUE) $h1 $h0)))) - (let ([egg-graph (make-egraph)]) + (let ([egg-graph (make-egraph-data)]) (for ([(in expected-out) (in-dict test-exprs)]) (define out (expr->egg-expr in egg-graph (*context*))) (define computed-in (egg-expr->expr out egg-graph (context-repr (*context*)))) @@ -375,7 +375,7 @@ `(*.f64 ,(literal 23/54 'binary64) r) `(+.f64 ,(literal 3/2 'binary64) ,(literal 1.4 'binary64)))) - (let ([egg-graph (make-egraph)]) + (let ([egg-graph (make-egraph-data)]) (for ([expr extended-expr-list]) (define egg-expr (expr->egg-expr expr egg-graph (*context*))) (check-equal? (egg-expr->expr egg-expr egg-graph (context-repr (*context*))) expr)))) @@ -1161,7 +1161,7 @@ (define (egraph-run-schedule batch roots schedule ctx) ; allocate the e-graph - (define egg-graph (make-egraph)) + (define egg-graph (make-egraph-data)) ; insert expressions into the e-graph (define root-ids (egraph-add-exprs egg-graph batch roots ctx)) @@ -1191,7 +1191,7 @@ ;; Public API ;; ;; Most calls to egg should be done through this interface. -;; - `make-egg-runner`: creates a struct that describes a _reproducible_ egg instance +;; - `make-egraph`: creates a struct that describes an egraph ;; - `run-egg`: takes an egg runner and performs an extraction (exprs or proof) ;; Herbie's version of an egg runner. @@ -1213,7 +1213,7 @@ ;; - scheduler: `(scheduler . )` [default: backoff] ;; - `simple`: run all rules without banning ;; - `backoff`: ban rules if the fire too much -(define (make-egg-runner batch roots reprs schedule #:context [ctx (*context*)]) +(define (make-egraph batch roots reprs schedule #:context [ctx (*context*)]) (define (oops! fmt . args) (apply error 'verify-schedule! fmt args)) ; verify the schedule diff --git a/src/core/localize.rkt b/src/core/localize.rkt index 531e849a3..b092c0c5c 100644 --- a/src/core/localize.rkt +++ b/src/core/localize.rkt @@ -83,12 +83,12 @@ ; egg runner (2-phases for real rewrites and implementation selection) (define batch (progs->batch progs)) (define runner - (make-egg-runner batch - (batch-roots batch) - reprs - `((,lifting-rules . ((iteration . 1) (scheduler . simple))) - (,rules . ((node . ,(*node-limit*)))) - (,lowering-rules . ((iteration . 1) (scheduler . simple)))))) + (make-egraph batch + (batch-roots batch) + reprs + `((,lifting-rules . ((iteration . 1) (scheduler . simple))) + (,rules . ((node . ,(*node-limit*)))) + (,lowering-rules . ((iteration . 1) (scheduler . simple)))))) ; run egg (define simplified (map (compose debatchref last) (simplify-batch runner batch))) diff --git a/src/core/mainloop.rkt b/src/core/mainloop.rkt index c447ff901..007043372 100644 --- a/src/core/mainloop.rkt +++ b/src/core/mainloop.rkt @@ -360,7 +360,7 @@ (define exprs (map alt-expr alts)) (define reprs (map (lambda (expr) (repr-of expr (*context*))) exprs)) (define batch (progs->batch exprs)) - (define runner (make-egg-runner batch (batch-roots batch) reprs schedule)) + (define runner (make-egraph batch (batch-roots batch) reprs schedule)) ; run egg (define simplified (map (compose debatchref last) (simplify-batch runner batch))) diff --git a/src/core/patch.rkt b/src/core/patch.rkt index c912e95e3..63ab439a4 100644 --- a/src/core/patch.rkt +++ b/src/core/patch.rkt @@ -44,7 +44,7 @@ (batchref-idx (alt-expr approx)))) ; run egg - (define runner (make-egg-runner global-batch roots reprs schedule)) + (define runner (make-egraph global-batch roots reprs schedule)) (define simplification-options (simplify-batch runner global-batch)) ; convert to altns @@ -135,7 +135,7 @@ (define reprs (map (curryr repr-of (*context*)) exprs)) (timeline-push! 'inputs (map ~a exprs)) - (define runner (make-egg-runner global-batch roots reprs schedule)) + (define runner (make-egraph global-batch roots reprs schedule)) ; batchrefss is a (listof (listof batchref)) (define batchrefss (run-egg runner (cons 'multi global-batch))) diff --git a/src/core/preprocess.rkt b/src/core/preprocess.rkt index 3c8cb9c9b..20452e8d5 100644 --- a/src/core/preprocess.rkt +++ b/src/core/preprocess.rkt @@ -64,7 +64,7 @@ ; egg query (define batch (progs->batch (list expr))) - (define runner (make-egg-runner batch (batch-roots batch) (list (context-repr ctx)) schedule)) + (define runner (make-egraph batch (batch-roots batch) (list (context-repr ctx)) schedule)) ; run egg (define simplified (simplify-batch runner batch)) @@ -99,10 +99,10 @@ (define batch (progs->batch specs)) (define runner - (make-egg-runner batch - (batch-roots batch) - (map (lambda (_) (context-repr ctx)) specs) - `((,rules . ((node . ,(*node-limit*))))))) + (make-egraph batch + (batch-roots batch) + (map (lambda (_) (context-repr ctx)) specs) + `((,rules . ((node . ,(*node-limit*))))))) ;; run egg to check for identities (define expr-pairs (map (curry cons spec) specs)) diff --git a/src/core/simplify.rkt b/src/core/simplify.rkt index d2e4f41a4..d8654d0a7 100644 --- a/src/core/simplify.rkt +++ b/src/core/simplify.rkt @@ -40,10 +40,10 @@ (define (test-simplify . args) (define batch (progs->batch args)) (define runner - (make-egg-runner batch - (batch-roots batch) - (map (lambda (_) 'real) args) - `((,(*simplify-rules*) . ((node . ,(*node-limit*))))))) + (make-egraph batch + (batch-roots batch) + (map (lambda (_) 'real) args) + `((,(*simplify-rules*) . ((node . ,(*node-limit*))))))) (parameterize ([*egraph-platform-cost* #f]) (map (compose debatchref last) (simplify-batch runner batch)))) From c546be101039e53262c0e212905c2755e16d9a99 Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 15:35:11 -0700 Subject: [PATCH 02/13] Split run-egg into four different functions --- src/core/derivations.rkt | 2 +- src/core/egg-herbie.rkt | 132 ++++++++++++++++++++------------------- src/core/patch.rkt | 2 +- src/core/preprocess.rkt | 4 +- src/core/simplify.rkt | 2 +- 5 files changed, 74 insertions(+), 68 deletions(-) diff --git a/src/core/derivations.rkt b/src/core/derivations.rkt index 3a2e5e0ac..d4e584d13 100644 --- a/src/core/derivations.rkt +++ b/src/core/derivations.rkt @@ -20,7 +20,7 @@ [(alt expr (list (or 'simplify 'rr) loc (? egg-runner? runner) #f) `(,prev) _) (define start-expr (location-get loc (alt-expr prev))) (define end-expr (location-get loc expr)) - (define proof (first (run-egg runner `(proofs ,(cons start-expr end-expr))))) + (define proof (egraph-prove runner start-expr end-expr)) (define proof* (canonicalize-proof (alt-expr altn) proof loc)) (alt expr `(rr ,loc ,runner ,proof*) `(,prev) '())] diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index a74f63113..e7a99e333 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -23,7 +23,10 @@ (provide (struct-out egg-runner) make-egraph - run-egg) + egraph-equal? + egraph-prove + egraph-best + egraph-variations) (module+ test (require rackunit) @@ -1191,8 +1194,11 @@ ;; Public API ;; ;; Most calls to egg should be done through this interface. -;; - `make-egraph`: creates a struct that describes an egraph -;; - `run-egg`: takes an egg runner and performs an extraction (exprs or proof) +;; - `make-egraph`: constructs an egraph and runs rules on it +;; - `egraph-equal?`: test if two expressions are equal +;; - `egraph-prove`: return a proof that two expressions are equal +;; - `egraph-best`: return a batch with the best versions of another batch +;; - `egraph-variations`: return a batch with all versions of another batch ;; Herbie's version of an egg runner. ;; Defines parameters for running rewrite rules with egg @@ -1259,67 +1265,65 @@ #:exists 'replace (lambda (p) (write-json (hash 'nodes nodes 'root_eclasses (map ~a roots) 'class_data (hash)) p)))) -;; Runs egg using an egg runner. -;; -;; Argument `cmd` specifies what to get from the e-graph: -;; - single extraction: `(single . )` -;; - multi extraction: `(multi . )` -;; - proofs: `(proofs . (( . ) ...))` -(define (run-egg runner cmd) - ;; Run egg using runner +(define (egraph-equal? runner start end) + (define ctx (egg-runner-ctx runner)) + (define egg-graph (egg-runner-egg-graph runner)) + (egraph-expr-equal? egg-graph start end ctx)) + +(define (egraph-prove runner start end) + (define ctx (egg-runner-ctx runner)) + (define egg-graph (egg-runner-egg-graph runner)) + + (unless (egraph-expr-equal? egg-graph start end ctx) + (error 'egraph-prove + "cannot prove ~a is equal to ~a; not equal" + start + end)) + (define proof (egraph-get-proof egg-graph start end ctx)) + (when (null? proof) + (error 'egraph-prove "proof extraction failed between`~a` and `~a`" start end)) + proof) + +(define (egraph-best runner batch) (define ctx (egg-runner-ctx runner)) (define root-ids (egg-runner-new-roots runner)) (define egg-graph (egg-runner-egg-graph runner)) - ; Perform extraction - (match cmd - [`(single . ,batch) ; single expression extraction - (define regraph (make-regraph egg-graph)) - (define reprs (egg-runner-reprs runner)) - (when (flag-set? 'dump 'egg) - (regraph-dump regraph root-ids reprs)) - - (define extract-id ((typed-egg-batch-extractor batch) regraph)) - (define finalize-batch (last extract-id)) - - ; (Listof (Listof batchref)) - (define out - (for/list ([id (in-list root-ids)] - [repr (in-list reprs)]) - (regraph-extract-best regraph extract-id id repr))) - ; commit changes to the batch - (finalize-batch) - out] - [`(multi . ,batch) ; multi expression extraction - (define regraph (make-regraph egg-graph)) - (define reprs (egg-runner-reprs runner)) - (when (flag-set? 'dump 'egg) - (regraph-dump regraph root-ids reprs)) - - (define extract-id ((typed-egg-batch-extractor batch) regraph)) - (define finalize-batch (last extract-id)) - - ; (Listof (Listof batchref)) - (define out - (for/list ([id (in-list root-ids)] - [repr (in-list reprs)]) - (regraph-extract-variants regraph extract-id id repr))) - ; commit changes to the batch - (finalize-batch) - out] - [`(proofs . ((,start-exprs . ,end-exprs) ...)) ; proof extraction - (for/list ([start (in-list start-exprs)] - [end (in-list end-exprs)]) - (unless (egraph-expr-equal? egg-graph start end ctx) - (error 'run-egg - "cannot find proof; start and end are not equal.\n start: ~a \n end: ~a" - start - end)) - (define proof (egraph-get-proof egg-graph start end ctx)) - (when (null? proof) - (error 'run-egg "proof extraction failed between`~a` and `~a`" start end)) - proof)] - [`(equal? . ((,start-exprs . ,end-exprs) ...)) ; term equality? - (for/list ([start (in-list start-exprs)] - [end (in-list end-exprs)]) - (egraph-expr-equal? egg-graph start end ctx))] - [_ (error 'run-egg "unknown command `~a`\n" cmd)])) + + (define regraph (make-regraph egg-graph)) + (define reprs (egg-runner-reprs runner)) + (when (flag-set? 'dump 'egg) + (regraph-dump regraph root-ids reprs)) + + (define extract-id ((typed-egg-batch-extractor batch) regraph)) + (define finalize-batch (last extract-id)) + + ; (Listof (Listof batchref)) + (define out + (for/list ([id (in-list root-ids)] + [repr (in-list reprs)]) + (regraph-extract-best regraph extract-id id repr))) + ; commit changes to the batch + (finalize-batch) + out) + +(define (egraph-variations runner batch) + (define ctx (egg-runner-ctx runner)) + (define root-ids (egg-runner-new-roots runner)) + (define egg-graph (egg-runner-egg-graph runner)) + + (define regraph (make-regraph egg-graph)) + (define reprs (egg-runner-reprs runner)) + (when (flag-set? 'dump 'egg) + (regraph-dump regraph root-ids reprs)) + + (define extract-id ((typed-egg-batch-extractor batch) regraph)) + (define finalize-batch (last extract-id)) + + ; (Listof (Listof batchref)) + (define out + (for/list ([id (in-list root-ids)] + [repr (in-list reprs)]) + (regraph-extract-variants regraph extract-id id repr))) + ; commit changes to the batch + (finalize-batch) + out) diff --git a/src/core/patch.rkt b/src/core/patch.rkt index 63ab439a4..7ffa314d1 100644 --- a/src/core/patch.rkt +++ b/src/core/patch.rkt @@ -137,7 +137,7 @@ (define runner (make-egraph global-batch roots reprs schedule)) ; batchrefss is a (listof (listof batchref)) - (define batchrefss (run-egg runner (cons 'multi global-batch))) + (define batchrefss (egraph-variations runner global-batch)) ; apply changelists (define rewritten diff --git a/src/core/preprocess.rkt b/src/core/preprocess.rkt index 20452e8d5..fb382313f 100644 --- a/src/core/preprocess.rkt +++ b/src/core/preprocess.rkt @@ -106,7 +106,9 @@ ;; run egg to check for identities (define expr-pairs (map (curry cons spec) specs)) - (define equal?-lst (run-egg runner `(equal? . ,expr-pairs))) + (define equal?-lst + (for/list ([(start end) (in-dict expr-pairs)]) + (egraph-equal? runner start end))) ;; collect equalities (define abs-instrs '()) diff --git a/src/core/simplify.rkt b/src/core/simplify.rkt index d8654d0a7..eeae08e7d 100644 --- a/src/core/simplify.rkt +++ b/src/core/simplify.rkt @@ -19,7 +19,7 @@ (define (simplify-batch runner batch) (timeline-push! 'inputs (map ~a (batch->progs (egg-runner-batch runner) (egg-runner-roots runner)))) - (define simplifieds (run-egg runner (cons 'single batch))) + (define simplifieds (egraph-best runner batch)) (define out (for/list ([simplified (in-list simplifieds)] [root (egg-runner-roots runner)]) From 556c856f3183e2dcde752fa776ae96a8a2998637 Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 15:48:04 -0700 Subject: [PATCH 03/13] Simplify preprocessing code using egraph-equal? --- src/core/preprocess.rkt | 41 ++++++++++++++++------------------------- 1 file changed, 16 insertions(+), 25 deletions(-) diff --git a/src/core/preprocess.rkt b/src/core/preprocess.rkt index fb382313f..bc34512dd 100644 --- a/src/core/preprocess.rkt +++ b/src/core/preprocess.rkt @@ -87,40 +87,31 @@ (define swap-identities (make-swap-identities spec ctx)) (define identities (append even-identities odd-identities swap-identities)) - (define specs - (for/list ([ident (in-list identities)]) - (match ident - [(list 'even _ spec) spec] - [(list 'odd _ spec) spec] - [(list 'swap _ spec) spec]))) - ;; make egg runner (define rules (*simplify-rules*)) - (define batch (progs->batch specs)) + (define batch (progs->batch (map third identities))) (define runner (make-egraph batch (batch-roots batch) - (map (lambda (_) (context-repr ctx)) specs) + (map (const (context-repr ctx)) identities) `((,rules . ((node . ,(*node-limit*))))))) - ;; run egg to check for identities - (define expr-pairs (map (curry cons spec) specs)) - (define equal?-lst - (for/list ([(start end) (in-dict expr-pairs)]) - (egraph-equal? runner start end))) - ;; collect equalities - (define abs-instrs '()) - (define negabs-instrs '()) - (define swaps '()) - (for ([ident (in-list identities)] - [expr-equal? (in-list equal?-lst)] - #:when expr-equal?) - (match ident - [(list 'even var _) (set! abs-instrs (cons (list 'abs var) abs-instrs))] - [(list 'odd var _) (set! negabs-instrs (cons (list 'negabs var) negabs-instrs))] - [(list 'swap pair _) (set! swaps (cons pair swaps))])) + (define abs-instrs + (for/list ([ident (in-list even-identities)] + #:when (egraph-equal? runner spec (third ident))) + (list 'abs (second ident)))) + + (define negabs-instrs + (for/list ([ident (in-list odd-identities)] + #:when (egraph-equal? runner spec (third ident))) + (list 'negabs (second ident)))) + + (define swaps + (for/list ([ident (in-list swap-identities)] + #:when (egraph-equal? runner spec (third ident))) + (second ident))) (define components (connected-components (context-vars ctx) swaps)) (define sort-instrs From 2ac7a898bd3a31972e701557956a6d55f5cfa980 Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 15:54:54 -0700 Subject: [PATCH 04/13] Unify identity names, parameters, and instructions in preprocessing --- src/core/preprocess.rkt | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/core/preprocess.rkt b/src/core/preprocess.rkt index bc34512dd..e97e8f2cd 100644 --- a/src/core/preprocess.rkt +++ b/src/core/preprocess.rkt @@ -33,7 +33,7 @@ (for/list ([var (in-list (context-vars ctx))] [repr (in-list (context-var-reprs ctx))] #:when (has-fabs-neg-impls? repr)) - (list 'even var (replace-expression spec var `(neg ,var))))) + (cons `(abs ,var) (replace-expression spec var `(neg ,var))))) ;; The odd identities: f(x) = -f(-x) ;; Requires `neg` and `fabs` operator implementations. @@ -41,14 +41,14 @@ (for/list ([var (in-list (context-vars ctx))] [repr (in-list (context-var-reprs ctx))] #:when (and (has-fabs-neg-impls? repr) (has-copysign-impl? repr))) - (list 'odd var (replace-expression `(neg ,spec) var `(neg ,var))))) + (cons `(negabs ,var) (replace-expression `(neg ,spec) var `(neg ,var))))) ;; Swap identities: f(a, b) = f(b, a) (define (make-swap-identities spec ctx) (define pairs (combinations (context-vars ctx) 2)) (for/list ([pair (in-list pairs)]) (match-define (list a b) pair) - (list 'swap pair (replace-vars `((,a . ,b) (,b . ,a)) spec)))) + (cons `(swap ,a ,b) (replace-vars `((,a . ,b) (,b . ,a)) spec)))) ;; Initial simplify (define (initial-simplify expr ctx) @@ -90,29 +90,29 @@ ;; make egg runner (define rules (*simplify-rules*)) - (define batch (progs->batch (map third identities))) + (define batch (progs->batch (cons spec (map cdr identities)))) (define runner (make-egraph batch (batch-roots batch) - (map (const (context-repr ctx)) identities) + (make-list (vector-length (batch-roots batch)) (context-repr ctx)) `((,rules . ((node . ,(*node-limit*))))))) ;; collect equalities (define abs-instrs - (for/list ([ident (in-list even-identities)] - #:when (egraph-equal? runner spec (third ident))) - (list 'abs (second ident)))) + (for/list ([(ident spec*) (in-dict even-identities)] + #:when (egraph-equal? runner spec spec*)) + ident)) (define negabs-instrs - (for/list ([ident (in-list odd-identities)] - #:when (egraph-equal? runner spec (third ident))) - (list 'negabs (second ident)))) + (for/list ([(ident spec*) (in-dict odd-identities)] + #:when (egraph-equal? runner spec spec*)) + ident)) (define swaps - (for/list ([ident (in-list swap-identities)] - #:when (egraph-equal? runner spec (third ident))) - (second ident))) - + (for/list ([(ident spec*) (in-dict swap-identities)] + #:when (egraph-equal? runner spec spec*)) + (match-define (list 'swap a b) ident) + (list a b))) (define components (connected-components (context-vars ctx) swaps)) (define sort-instrs (for/list ([component (in-list components)] From 0045aa33025be9cb30a3299969c84e1228d90ec7 Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 15:58:58 -0700 Subject: [PATCH 05/13] fmt --- src/core/egg-herbie.rkt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index e7a99e333..57453b086 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -1275,10 +1275,7 @@ (define egg-graph (egg-runner-egg-graph runner)) (unless (egraph-expr-equal? egg-graph start end ctx) - (error 'egraph-prove - "cannot prove ~a is equal to ~a; not equal" - start - end)) + (error 'egraph-prove "cannot prove ~a is equal to ~a; not equal" start end)) (define proof (egraph-get-proof egg-graph start end ctx)) (when (null? proof) (error 'egraph-prove "proof extraction failed between`~a` and `~a`" start end)) @@ -1315,10 +1312,10 @@ (define reprs (egg-runner-reprs runner)) (when (flag-set? 'dump 'egg) (regraph-dump regraph root-ids reprs)) - + (define extract-id ((typed-egg-batch-extractor batch) regraph)) (define finalize-batch (last extract-id)) - + ; (Listof (Listof batchref)) (define out (for/list ([id (in-list root-ids)] From b4b924fe32c61370b922d253fa46c6e33a9eae3e Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 16:10:22 -0700 Subject: [PATCH 06/13] Remove the herbie->egg-dict field in egraph-data --- src/core/egg-herbie.rkt | 50 ++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 57453b086..9c29f18f9 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -55,13 +55,12 @@ ;; Wrapper around Rust-allocated egg runner (struct egraph-data (egraph-pointer ; FFI pointer to runner - herbie->egg-dict ; map from symbols to canonicalized names egg->herbie-dict ; inverse map id->spec)) ; map from e-class id to an approx-spec or #f ; Makes a new egraph that is managed by Racket's GC (define (make-egraph-data) - (egraph-data (egraph_create) (make-hash) (make-hash) (make-hash))) + (egraph-data (egraph_create) (make-hash) (make-hash))) ; Creates a new runner using an existing egraph. ; Useful for multi-phased rule application @@ -72,17 +71,14 @@ ; Adds expressions returning the root ids (define (egraph-add-exprs egg-data batch roots ctx) - (match-define (egraph-data ptr herbie->egg-dict egg->herbie-dict id->spec) egg-data) + (match-define (egraph-data ptr egg->herbie-dict id->spec) egg-data) ; lookups the egg name of a variable (define (normalize-var x) - (hash-ref! herbie->egg-dict - x - (lambda () - (define id (hash-count herbie->egg-dict)) - (define replacement (string->symbol (format "$h~a" id))) - (hash-set! egg->herbie-dict replacement (cons x (context-lookup ctx x))) - replacement))) + (define idx (index-of (context-vars ctx) x)) + (define replacement (string->symbol (format "$var~a" idx))) + (hash-set! egg->herbie-dict replacement (cons x (context-lookup ctx x))) + replacement) ; normalizes an approx spec (define (normalize-spec expr) @@ -266,19 +262,15 @@ ;; Result is the expression. (define (expr->egg-expr expr egg-data ctx) (define egg->herbie-dict (egraph-data-egg->herbie-dict egg-data)) - (define herbie->egg-dict (egraph-data-herbie->egg-dict egg-data)) (let loop ([expr expr]) (match expr [(? number?) expr] [(? literal?) (literal-value expr)] - [(? symbol?) - (hash-ref! herbie->egg-dict - expr - (lambda () - (define id (hash-count herbie->egg-dict)) - (define replacement (string->symbol (format "$h~a" id))) - (hash-set! egg->herbie-dict replacement (cons expr (context-lookup ctx expr))) - replacement))] + [(? symbol? x) + (define idx (index-of (context-vars ctx) x)) + (define replacement (string->symbol (format "$var~a" idx))) + (hash-set! egg->herbie-dict replacement (cons x (context-lookup ctx x))) + replacement] [(approx spec impl) (list '$approx (loop spec) (loop impl))] [(hole precision spec) (loop spec)] [(list op args ...) (cons op (map loop args))]))) @@ -333,22 +325,18 @@ (egg-parsed->expr (flatten-let egg-expr) egg->herbie type)) (module+ test - (define repr (get-representation 'binary64)) - (*context* (make-debug-context '())) - (*context* (context-extend (*context*) 'x repr)) - (*context* (context-extend (*context*) 'y repr)) - (*context* (context-extend (*context*) 'z repr)) + (*context* (make-debug-context '(x y z))) (define test-exprs - (list (cons '(+.f64 y x) '(+.f64 $h0 $h1)) - (cons '(+.f64 x y) '(+.f64 $h1 $h0)) - (cons '(-.f64 #s(literal 2 binary64) (+.f64 x y)) '(-.f64 2 (+.f64 $h1 $h0))) + (list (cons '(+.f64 y x) '(+.f64 $var1 $var0)) + (cons '(+.f64 x y) '(+.f64 $var0 $var1)) + (cons '(-.f64 #s(literal 2 binary64) (+.f64 x y)) '(-.f64 2 (+.f64 $var0 $var1))) (cons '(-.f64 z (+.f64 (+.f64 y #s(literal 2 binary64)) x)) - '(-.f64 $h2 (+.f64 (+.f64 $h0 2) $h1))) - (cons '(*.f64 x y) '(*.f64 $h1 $h0)) - (cons '(+.f64 (*.f64 x y) #s(literal 2 binary64)) '(+.f64 (*.f64 $h1 $h0) 2)) + '(-.f64 $var2 (+.f64 (+.f64 $var1 2) $var0))) + (cons '(*.f64 x y) '(*.f64 $var0 $var1)) + (cons '(+.f64 (*.f64 x y) #s(literal 2 binary64)) '(+.f64 (*.f64 $var0 $var1) 2)) (cons '(cos.f32 (PI.f32)) '(cos.f32 (PI.f32))) - (cons '(if (TRUE) x y) '(if (TRUE) $h1 $h0)))) + (cons '(if (TRUE) x y) '(if (TRUE) $var0 $var1)))) (let ([egg-graph (make-egraph-data)]) (for ([(in expected-out) (in-dict test-exprs)]) From 4dc05965b818df041d261cf5bff13f4da56aa355 Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 16:34:43 -0700 Subject: [PATCH 07/13] Get rid of egg->herbie-dict, use normal context instead --- src/core/batch.rkt | 86 ------------------- src/core/egg-herbie.rkt | 184 ++++++++++++++++++++++++++++------------ 2 files changed, 131 insertions(+), 139 deletions(-) diff --git a/src/core/batch.rkt b/src/core/batch.rkt index eeb881902..d78af6176 100644 --- a/src/core/batch.rkt +++ b/src/core/batch.rkt @@ -13,7 +13,6 @@ batch-ref ; Batch -> Idx -> Expr deref ; Batchref -> Expr batch-replace ; Batch -> (Expr -> Expr) -> Batch - egg-nodes->batch ; Nodes -> Spec-maps -> Batch -> (Listof Batchref) debatchref ; Batchref -> Expr batch-remove-zombie ; Batch -> ?(Vectorof Root) -> Batch mutable-batch-munge! ; Mutable-batch -> Expr -> Root @@ -177,91 +176,6 @@ [n (in-naturals)]) (cons node n)))) -(define (egg-nodes->batch egg-nodes id->spec input-batch rename-dict) - (define out (batch->mutable-batch input-batch)) - ; This fuction here is only because of cycles in loads:( Can not be imported from egg-herbie.rkt - (define (egg-parsed->expr expr rename-dict type) - (let loop ([expr expr] - [type type]) - (match expr - [(? number?) - (if (representation? type) - (literal expr (representation-name type)) - expr)] - [(? symbol?) - (if (hash-has-key? rename-dict expr) - (car (hash-ref rename-dict expr)) - (list expr))] - [(list '$approx spec impl) - (define spec-type - (if (representation? type) - (representation-type type) - type)) - (approx (loop spec spec-type) (loop impl type))] - [(list 'if cond ift iff) - (if (representation? type) - (list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type)) - (list 'if (loop cond 'bool) (loop ift type) (loop iff type)))] - [(list (? impl-exists? impl) args ...) (cons impl (map loop args (impl-info impl 'itype)))] - [(list op args ...) (cons op (map loop args (operator-info op 'itype)))]))) - - (define (eggref id) - (cdr (vector-ref egg-nodes id))) - - (define (add-enode enode type) - (define idx - (let loop ([enode enode] - [type type]) - (define enode* - (match enode - [(? number?) - (if (representation? type) - (literal enode (representation-name type)) - enode)] - [(? symbol?) - (if (hash-has-key? rename-dict enode) - (car (hash-ref rename-dict enode)) - enode)] - [(list '$approx spec (app eggref impl)) - (define spec* (vector-ref id->spec spec)) - (unless spec* - (error 'regraph-extract-variants "no initial approx node in eclass")) - (define spec-type - (if (representation? type) - (representation-type type) - type)) - (define final-spec (egg-parsed->expr spec* rename-dict spec-type)) - (define final-spec-idx (mutable-batch-munge! out final-spec)) - (approx final-spec-idx (loop impl type))] - [(list 'if (app eggref cond) (app eggref ift) (app eggref iff)) - (if (representation? type) - (list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type)) - (list 'if (loop cond 'bool) (loop ift type) (loop iff type)))] - [(list (? impl-exists? impl) (app eggref args) ...) - (define args* - (for/list ([arg (in-list args)] - [type (in-list (impl-info impl 'itype))]) - (loop arg type))) - (cons impl args*)] - [(list (? operator-exists? op) (app eggref args) ...) - (define args* - (for/list ([arg (in-list args)] - [type (in-list (operator-info op 'itype))]) - (loop arg type))) - (cons op args*)])) - (mutable-batch-push! out enode*))) - (batchref input-batch idx)) - - ; same as add-enode but works with index as an input instead of enode - (define (add-id id type) - (add-enode (eggref id) type)) - - ; Commit changes to the input-batch - (define (finalize-batch) - (batch-copy-mutable-nodes! input-batch out)) - - (values add-id add-enode finalize-batch)) - ; Tests for progs->batch and batch->progs (module+ test (require rackunit) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 9c29f18f9..c1ef9d37b 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -55,12 +55,11 @@ ;; Wrapper around Rust-allocated egg runner (struct egraph-data (egraph-pointer ; FFI pointer to runner - egg->herbie-dict ; inverse map id->spec)) ; map from e-class id to an approx-spec or #f ; Makes a new egraph that is managed by Racket's GC (define (make-egraph-data) - (egraph-data (egraph_create) (make-hash) (make-hash))) + (egraph-data (egraph_create) (make-hash))) ; Creates a new runner using an existing egraph. ; Useful for multi-phased rule application @@ -71,20 +70,13 @@ ; Adds expressions returning the root ids (define (egraph-add-exprs egg-data batch roots ctx) - (match-define (egraph-data ptr egg->herbie-dict id->spec) egg-data) - - ; lookups the egg name of a variable - (define (normalize-var x) - (define idx (index-of (context-vars ctx) x)) - (define replacement (string->symbol (format "$var~a" idx))) - (hash-set! egg->herbie-dict replacement (cons x (context-lookup ctx x))) - replacement) + (match-define (egraph-data ptr id->spec) egg-data) ; normalizes an approx spec (define (normalize-spec expr) (match expr [(? number?) expr] - [(? symbol?) (normalize-var expr)] + [(? symbol?) (var->egg-var expr ctx)] [(list op args ...) (cons op (map normalize-spec args))])) ; pre-allocated id vectors for all the common cases @@ -135,7 +127,7 @@ (match node [(literal v _) (insert-node! v root?)] [(? number?) (insert-node! node root?)] - [(? symbol?) (insert-node! (normalize-var node) root?)] + [(? symbol?) (insert-node! (var->egg-var node ctx) root?)] [(hole prec spec) (remap spec)] ; "hole" terms currently disappear [(approx spec impl) (hash-ref! id->spec @@ -171,13 +163,13 @@ (define (egraph-get-simplest egraph-data node-id iteration ctx) (define expr (egraph_get_simplest (egraph-data-egraph-pointer egraph-data) node-id iteration)) - (egg-expr->expr expr egraph-data (context-repr ctx))) + (egg-expr->expr expr egraph-data ctx (context-repr ctx))) (define (egraph-get-variants egraph-data node-id orig-expr ctx) (define egg-expr (expr->egg-expr orig-expr egraph-data ctx)) (define exprs (egraph_get_variants (egraph-data-egraph-pointer egraph-data) node-id egg-expr)) (for/list ([expr (in-list exprs)]) - (egg-expr->expr expr egraph-data (context-repr ctx)))) + (egg-expr->expr expr egraph-data ctx (context-repr ctx)))) (define (egraph-is-unsound-detected egraph-data) (egraph_is_unsound_detected (egraph-data-egraph-pointer egraph-data))) @@ -206,12 +198,11 @@ ;; where each enode is either a symbol, number, or list (define (egraph-get-eclass egraph-data id) (define ptr (egraph-data-egraph-pointer egraph-data)) - (define egg->herbie (egraph-data-egg->herbie-dict egraph-data)) (define eclass (egraph_get_eclass ptr id)) ; need to fix up any constant operators (for ([enode (in-vector eclass)] [i (in-naturals)]) - (when (and (symbol? enode) (not (hash-has-key? egg->herbie enode))) + (when (and (symbol? enode) (not (equal? (substring (symbol->string enode) 0 4) "$var"))) (vector-set! eclass i (cons enode empty-u32vec)))) eclass) @@ -232,7 +223,7 @@ [(<= (string-length str) (*proof-max-string-length*)) (define converted (for/list ([expr (in-port read (open-input-string str))]) - (egg-expr->expr expr egraph-data (context-repr ctx)))) + (egg-expr->expr expr egraph-data ctx (context-repr ctx)))) (define expanded (expand-proof converted (box (*proof-max-length*)))) (if (member #f expanded) #f expanded)] [else #f])) @@ -257,20 +248,23 @@ [(approx spec impl) (list '$approx (loop spec) (loop impl))] [(list op args ...) (cons op (map loop args))]))) +(define (var->egg-var var ctx) + (define idx (index-of (context-vars ctx) var)) + (string->symbol (format "$var~a" idx))) + +(define (egg-var->var egg-var ctx) + (define idx (string->number (substring (symbol->string egg-var) 4))) + (list-ref (context-vars ctx) idx)) + ;; Translates a Herbie expression into an expression usable by egg. ;; Updates translation dictionary upon encountering variables. ;; Result is the expression. (define (expr->egg-expr expr egg-data ctx) - (define egg->herbie-dict (egraph-data-egg->herbie-dict egg-data)) (let loop ([expr expr]) (match expr [(? number?) expr] [(? literal?) (literal-value expr)] - [(? symbol? x) - (define idx (index-of (context-vars ctx) x)) - (define replacement (string->symbol (format "$var~a" idx))) - (hash-set! egg->herbie-dict replacement (cons x (context-lookup ctx x))) - replacement] + [(? symbol? x) (var->egg-var x ctx)] [(approx spec impl) (list '$approx (loop spec) (loop impl))] [(hole precision spec) (loop spec)] [(list op args ...) (cons op (map loop args))]))) @@ -291,7 +285,7 @@ ;; TODO: typing information is confusing since proofs mean ;; we may process mixed spec/impl expressions; ;; only need `type` to correctly interpret numbers -(define (egg-parsed->expr expr rename-dict type) +(define (egg-parsed->expr expr ctx type) (let loop ([expr expr] [type type]) (match expr @@ -299,10 +293,10 @@ (if (representation? type) (literal expr (representation-name type)) expr)] + [(? symbol? (regexp #rx"^\\$var")) + (egg-var->var expr)] [(? symbol?) - (if (hash-has-key? rename-dict expr) - (car (hash-ref rename-dict expr)) ; variable (extract uncanonical name) - (list expr))] ; constant function + (list expr)] ; constant function [(list '$approx spec impl) ; approx (define spec-type (if (representation? type) @@ -320,9 +314,8 @@ [(list op args ...) (cons op (map loop args (operator-info op 'itype)))]))) ;; Parses a string from egg into a single S-expr. -(define (egg-expr->expr egg-expr egraph-data type) - (define egg->herbie (egraph-data-egg->herbie-dict egraph-data)) - (egg-parsed->expr (flatten-let egg-expr) egg->herbie type)) +(define (egg-expr->expr egg-expr egraph-data ctx type) + (egg-parsed->expr (flatten-let egg-expr) ctx type)) (module+ test (*context* (make-debug-context '(x y z))) @@ -341,7 +334,7 @@ (let ([egg-graph (make-egraph-data)]) (for ([(in expected-out) (in-dict test-exprs)]) (define out (expr->egg-expr in egg-graph (*context*))) - (define computed-in (egg-expr->expr out egg-graph (context-repr (*context*)))) + (define computed-in (egg-expr->expr out egg-graph (*context*) (context-repr (*context*)))) (check-equal? out expected-out) (check-equal? computed-in in))) @@ -369,7 +362,7 @@ (let ([egg-graph (make-egraph-data)]) (for ([expr extended-expr-list]) (define egg-expr (expr->egg-expr expr egg-graph (*context*))) - (check-equal? (egg-expr->expr egg-expr egg-graph (context-repr (*context*))) expr)))) + (check-equal? (egg-expr->expr egg-expr egg-graph (*context*) (context-repr (*context*))) expr)))) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Proofs @@ -526,8 +519,8 @@ ;; - specs: vector-map from e-class to an approx spec or #f ;; - parents: vector-map from e-class to its parent e-classes (as a vector) ;; - canon: map from (Rust) e-class, type to (Racket) e-class -;; - egg->herbie: data to translate egg IR to herbie IR -(struct regraph (eclasses types leaf? constants specs parents canon egg->herbie)) +;; - ctx: the standard variable context +(struct regraph (eclasses types leaf? constants specs parents canon ctx)) ;; Returns all representatations (and their types) in the current platform. (define (all-reprs/types [pform (*active-platform*)]) @@ -537,16 +530,17 @@ ;; Returns the type(s) of an enode so it can be placed in the proper e-class. ;; Typing rules: ;; - numbers: every real representation (or real type) -;; - variables: lookup in the `egg->herbie` renaming dictionary +;; - variables: lookup in the context ;; - `if`: type is every representation (or type) [can prune incorrect ones] ;; - `approx`: every real representation [can prune incorrect ones] ;; - ops/impls: its output type/representation ;; NOTE: we can constrain "every" type by using the platform. -(define (enode-type enode egg->herbie) +(define (enode-type enode ctx) (match enode [(? number?) (cons 'real (platform-reprs (*active-platform*)))] ; number [(? symbol?) ; variable - (match-define (cons _ repr) (hash-ref egg->herbie enode)) + (define var (egg-var->var enode ctx)) + (define repr (context-lookup ctx var)) (list repr (representation-type repr))] [(cons f _) ; application (cond @@ -596,7 +590,7 @@ ;; Splits untyped eclasses into typed eclasses. ;; Nodes are duplicated across their possible types. -(define (split-untyped-eclasses egraph-data egg->herbie) +(define (split-untyped-eclasses egraph-data ctx) (define eclass-ids (egraph-eclasses egraph-data)) (define max-id (for/fold ([current-max 0]) ([egg-id (in-u32vector eclass-ids)]) @@ -639,7 +633,7 @@ (for ([enode (in-vector enodes)]) ; get all possible types for the enode ; lookup its correct eclass and add the rebuilt node - (define types (enode-type enode egg->herbie)) + (define types (enode-type enode ctx)) (for ([type (in-list types)]) (define id (idx+type->id idx type)) (define enode* (rebuild-enode enode type lookup-id)) @@ -766,10 +760,10 @@ ;; Splits untyped eclasses into typed eclasses, ;; keeping only the subset of enodes that are well-typed. -(define (make-typed-eclasses egraph-data egg->herbie) +(define (make-typed-eclasses egraph-data ctx) ;; Step 1: split Rust-eclasses by type (define-values (id->eclass id->parents id->leaf? eclass-ids egg-id->idx type->idx) - (split-untyped-eclasses egraph-data egg->herbie)) + (split-untyped-eclasses egraph-data ctx)) ;; Step 2: keep well-typed e-nodes ;; An e-class is well-typed if it has one well-typed node @@ -813,12 +807,11 @@ ;; Constructs a Racket egraph from an S-expr representation of ;; an egraph and data to translate egg IR to herbie IR. -(define (make-regraph egraph-data) - (define egg->herbie (egraph-data-egg->herbie-dict egraph-data)) +(define (make-regraph egraph-data ctx) (define id->spec (egraph-data-id->spec egraph-data)) ;; split the e-classes by type - (define-values (eclasses types canon) (make-typed-eclasses egraph-data egg->herbie)) + (define-values (eclasses types canon) (make-typed-eclasses egraph-data ctx)) (define n (vector-length eclasses)) ;; analyze each eclass @@ -832,7 +825,7 @@ (vector-set! specs id* spec)) ; construct the `regraph` instance - (regraph eclasses types leaf? constants specs parents canon egg->herbie)) + (regraph eclasses types leaf? constants specs parents canon ctx)) (define (regraph-nodes->json regraph) (define cost (platform-node-cost-proc (*active-platform*))) @@ -1014,12 +1007,98 @@ (define id->spec (regraph-specs regraph)) - (define egg->herbie (regraph-egg->herbie regraph)) + (define ctx (regraph-ctx regraph)) (define-values (add-id add-enode finalize-batch) - (egg-nodes->batch costs id->spec batch-extract-to egg->herbie)) + (egg-nodes->batch costs id->spec batch-extract-to ctx)) ;; These functions provide a setup to extract nodes into batch-extract-to from nodes (list add-id add-enode finalize-batch)) +(define (egg-nodes->batch egg-nodes id->spec input-batch ctx) + (define out (batch->mutable-batch input-batch)) + ; This fuction here is only because of cycles in loads:( Can not be imported from egg-herbie.rkt + (define (egg-parsed->expr expr type) + (let loop ([expr expr] + [type type]) + (match expr + [(? number?) + (if (representation? type) + (literal expr (representation-name type)) + expr)] + [(? symbol?) + (if (equal? (substring (symbol->string expr) 0 4) "$var") + (egg-var->var expr ctx) + (list expr))] + [(list '$approx spec impl) + (define spec-type + (if (representation? type) + (representation-type type) + type)) + (approx (loop spec spec-type) (loop impl type))] + [(list 'if cond ift iff) + (if (representation? type) + (list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type)) + (list 'if (loop cond 'bool) (loop ift type) (loop iff type)))] + [(list (? impl-exists? impl) args ...) (cons impl (map loop args (impl-info impl 'itype)))] + [(list op args ...) (cons op (map loop args (operator-info op 'itype)))]))) + + (define (eggref id) + (cdr (vector-ref egg-nodes id))) + + (define (add-enode enode type) + (define idx + (let loop ([enode enode] + [type type]) + (define enode* + (match enode + [(? number?) + (if (representation? type) + (literal enode (representation-name type)) + enode)] + [(? symbol?) + (if (equal? (substring (symbol->string enode) 0 4) "$var") + (egg-var->var enode ctx) + enode)] + [(list '$approx spec (app eggref impl)) + (define spec* (vector-ref id->spec spec)) + (unless spec* + (error 'regraph-extract-variants "no initial approx node in eclass")) + (define spec-type + (if (representation? type) + (representation-type type) + type)) + (define final-spec (egg-parsed->expr spec* spec-type)) + (define final-spec-idx (mutable-batch-munge! out final-spec)) + (approx final-spec-idx (loop impl type))] + [(list 'if (app eggref cond) (app eggref ift) (app eggref iff)) + (if (representation? type) + (list 'if (loop cond (get-representation 'bool)) (loop ift type) (loop iff type)) + (list 'if (loop cond 'bool) (loop ift type) (loop iff type)))] + [(list (? impl-exists? impl) (app eggref args) ...) + (define args* + (for/list ([arg (in-list args)] + [type (in-list (impl-info impl 'itype))]) + (loop arg type))) + (cons impl args*)] + [(list (? operator-exists? op) (app eggref args) ...) + (define args* + (for/list ([arg (in-list args)] + [type (in-list (operator-info op 'itype))]) + (loop arg type))) + (cons op args*)])) + (mutable-batch-push! out enode*))) + (batchref input-batch idx)) + + ; same as add-enode but works with index as an input instead of enode + (define (add-id id type) + (add-enode (eggref id) type)) + + ; Commit changes to the input-batch + (define (finalize-batch) + (batch-copy-mutable-nodes! input-batch out)) + + (values add-id add-enode finalize-batch)) + + ;; Is fractional with odd denominator. (define (fraction-with-odd-denominator? frac) (and (rational? frac) (let ([denom (denominator frac)]) (and (> denom 1) (odd? denom))))) @@ -1064,13 +1143,13 @@ (define (platform-egg-cost-proc regraph cache node type rec) (cond [(representation? type) - (define egg->herbie (regraph-egg->herbie regraph)) + (define ctx (regraph-ctx regraph)) (define node-cost-proc (platform-node-cost-proc (*active-platform*))) (match node ; numbers (repr is unused) [(? number? n) ((node-cost-proc (literal n type) type))] - [(? symbol?) ; variables (`egg->herbie` has the repr) - (define repr (cdr (hash-ref egg->herbie node))) + [(? symbol?) ; variables + (define repr (context-lookup ctx (egg-var->var node ctx))) ((node-cost-proc node repr))] ; approx node [(list '$approx _ impl) (rec impl)] @@ -1085,7 +1164,6 @@ ;; Extracts the best expression according to the extractor. ;; Result is a single element list. (define (regraph-extract-best regraph extract id type) - (define egg->herbie (regraph-egg->herbie regraph)) (define canon (regraph-canon regraph)) ; Extract functions to extract exprs from egraph (match-define (list extract-id _ _) extract) @@ -1274,7 +1352,7 @@ (define root-ids (egg-runner-new-roots runner)) (define egg-graph (egg-runner-egg-graph runner)) - (define regraph (make-regraph egg-graph)) + (define regraph (make-regraph egg-graph ctx)) (define reprs (egg-runner-reprs runner)) (when (flag-set? 'dump 'egg) (regraph-dump regraph root-ids reprs)) @@ -1296,7 +1374,7 @@ (define root-ids (egg-runner-new-roots runner)) (define egg-graph (egg-runner-egg-graph runner)) - (define regraph (make-regraph egg-graph)) + (define regraph (make-regraph egg-graph ctx)) (define reprs (egg-runner-reprs runner)) (when (flag-set? 'dump 'egg) (regraph-dump regraph root-ids reprs)) From 4333e829421e85a670b4b4b3c92edb6e136d50b7 Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 16:39:23 -0700 Subject: [PATCH 08/13] Build id->spec hash as a separate phase --- src/core/egg-herbie.rkt | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index c1ef9d37b..7f46d624e 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -130,17 +130,21 @@ [(? symbol?) (insert-node! (var->egg-var node ctx) root?)] [(hole prec spec) (remap spec)] ; "hole" terms currently disappear [(approx spec impl) - (hash-ref! id->spec - (remap spec) - (lambda () - (define spec* (normalize-spec (batch-ref insert-batch spec))) - (define type (representation-type (repr-of-node insert-batch impl ctx))) - (cons spec* type))) ; preserved spec and type for extraction (insert-node! (list '$approx (remap spec) (remap impl)) root?)] [(list op (app remap args) ...) (insert-node! (cons op args) root?)])) (vector-set! mappings n idx)) + (for ([node (in-vector (batch-nodes insert-batch))] + #:when (approx? node)) + (match-define (approx spec impl) node) + (hash-ref! id->spec + (remap spec) + (lambda () + (define spec* (normalize-spec (batch-ref insert-batch spec))) + (define type (representation-type (repr-of-node insert-batch impl ctx))) + (cons spec* type)))) + (for/list ([root (in-vector (batch-roots insert-batch))]) (remap root))) From d195961f670af449a6fb272edc5f9f379305c063 Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 16:42:26 -0700 Subject: [PATCH 09/13] Simplify parameters to expr<->egg-expr functions --- src/core/egg-herbie.rkt | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 7f46d624e..66943ad76 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -167,13 +167,13 @@ (define (egraph-get-simplest egraph-data node-id iteration ctx) (define expr (egraph_get_simplest (egraph-data-egraph-pointer egraph-data) node-id iteration)) - (egg-expr->expr expr egraph-data ctx (context-repr ctx))) + (egg-expr->expr expr ctx)) (define (egraph-get-variants egraph-data node-id orig-expr ctx) - (define egg-expr (expr->egg-expr orig-expr egraph-data ctx)) + (define egg-expr (expr->egg-expr orig-expr ctx)) (define exprs (egraph_get_variants (egraph-data-egraph-pointer egraph-data) node-id egg-expr)) (for/list ([expr (in-list exprs)]) - (egg-expr->expr expr egraph-data ctx (context-repr ctx)))) + (egg-expr->expr expr ctx))) (define (egraph-is-unsound-detected egraph-data) (egraph_is_unsound_detected (egraph-data-egraph-pointer egraph-data))) @@ -220,14 +220,14 @@ ;; returns a flattened list of terms or #f if it failed to expand the proof due to budget (define (egraph-get-proof egraph-data expr goal ctx) - (define egg-expr (expr->egg-expr expr egraph-data ctx)) - (define egg-goal (expr->egg-expr goal egraph-data ctx)) + (define egg-expr (expr->egg-expr expr ctx)) + (define egg-goal (expr->egg-expr goal ctx)) (define str (egraph_get_proof (egraph-data-egraph-pointer egraph-data) egg-expr egg-goal)) (cond [(<= (string-length str) (*proof-max-string-length*)) (define converted (for/list ([expr (in-port read (open-input-string str))]) - (egg-expr->expr expr egraph-data ctx (context-repr ctx)))) + (egg-expr->expr expr ctx))) (define expanded (expand-proof converted (box (*proof-max-length*)))) (if (member #f expanded) #f expanded)] [else #f])) @@ -263,7 +263,7 @@ ;; Translates a Herbie expression into an expression usable by egg. ;; Updates translation dictionary upon encountering variables. ;; Result is the expression. -(define (expr->egg-expr expr egg-data ctx) +(define (expr->egg-expr expr ctx) (let loop ([expr expr]) (match expr [(? number?) expr] @@ -318,8 +318,8 @@ [(list op args ...) (cons op (map loop args (operator-info op 'itype)))]))) ;; Parses a string from egg into a single S-expr. -(define (egg-expr->expr egg-expr egraph-data ctx type) - (egg-parsed->expr (flatten-let egg-expr) ctx type)) +(define (egg-expr->expr egg-expr ctx) + (egg-parsed->expr (flatten-let egg-expr) ctx (context-repr ctx))) (module+ test (*context* (make-debug-context '(x y z))) @@ -337,8 +337,8 @@ (let ([egg-graph (make-egraph-data)]) (for ([(in expected-out) (in-dict test-exprs)]) - (define out (expr->egg-expr in egg-graph (*context*))) - (define computed-in (egg-expr->expr out egg-graph (*context*) (context-repr (*context*)))) + (define out (expr->egg-expr in (*context*))) + (define computed-in (egg-expr->expr out (*context*))) (check-equal? out expected-out) (check-equal? computed-in in))) @@ -365,8 +365,8 @@ (let ([egg-graph (make-egraph-data)]) (for ([expr extended-expr-list]) - (define egg-expr (expr->egg-expr expr egg-graph (*context*))) - (check-equal? (egg-expr->expr egg-expr egg-graph (*context*) (context-repr (*context*))) expr)))) + (define egg-expr (expr->egg-expr expr (*context*))) + (check-equal? (egg-expr->expr egg-expr (*context*)) expr)))) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Proofs From 928b7e977d64825dda08ea211b2b34d0deed32d0 Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 16:49:58 -0700 Subject: [PATCH 10/13] fmt --- src/core/egg-herbie.rkt | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 66943ad76..b63509dc9 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -129,8 +129,7 @@ [(? number?) (insert-node! node root?)] [(? symbol?) (insert-node! (var->egg-var node ctx) root?)] [(hole prec spec) (remap spec)] ; "hole" terms currently disappear - [(approx spec impl) - (insert-node! (list '$approx (remap spec) (remap impl)) root?)] + [(approx spec impl) (insert-node! (list '$approx (remap spec) (remap impl)) root?)] [(list op (app remap args) ...) (insert-node! (cons op args) root?)])) (vector-set! mappings n idx)) @@ -297,10 +296,8 @@ (if (representation? type) (literal expr (representation-name type)) expr)] - [(? symbol? (regexp #rx"^\\$var")) - (egg-var->var expr)] - [(? symbol?) - (list expr)] ; constant function + [(? symbol? (regexp #rx"^\\$var")) (egg-var->var expr)] + [(? symbol?) (list expr)] ; constant function [(list '$approx spec impl) ; approx (define spec-type (if (representation? type) @@ -1102,7 +1099,6 @@ (values add-id add-enode finalize-batch)) - ;; Is fractional with odd denominator. (define (fraction-with-odd-denominator? frac) (and (rational? frac) (let ([denom (denominator frac)]) (and (> denom 1) (odd? denom))))) From 71bd9b477ad4618f971cc57829e5ae7be176ec5d Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 16:51:42 -0700 Subject: [PATCH 11/13] Fix substring code --- src/core/egg-herbie.rkt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index b63509dc9..8dd596db4 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -205,7 +205,7 @@ ; need to fix up any constant operators (for ([enode (in-vector eclass)] [i (in-naturals)]) - (when (and (symbol? enode) (not (equal? (substring (symbol->string enode) 0 4) "$var"))) + (when (and (symbol? enode) (not (string-prefix? (symbol->string expr) "$var"))) (vector-set! eclass i (cons enode empty-u32vec)))) eclass) @@ -1026,7 +1026,7 @@ (literal expr (representation-name type)) expr)] [(? symbol?) - (if (equal? (substring (symbol->string expr) 0 4) "$var") + (if (string-prefix? (symbol->string expr) "$var") (egg-var->var expr ctx) (list expr))] [(list '$approx spec impl) @@ -1056,7 +1056,7 @@ (literal enode (representation-name type)) enode)] [(? symbol?) - (if (equal? (substring (symbol->string enode) 0 4) "$var") + (if (string-prefix? (symbol->string expr) "$var") (egg-var->var enode ctx) enode)] [(list '$approx spec (app eggref impl)) From 71ceba769d02d96b78cdd3647541034c90bfe8dd Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 16:57:09 -0700 Subject: [PATCH 12/13] Woops, typos / C&P errors --- src/core/egg-herbie.rkt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 8dd596db4..48c4299b3 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -205,7 +205,7 @@ ; need to fix up any constant operators (for ([enode (in-vector eclass)] [i (in-naturals)]) - (when (and (symbol? enode) (not (string-prefix? (symbol->string expr) "$var"))) + (when (and (symbol? enode) (not (string-prefix? (symbol->string enode) "$var"))) (vector-set! eclass i (cons enode empty-u32vec)))) eclass) @@ -1056,7 +1056,7 @@ (literal enode (representation-name type)) enode)] [(? symbol?) - (if (string-prefix? (symbol->string expr) "$var") + (if (string-prefix? (symbol->string enode) "$var") (egg-var->var enode ctx) enode)] [(list '$approx spec (app eggref impl)) From 10c82be6706c7b67c8cc546482ec01adcca8342c Mon Sep 17 00:00:00 2001 From: Pavel Panchekha Date: Tue, 21 Jan 2025 17:05:23 -0700 Subject: [PATCH 13/13] Fix egg-expr->expr --- src/core/egg-herbie.rkt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 48c4299b3..7265ec8b0 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -296,8 +296,10 @@ (if (representation? type) (literal expr (representation-name type)) expr)] - [(? symbol? (regexp #rx"^\\$var")) (egg-var->var expr)] - [(? symbol?) (list expr)] ; constant function + [(? symbol?) + (if (string-prefix? (symbol->string expr) "$var") + (egg-var->var expr ctx) + (list expr))] [(list '$approx spec impl) ; approx (define spec-type (if (representation? type)