Skip to content

Commit

Permalink
Merge pull request #1143 from herbie-fp/simplify-rewrite-in-one-shot
Browse files Browse the repository at this point in the history
Creating one egraph in `generate-candidates`
  • Loading branch information
pavpanchekha authored Jan 27, 2025
2 parents c73a45c + 8548c1e commit 328a3e6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 53 deletions.
3 changes: 1 addition & 2 deletions src/core/egg-herbie.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@

; node -> natural
; inserts an expression into the e-graph, returning its e-class id.

(define (insert-node! node root?)
(match node
[(list op ids ...) (egraph_add_node ptr (symbol->string op) (list->u32vec ids) root?)]
[(? symbol? x) (egraph_add_node ptr (symbol->string x) 0-vec root?)]
[(? number? n) (egraph_add_node ptr (number->string n) 0-vec root?)]))

(define insert-batch (batch-remove-zombie batch roots))

(define mappings (build-vector (batch-length insert-batch) values))
(define (remap x)
(vector-ref mappings x))
Expand All @@ -131,7 +131,6 @@
[(hole prec spec) (remap spec)] ; "hole" terms currently disappear
[(approx spec impl) (insert-node! (list '$approx (remap spec) (remap impl)) root?)]
[(list op (app remap args) ...) (insert-node! (cons op args) root?)]))

(vector-set! mappings n idx))

(for ([node (in-vector (batch-nodes insert-batch))]
Expand Down
56 changes: 5 additions & 51 deletions src/core/patch.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -16,52 +16,6 @@

(provide generate-candidates)

;;;;;;;;;;;;;;;;;;;;;;;;;;;; Simplify ;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define (lower-approximations approxs global-batch)
(timeline-event! 'simplify)

(define reprs
(for/list ([approx (in-list approxs)])
(define prev (car (alt-prevs approx)))
(repr-of (debatchref (alt-expr prev)) (*context*))))

; generate real rules
(define rules (*simplify-rules*))
(define lowering-rules (platform-lowering-rules))

; egg runner
(define schedule
(if (flag-set? 'generate 'simplify)
; if simplify enabled, 2-phases for real rewrites and implementation selection
`((,rules . ((node . ,(*node-limit*))))
(,lowering-rules . ((iteration . 1) (scheduler . simple))))
; if disabled, only implementation selection
`((,lowering-rules . ((iteration . 1) (scheduler . simple))))))

(define roots
(for/vector ([approx (in-list approxs)])
(batchref-idx (alt-expr approx))))

; run egg
(define runner (make-egraph global-batch roots reprs schedule))
(define simplification-options (simplify-batch runner global-batch))

; convert to altns
(define simplified
(reap [sow]
(define global-batch-mutable (batch->mutable-batch global-batch)) ; Create mutable batch
(for ([altn (in-list approxs)]
[outputs (in-list simplification-options)])
(match-define (cons _ simplified) outputs)
(define prev (car (alt-prevs altn)))
(for ([bref (in-list simplified)])
(sow (alt bref `(simplify ,runner #f) (list altn) '()))))
(batch-copy-mutable-nodes! global-batch global-batch-mutable))) ; Update global-batch

(timeline-push! 'count (length approxs) (length simplified))
simplified)

;;;;;;;;;;;;;;;;;;;;;;;;;;;; Taylor ;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(define transforms-to-try
Expand Down Expand Up @@ -111,7 +65,7 @@
(timeline-push! 'outputs (map ~a (map (compose debatchref alt-expr) approxs)))
(timeline-push! 'count (length altns) (length approxs))

(lower-approximations approxs global-batch))
approxs)

;;;;;;;;;;;;;;;;;;;;;;;;;;;; Recursive Rewrite ;;;;;;;;;;;;;;;;;;;;;;;;;;;;

Expand All @@ -134,7 +88,6 @@
(define roots (list->vector (map (compose batchref-idx alt-expr) altns)))
(define reprs (map (curryr repr-of (*context*)) exprs))
(timeline-push! 'inputs (map ~a exprs))

(define runner (make-egraph global-batch roots reprs schedule))
; batchrefss is a (listof (listof batchref))
(define batchrefss (egraph-variations runner global-batch))
Expand All @@ -161,7 +114,7 @@
; Starting alternatives
(define start-altns
(for/list ([expr (in-list exprs)]
[root (batch-roots global-batch)])
[root (in-vector (batch-roots global-batch))])
(define repr (repr-of expr (*context*)))
(alt (batchref global-batch root) (list 'patch expr repr) '() '())))

Expand All @@ -170,10 +123,11 @@
(if (flag-set? 'generate 'taylor)
(run-taylor exprs start-altns global-batch)
'()))

; Recursive rewrite
(define rewritten
(if (flag-set? 'generate 'rr)
(run-rr start-altns global-batch)
(run-rr (append start-altns approximations) global-batch)
'()))

(remove-duplicates (append approximations rewritten) #:key (λ (x) (batchref-idx (alt-expr x)))))
(remove-duplicates rewritten #:key (λ (x) (batchref-idx (alt-expr x)))))
1 change: 1 addition & 0 deletions src/core/programs.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
[(literal val precision) (get-representation precision)]
[(? variable?) (context-lookup ctx expr)]
[(approx _ impl) (repr-of impl ctx)]
[(hole precision spec) (get-representation precision)]
[(list 'if cond ift iff) (repr-of ift ctx)]
[(list op args ...) (impl-info op 'otype)]))

Expand Down

0 comments on commit 328a3e6

Please sign in to comment.