From 9a9c65da0bd598fb3212c5e0f9616be29eeb3f87 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Fri, 31 Jan 2025 17:55:06 -0700 Subject: [PATCH] extremely revolutionary ideas on extraction --- src/core/egg-herbie.rkt | 59 +++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/src/core/egg-herbie.rkt b/src/core/egg-herbie.rkt index 5d93a9ed4..819654789 100644 --- a/src/core/egg-herbie.rkt +++ b/src/core/egg-herbie.rkt @@ -585,17 +585,17 @@ [(cons f ids) ; application (cond [(eq? f '$approx) ; approx node - (define spec (u32vector-ref ids 0)) - (define impl (u32vector-ref ids 1)) + (define spec (lookup (u32vector-ref ids 0))) + (define impl (lookup (u32vector-ref ids 1))) (list '$approx spec impl)] [(eq? f 'if) ; if expression - (define cond (u32vector-ref ids 0)) - (define ift (u32vector-ref ids 1)) - (define iff (u32vector-ref ids 2)) + (define cond (lookup (u32vector-ref ids 0))) + (define ift (lookup (u32vector-ref ids 1))) + (define iff (lookup (u32vector-ref ids 2))) (list 'if cond ift iff)] [(eq? f '$hole) ; hole expression - (define repr (u32vector-ref ids 0)) - (define val (u32vector-ref ids 1)) + (define repr (lookup (u32vector-ref ids 0))) + (define val (lookup (u32vector-ref ids 1))) (list '$hole repr val)] ; it is just a repr node [(repr-exists? f) f] @@ -606,14 +606,16 @@ (operator-info f 'itype))) ; unsafe since we don't check that |itypes| = |ids| ; optimize for common cases to avoid extra allocations - (cons - f - (match itypes - [(list) '()] - [(list t1) (list (u32vector-ref ids 0))] - [(list t1 t2) (list (u32vector-ref ids 0) (u32vector-ref ids 1))] - [(list t1 t2 t3) (list (u32vector-ref ids 0) (u32vector-ref ids 1) (u32vector-ref ids 2))] - [_ (u32vector->list ids)]))])])) + (cons f + (match itypes + [(list) '()] + [(list t1) (list (lookup (u32vector-ref ids 0)))] + [(list t1 t2) (list (lookup (u32vector-ref ids 0)) (lookup (u32vector-ref ids 1)))] + [(list t1 t2 t3) + (list (lookup (u32vector-ref ids 0)) + (lookup (u32vector-ref ids 1)) + (lookup (u32vector-ref ids 2)))] + [_ (map lookup (u32vector->list ids))]))])])) ;; Splits untyped eclasses into typed eclasses. ;; Nodes are duplicated across their possible types. @@ -639,16 +641,11 @@ (+ (* idx num-types) (hash-ref type->idx type))) ; maps (untyped eclass id, type) to typed eclass id - (define (lookup-id eid type) - (idx+type->id (u32vector-ref egg-id->idx eid) type)) + (define (lookup-id eid) + (u32vector-ref egg-id->idx eid)) - ;;; - (define (lookup-id2 eid) - (vector-ref (egraph-get-eclass egraph-data eid) 0)) - ;;; - - ; allocate enough eclasses for every (egg-id, type) combination - (define n (* (u32vector-length eclass-ids) num-types)) + ; allocate enough eclasses for every egg-id combination + (define n (u32vector-length eclass-ids)) (define id->eclass (make-vector n '())) (define id->parents (make-vector n '())) (define id->leaf? (make-vector n #f)) @@ -663,8 +660,18 @@ (for ([eid (in-u32vector eclass-ids)] [idx (in-naturals)]) (define enodes (egraph-get-eclass egraph-data eid)) - (for/list ([enode (in-vector enodes)]) - (rebuild-enode enode lookup-id2))) + (for ([enode (in-vector enodes)]) + (define enode* (rebuild-enode enode lookup-id)) + (match enode* + [(list _ ids ...) + (if (null? ids) + (vector-set! id->leaf? idx #t) + (for ([child-id (in-list ids)]) + (vector-set! id->parents child-id (cons idx (vector-ref id->parents child-id)))))] + [(? symbol?) (vector-set! id->leaf? idx #t)] + [(? number?) (vector-set! id->leaf? idx #t)]) + (vector-set! id->eclass idx (cons enode* (vector-ref id->eclass idx)))) + (printf "Eclass ~a: ~a\n" idx (vector-ref id->eclass idx))) #;(for ([eid (in-u32vector eclass-ids)] [idx (in-naturals)])