From 8548c1e92edfb0d7293662aad650d9579e0dc77c Mon Sep 17 00:00:00 2001 From: AYadrov Date: Sat, 25 Jan 2025 18:59:10 -0700 Subject: [PATCH] removed lower-approximations, and putting taylor series straight to run-rr --- src/core/egg-herbie.rkt | 3 +-- src/core/patch.rkt | 56 ++++------------------------------------- src/core/programs.rkt | 1 + 3 files changed, 7 insertions(+), 53 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 7265ec8b0..38b3a1eff 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -104,6 +104,7 @@ ; node -> natural ; inserts an expression into the e-graph, returning its e-class id. + (define (insert-node! node root?) (match node [(list op ids ...) (egraph_add_node ptr (symbol->string op) (list->u32vec ids) root?)] @@ -111,7 +112,6 @@ [(? number? n) (egraph_add_node ptr (number->string n) 0-vec root?)])) (define insert-batch (batch-remove-zombie batch roots)) - (define mappings (build-vector (batch-length insert-batch) values)) (define (remap x) (vector-ref mappings x)) @@ -131,7 +131,6 @@ [(hole prec spec) (remap spec)] ; "hole" terms currently disappear [(approx spec impl) (insert-node! (list '$approx (remap spec) (remap impl)) root?)] [(list op (app remap args) ...) (insert-node! (cons op args) root?)])) - (vector-set! mappings n idx)) (for ([node (in-vector (batch-nodes insert-batch))] diff --git a/src/core/patch.rkt b/src/core/patch.rkt index 7ffa314d1..1e9a79e2a 100644 --- a/src/core/patch.rkt +++ b/src/core/patch.rkt @@ -16,52 +16,6 @@ (provide generate-candidates) -;;;;;;;;;;;;;;;;;;;;;;;;;;;; Simplify ;;;;;;;;;;;;;;;;;;;;;;;;;;;; - -(define (lower-approximations approxs global-batch) - (timeline-event! 'simplify) - - (define reprs - (for/list ([approx (in-list approxs)]) - (define prev (car (alt-prevs approx))) - (repr-of (debatchref (alt-expr prev)) (*context*)))) - - ; generate real rules - (define rules (*simplify-rules*)) - (define lowering-rules (platform-lowering-rules)) - - ; egg runner - (define schedule - (if (flag-set? 'generate 'simplify) - ; if simplify enabled, 2-phases for real rewrites and implementation selection - `((,rules . ((node . ,(*node-limit*)))) - (,lowering-rules . ((iteration . 1) (scheduler . simple)))) - ; if disabled, only implementation selection - `((,lowering-rules . ((iteration . 1) (scheduler . simple)))))) - - (define roots - (for/vector ([approx (in-list approxs)]) - (batchref-idx (alt-expr approx)))) - - ; run egg - (define runner (make-egraph global-batch roots reprs schedule)) - (define simplification-options (simplify-batch runner global-batch)) - - ; convert to altns - (define simplified - (reap [sow] - (define global-batch-mutable (batch->mutable-batch global-batch)) ; Create mutable batch - (for ([altn (in-list approxs)] - [outputs (in-list simplification-options)]) - (match-define (cons _ simplified) outputs) - (define prev (car (alt-prevs altn))) - (for ([bref (in-list simplified)]) - (sow (alt bref `(simplify ,runner #f) (list altn) '())))) - (batch-copy-mutable-nodes! global-batch global-batch-mutable))) ; Update global-batch - - (timeline-push! 'count (length approxs) (length simplified)) - simplified) - ;;;;;;;;;;;;;;;;;;;;;;;;;;;; Taylor ;;;;;;;;;;;;;;;;;;;;;;;;;;;; (define transforms-to-try @@ -111,7 +65,7 @@ (timeline-push! 'outputs (map ~a (map (compose debatchref alt-expr) approxs))) (timeline-push! 'count (length altns) (length approxs)) - (lower-approximations approxs global-batch)) + approxs) ;;;;;;;;;;;;;;;;;;;;;;;;;;;; Recursive Rewrite ;;;;;;;;;;;;;;;;;;;;;;;;;;;; @@ -134,7 +88,6 @@ (define roots (list->vector (map (compose batchref-idx alt-expr) altns))) (define reprs (map (curryr repr-of (*context*)) exprs)) (timeline-push! 'inputs (map ~a exprs)) - (define runner (make-egraph global-batch roots reprs schedule)) ; batchrefss is a (listof (listof batchref)) (define batchrefss (egraph-variations runner global-batch)) @@ -161,7 +114,7 @@ ; Starting alternatives (define start-altns (for/list ([expr (in-list exprs)] - [root (batch-roots global-batch)]) + [root (in-vector (batch-roots global-batch))]) (define repr (repr-of expr (*context*))) (alt (batchref global-batch root) (list 'patch expr repr) '() '()))) @@ -170,10 +123,11 @@ (if (flag-set? 'generate 'taylor) (run-taylor exprs start-altns global-batch) '())) + ; Recursive rewrite (define rewritten (if (flag-set? 'generate 'rr) - (run-rr start-altns global-batch) + (run-rr (append start-altns approximations) global-batch) '())) - (remove-duplicates (append approximations rewritten) #:key (λ (x) (batchref-idx (alt-expr x))))) + (remove-duplicates rewritten #:key (λ (x) (batchref-idx (alt-expr x))))) diff --git a/src/core/programs.rkt b/src/core/programs.rkt index b3e58e8a1..a870c3cad 100644 --- a/src/core/programs.rkt +++ b/src/core/programs.rkt @@ -30,6 +30,7 @@ [(literal val precision) (get-representation precision)] [(? variable?) (context-lookup ctx expr)] [(approx _ impl) (repr-of impl ctx)] + [(hole precision spec) (get-representation precision)] [(list 'if cond ift iff) (repr-of ift ctx)] [(list op args ...) (impl-info op 'otype)]))