-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into new-encoding
- Loading branch information
Showing
7 changed files
with
87 additions
and
190 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,84 +1,34 @@ | ||
#lang racket | ||
|
||
(require "../utils/alternative.rkt" | ||
"points.rkt" | ||
"programs.rkt" | ||
"egg-herbie.rkt" | ||
"../syntax/sugar.rkt" | ||
"../syntax/syntax.rkt") | ||
"egg-herbie.rkt") | ||
|
||
(provide add-derivations) | ||
|
||
(define (canonicalize-proof prog proof loc pcontext ctx) | ||
(define (canonicalize-proof prog proof loc) | ||
(and proof | ||
;; Proofs are actually on subexpressions, | ||
;; we need to construct the proof for the full expression | ||
(for/list ([step (in-list proof)]) | ||
(location-do loc prog (const step))))) | ||
|
||
;; Computes a `equal?`-based hash table key for an alternative | ||
(define (altn->key altn) | ||
(match altn | ||
[(alt expr `(rr ,loc ,method ,_) prevs _) (list expr (list 'rr loc method) (map alt-expr prevs))] | ||
[(alt expr `(simplify ,loc ,method ,_) prevs _) | ||
(list expr (list 'simplify loc method) (map alt-expr prevs))] | ||
[_ (error 'altn->key "unimplemented ~a" altn)])) | ||
|
||
;; Creates two tables: | ||
;; - map from alternative to a pair (e, l ~> r) where `e` is an `egg-runner` | ||
;; and `l ~> r` is the rewrite we want a proof for. | ||
;; - map from egg query to list of proofs | ||
(define (make-proof-tables altns) | ||
(define alt->query&rws (make-hash)) | ||
(define query->rws (make-hash)) | ||
|
||
(define (build! altn) | ||
(match altn | ||
; recursive rewrite using egg (impl -> impl) | ||
[(alt expr `(,(or 'rr 'simplify) ,loc ,(? egg-runner? runner) #f) `(,prev) _) | ||
(define start-expr (location-get loc (alt-expr prev))) | ||
(define end-expr (location-get loc expr)) | ||
(define rewrite (cons start-expr end-expr)) | ||
(hash-set! alt->query&rws (altn->key altn) (cons runner rewrite)) | ||
(hash-update! query->rws runner (lambda (rws) (set-add rws rewrite)) '())] | ||
|
||
; everything else | ||
[_ (void)]) | ||
|
||
altn) | ||
|
||
; build the table | ||
(for ([altn (in-list altns)]) | ||
(alt-for-each build! altn)) | ||
(values alt->query&rws query->rws)) | ||
|
||
;; Runs proof extraction. | ||
;; Result is a map from egg query to rewrites. | ||
(define (compute-proofs query->rws) | ||
(for/hash ([(runner rws) (in-hash query->rws)]) | ||
(define proofs (run-egg runner `(proofs . ,rws))) | ||
(values runner (map cons rws proofs)))) | ||
|
||
;; Lookups a proof based on an alternative | ||
(define ((lookup-proof alt->query&rws query->proofs) altn) | ||
(match-define (cons runner rw) (hash-ref alt->query&rws (altn->key altn))) | ||
(cdr (assoc rw (hash-ref query->proofs runner)))) | ||
|
||
;; Adds proof information to alternatives. | ||
(define (add-derivations-to altn pcontext ctx alt->proof) | ||
(define (add-derivations-to altn) | ||
(match altn | ||
; recursive rewrite or simplify, both using egg | ||
[(alt expr (list phase loc (? egg-runner? runner) #f) `(,prev) _) | ||
#:when (or (equal? phase 'simplify) (equal? phase 'rr)) | ||
(define proof (canonicalize-proof (alt-expr altn) (alt->proof altn) loc pcontext ctx)) | ||
(alt expr `(rr ,loc ,runner ,proof) `(,prev) '())] | ||
[(alt expr (list (or 'simplify 'rr) loc (? egg-runner? runner) #f) `(,prev) _) | ||
(define start-expr (location-get loc (alt-expr prev))) | ||
(define end-expr (location-get loc expr)) | ||
(define proof (first (run-egg runner `(proofs ,(cons start-expr end-expr))))) | ||
(define proof* (canonicalize-proof (alt-expr altn) proof loc)) | ||
(alt expr `(rr ,loc ,runner ,proof*) `(,prev) '())] | ||
|
||
; everything else | ||
[_ altn])) | ||
|
||
(define (add-derivations alts pcontext ctx) | ||
(define-values (alt->query&rws query->rws) (make-proof-tables alts)) | ||
(define query->proofs (compute-proofs query->rws)) | ||
(define lookup-proc (lookup-proof alt->query&rws query->proofs)) | ||
(define (add-derivations alts) | ||
(define cache (make-hash)) | ||
(for/list ([altn (in-list alts)]) | ||
(alt-map (curryr add-derivations-to pcontext ctx lookup-proc) altn))) | ||
;; We need to cache this because we'll see the same alt several times | ||
(alt-map (lambda (altn) (hash-ref! cache altn (lambda () (add-derivations-to altn)))) altn))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.