-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathCodeGen.scala
619 lines (543 loc) · 21.3 KB
/
CodeGen.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
trait PythonPrinter {
def prelude: String
def print(s: Seq): String = s match {
case Singleton(e) => "[" + print(e) + "]"
case Join(l, r) => print(l) + " + " + print(r)
case Compr(e, v, Range(l, h)) => "[" + print(e) + " for " + print(v) +
" in xrange(" + print(l) + ", " + print(h) + ")]"
}
def print(e: Expr): String = e match {
case Var(n, _) => n
case Const(i) => if (i >= 0) i.toString else "(-" + (-i).toString + ")"
case Plus(l, r) => "(" + print(l) + " + " + print(r) + ")"
case Minus(l, r) => "(" + print(l) + " - " + print(r) + ")"
case Times(l, r) => print(l) + "*" + print(r)
case Div(l, r) => print(l) + "/" + print(r)
case Mod(l, r) => "(" + print(l) + " mod " + print(r) + ")"
case App(v, args) => print(v) + "(" + args.map(print(_)).mkString(", ") + ")"
case Op(l, r) => "plus(" + print(l) + ", " + print(r) + ")"
case Reduce(r) => "reduce(plus, " + print(r) + ", zero)"
case Zero => "zero"
case Havoc => "random.randint(0, 1000)"
case OpVar(v, args, exprs) => "(lambda " + args.map(print(_)).mkString(", ") +
": " + print(v) + "(" + exprs.map(print(_)).mkString(", ") + "))"
case Cond(cases, default) => cases match {
case (pred, expr) :: rest => "(" + print(expr) + " if " + print(pred) + " else " + (rest match {
case Nil => print(default)
case _ => print(Cond(rest, default))
}) + ")"
case _ => ???
}
}
def print(p: Pred): String = p match {
case True => "True"
case False => "False"
case And(l, r) => "(" + print(l) + " and " + print(r) + ")"
case Or(l, r) => "(" + print(l) + " or " + print(r) + ")"
case Not(l) => "(not " + print(l) + ")"
case Eq(l, r) => "(" + print(l) + " == " + print(r) + ")"
case LT(l, r) => "(" + print(l) + " < " + print(r) + ")"
case GT(l, r) => "(" + print(l) + " > " + print(r) + ")"
case LE(l, r) => "(" + print(l) + " <= " + print(r) + ")"
case GE(l, r) => "(" + print(l) + " >= " + print(r) + ")"
}
def print(c: Computation): String
def print(p: List[Computation], out: java.io.PrintStream) {
out.println(prelude)
for (c <- p)
out.println(print(c))
}
}
// Functional style code output
object Python extends PythonPrinter {
override val prelude =
"""class memoize(dict):
def __init__(self, func):
self.func = func
def __call__(self, *args):
return self[args]
def __missing__(self, key):
result = self[key] = self.func(*key)
return result
plus = min
zero = pow(10, 16)
import random
import sys
sys.setrecursionlimit(2 ** 16)
"""
def print(c: Computation): String = c match {
case a: Algorithm =>
"def " + print(a.v) +
"(" + a.args.map(print(_)).mkString(", ") + "):\n" +
" assert " + print(a.pre) +
"\n" + { a.expr match {
case Cond(cases, default) => " if " + cases.map { case (pred, expr) =>
print(pred) + ":\n return " + print(expr) }.mkString("\n elif ") +
"\n else:\n return " + print(default)
case e => " return " + print(e)
}}
case Input(v, e) =>
if (v.arity == 0)
print(v) + " = " + print(e)
else
"@memoize\ndef " + print(v) + "(" +
(1 to v.arity).map("v"+_).mkString(", ") + "):\n" +
" return " + print(e)
}
}
// Imperative style code output
// Turn functions into NumPy tables
// Treat first 'dom' arguments as indexes into tables
// Expects input programs to be flattened
class NumPython(smt: Proof, val dom: Int) extends PythonPrinter with Logger {
assert (dom > 0)
def prelude = """# autogenerated by bellmaniac
import sys
plus = min
zero = sys.maxint
DIM = 128
MIN = 0
MAX = 1000
from numpy import *
"""
private def print(l: List[Expr]): String = l.map(print(_)).mkString(", ")
private def indent(tabs: Int = dom+1) = (1 to tabs).map(_ => " ").mkString("")
import Transform.{transform, visit}
import scala.language.postfixOps
import scala.language.implicitConversions
// create T value
val T = Var("T", dom)
// current algorithm
private var scope: Algorithm = _
private var offv: List[Var] = _
// all computations
private var all: List[Computation] = _
def inputs = all.collect { case i: Input => i }
def algorithms = all.collect { case a: Algorithm => a }
// Use lookup tables instead of function applications
override def print(e: Expr) = e match {
case App(v: Var, args) => print(v) + "[" + print(args) + "]"
case _: OpVar => error("should be unreachable")
case v: Var if v.arity > 0 =>
if (inputs.exists(_.v == v) || scope.args.contains(v) || v.name.startsWith("T"))
v.name
else {
// TODO: fix this
// error("can only pass tables as parameters: " + v)
v.name
}
case _ => super.print(e)
}
// Write to T at offset off by invoking v with rest
case class Write(v: Var, T: Var, off: List[Expr], rest: List[Expr], deps: List[Write]) {
assert (v.arity == off.size + rest.size, "must match arity")
assert (off.size == dom && T.arity == dom)
override def toString = v.name + "(" + print(T :: off ::: rest) + ")"
def all: List[Write] =
deps.flatMap(_.all) ::: this :: Nil
def stmts: List[String] = all.map(_.toString)
// same up to renaming of all Ts and this.off
def same(that: Write): Boolean =
this.v == that.v &&
(this.rest zip that.rest).forall {
case (e1, e2) =>
e1 == e2 ||
((this.deps.find(_.T == e1), that.deps.find(_.T == e2)) match {
case (Some(w1), Some(w2)) if w1.off == w2.off =>
w1.same(w2)
case _ => false
})
}
def replace(implicit that: Write): Write = {
deps.find { case w => w.same(that) } match {
case Some(w) => // replace and repeat
// offset corrections
val offsets = that.off zip w.off map Minus.tupled;
copy(deps = deps.filter(_ != w),
rest = transform(App(v, off ::: rest)) {
case App(v, args) if args.contains(w.T) =>
// add offsets according to position of w.T
val a = algorithms.find(_.v == v).get
// find additions
var add: Map[Var, Expr] = Map()
for ((formal, actual) <- a.args zip args;
if actual == w.T;
(o, i) <- offsets.zipWithIndex)
add = add + (Var(formal.name + "_" + i) -> o)
// add them
App(v, for ((formal, actual) <- a.args zip args) yield {
if (actual == w.T)
that.T
else if (add.contains(formal))
Linear.make(actual + add(formal)).expr
else
actual
})
case v: Var if v == w.T =>
error("must appear as a parameter in App")
} match {
case App(_, args) => args.drop(dom)
case _ => ???
}
).replace
case None => // recurse
copy(deps = deps.map(_.replace))
}
}
}
object Write {
// top-level write
def make(app: App): Write =
make(app, T, offv zip scope.args.take(dom) map Plus.tupled)
// off is compared to first arguments of app for offsets
// outputs list of write with last corresponding to this "app"
// each child T is used only once in the parent
def make(app: App, T: Var, off: List[Expr]): Write = app match {
case App(v: Var, args) =>
val offsets =
for ((o, a) <- off zip args.take(dom))
yield Linear.make(o - a).expr
var deps: List[Write] = Nil
// recurse on dependency OpVars
val rest = for (arg <- args.drop(dom)) yield transform(arg) {
case OpVar(v, args, exprs) =>
assert(v.isInstanceOf[Var], "must be flattened")
assert(exprs.startsWith(args), "must be fully linearized")
//assert(args.size == dom, "must match dom")
val T1 = T.fresh
// todo: extract offsets from the opvar
// deps = make(App(v, exprs), T1, args) :: deps
// XXX: hack below
deps = make(App(v, exprs), T1, args ::: exprs.take(dom - args.size)) :: deps
T1
}
Write(v, T, offsets, rest, deps.reverse)
}
}
// use partitions in other partitions writes
def reuse(writes: List[Write]) = {
var deps: List[(Int, Int)] = Nil
def replaceAll(writes: List[Write]) =
for ((w1,i) <- writes.zipWithIndex) yield {
var out = w1
for ((w2,j) <- writes.zipWithIndex if i != j) {
val next = out.replace(w2)
if (out != next) {
out = next
deps = (j, i) :: deps
}
}
out
}
var result: List[Write] = writes
var prev: List[Write] = Nil
while (result != prev) {
prev = result
result = replaceAll(result)
}
// permute according to deps
def order(g: List[Int]): List[Int] = g match {
case Nil => Nil
case _ =>
// find node without incoming edges
g.find { case i => ! deps.exists(_._2 == i) } match {
case Some(i) =>
// remove from graph and recurse
deps = deps.filter(_._1 != i)
i :: order(g.filter(_ != i))
case _ => error("loop detected")
}
}
order(0 to (writes.size - 1) toList).map(result(_))
}
// TODO: unsound
// reuse memory tables in a write sequence
def overwrite(w: Write): Write = w match {
case Write(v, t, off, rest, List(w1)) =>
// must use only one temporary table once
val t1 = w1.T
if (! w1.off.forall(_ == Const(0)))
error("must have 0-based tables")
// must write where it reads from
var off1: List[Expr] = Nil
// find where it's used
val rest1 = transform(App(v, off ::: rest)) {
case App(v, args) if args.contains(t1) =>
val a = algorithms.find(_.v == v).get
// find formal argument for t1
val formal = (a.args zip args).filter(_._2 == t1).map(_._1) match {
case f :: Nil => f
case _ => error("must have exactly one argument " + t1)
}
// find read signatures for a
val fc = memory(a).read.filter(_.v == formal).map(_.c) match {
case fc :: Nil => fc.map(_.s(a.args zip args))
case _ => error("must have exactly one read vector " + t1)
}
// read at off1 from invocation point
off1 = (0 to (dom-1)).map(i =>
Linear.make(off(i) + args(i) - fc(i)).expr).toList
// correcte offsets to use t instead of t1:
var add = {for (i <- (0 to (dom - 1)))
yield (Var(formal.name + "_" + i), off1(i))
}.toMap
// reuse this t in place of t1
App(v, a.args zip args map {
case (f, a) if f == formal => t
case (f, a) if add.contains(f) => Linear.make(a + add(f)).expr
case (_, a) => a
})
case v: Var if v == t1 =>
error("must appear as a parameter in App")
} match {
case a: App => a.args.drop(dom)
case _ => error("unexpected")
}
Write(v, t, off, rest1, List(overwrite(w1.copy(T = t, off = off1))))
case _ => w
}
override def print(c: Computation) = c match {
case Input(v, e) =>
print(v) + " = " + (e match {
case Havoc if v.arity > 0 =>
"random.randint(MIN, MAX, size=(" + (1 to v.arity).map(_ => "DIM").mkString(", ") + "))"
case e if v.arity == 0 =>
print(e)
case e if v.arity > 0 =>
"empty((" + (1 to v.arity).map(_ => "DIM").mkString(", ") + "), int)\n" +
print(v) +".fill(" + print(e) +")"
})
case a @ Algorithm(v, args, pre, e) =>
this.scope = a
this.offv = args.take(dom).map { case Var(n, i) => Var("o" + n, i) }
val loop = new LoopConstruct(a)
"def " + v.name + "(" + print(T :: offv ::: args.drop(dom)) + "):\n" + {
e match {
case Cond(cases, Havoc) if cases.forall(_._2.isInstanceOf[App]) =>
// SPLIT PROGRAM
// compute with write sequences
cases match {
case (minp, mine: App) :: splits =>
val (preds, es) = splits.unzip
val apps = es.map {
case app: App => app
case _ => error("expected an app here")
}
// double check that splits are coming from the tactic
for (p1 <- preds; p2 <- preds; if p1 != p2)
assert(smt.prove((pre and p1) implies (! p2)))
// reuse and overwrite
val writes = reuse(apps.map(Write.make(_))).map(overwrite(_))
" if " + print(minp) + ":\n " +
Write.make(mine).stmts.mkString("\n ") +
"\n return\n " + {
(for (w1 <- writes;
w <- w1.all if w.T != T;
a <- algorithms if a.v == w.v)
yield print(w.T) + " = " + memory(a).alloc) :::
(for (w <- writes; stmt <- "# partition" :: w.stmts)
yield stmt)
}.mkString("\n ")
case _ => ???
}
case _ =>
// ABSTRACT SPEC PROGRAM
// compute with the loop construct
val (lvars, lbounds, lmap) = loop.generate
// must be self-recursive for this to work
def makeTs(e: Expr): Expr = transform(e) {
case App(v, args) if a.v == v =>
if (args.drop(dom) == a.args.drop(dom))
App(T, args.take(dom).map(makeTs) zip offv map { case (a, o) => a + o })
else
error("self-recursion must project to same value")
}
{for (((lv, Range(l, h)), i) <- lvars zip lbounds zipWithIndex)
yield indent(i+1) + "for "+print(lv)+" in xrange("+print(l)+","+print(h)+"):\n"
}.mkString("") +
{for ((l,i) <- lmap zipWithIndex)
yield indent() + args(i)+" = "+lmap(i)+"\n"
}.mkString("") +
indent() + "assert " + print(pre) + "\n" +
indent() + print(makeTs(App(v, args))) + " = " + print(makeTs(e))
}
}
}
override def print(p: List[Computation], out: java.io.PrintStream) {
this.all = p
super.print(p, out)
}
// Generate loop construct for an algorithm
// Additional references:
// polyhedral model on wiki
// Omega library tutorials (e.g. SUIF)
case class Rotation(flips: List[Boolean]) {
def apply(a: List[Expr]) = a zip flips map {
case (x, false) => x
case (x, true) => Const(0) - x
}
def inverse = this
}
object Rotation {
def all(d: Int): Iterator[Rotation] =
if (d == 0)
List(Rotation(Nil)).iterator
else
(all(d-1) map { case Rotation(flips) => Rotation(false :: flips) }) ++
(all(d-1) map { case Rotation(flips) => Rotation(true :: flips) })
}
case class Vector(path: Pred, v: Var, c: List[Expr])
class LoopConstruct(a: Algorithm) extends Logger {
val pre = a.pre
val c = a.args.take(dom)
// find all recursion references
def vectors = {
var out: List[Vector] = Nil
transform(a) {
case (path, locals, App(v, vargs)) =>
if (v == a.v)
out = Vector(path, a.v, vargs.take(dom)) :: out
else if (! locals.contains(v) && ! inputs.exists(_.v == v))
error("unexpected: " + v + " in " + a.v)
Havoc
}
out
}
// domination order
def lexorder(a: List[Expr], b: List[Expr]): Pred = (a, b) match {
case (Nil, Nil) => True
case (a0 :: a1, b0 :: b1) => a0 < b0 or (a0 === b0 and lexorder(a1, b1))
case _ => ???
}
// find domination order orientation
def orient(vs: List[Vector]): (Rotation, List[Int]) = {
for (p <- (0 to dom-1).toList.permutations;
r <- Rotation.all(dom))
if (vs.forall { case Vector(path, _, vc) =>
smt.prove(path implies lexorder(p.map(r(vc)(_)), p.map(r(c)(_))))
})
return (r, p)
error("can't orient in domination order: " + vs.mkString(", "))
}
implicit def int2rat(n: Int) = new Rational(n, 1)
implicit def int2expr(n: Int) = Const(n)
implicit def int2linear(n: Int) = Linear.make(Map(), new Rational(n, 1))
// solve for max expression
def MAX(p: List[Expr], pred: Pred): Expr = p match {
case Nil => ???
case e :: Nil => e
case e :: p1 =>
val e1 = MAX(p1, pred)
if (smt.prove(pred implies e1 <= e))
e
else if (smt.prove(pred implies e <= e1))
e1
else
error("can't find max of " + p + " under " + pred)
}
// solve for min expression
def MIN(p: List[Expr], pred: Pred): Expr = p match {
case Nil => ???
case e :: Nil => e
case e :: p1 =>
val e1 = MIN(p1, pred)
if (smt.prove(pred implies e <= e1))
e
else if (smt.prove(pred implies e1 <= e))
e1
else
error("can't find min of " + p + " under " + pred)
}
// Infer range constraints from linear constraints in predicate
def inferBounds(p: List[Var], pred: Pred,
use: Boolean = true, transitive: Boolean = false): Option[List[Range]] = {
var eqs = Linear.equations(pred)
// compute transitive equations by eliminating one variable at a time
if (transitive)
eqs = eqs ::: {
for (v <- p;
e1 <- eqs if e1.proj(v) > 0;
e2 <- eqs if e2.proj(v) < 0) yield
e1 * e2.proj(v) * (-1) + e2 * e1.proj(v)
}
// free variables (allowed to appear in result range)
var free = Transform.vars(pred) -- p
var out: List[Range] = Nil
for (v <- p) {
// find constraints only contains "free" vars and having "v"
val bounds = eqs.filter { case eq => eq.has(v) && eq.vars.subsetOf(free + v) }
// upper and lower bound expressions
val lower = bounds.filter(_.proj(v) > 0).map {
case eq => 0 - (eq.drop(v) / eq.proj(v)) }.map(_.expr)
val upper = bounds.filter(_.proj(v) < 0).map {
case eq => 1 - (eq.drop(v) / eq.proj(v)) }.map(_.expr)
if (lower.size == 0 || upper.size == 0)
return None
// use previous variables in generating subsequent range bound
if (use)
free = free + v
out = Range(MAX(lower, pred), MIN(upper, pred)) :: out
}
return Some(out.reverse)
}
// generate looping construct for first "dom" parameters of "a"
// returns (list of iteration variables, list of their ranges, assignment to actual variables)
def generate: (List[Var], List[Range], List[Expr]) = {
// orient dependency vectors by flipping +/- coordinates so that they point
// into lower-left corner in the lexicographic order
// p applied after r
val (r, p) = orient(vectors)
// find iteration order and bounds
// create fresh variables: c1 = r(c)
val c1 = c.map { case Var(n, i) => Var(n +"0", i) }
// formulate pre in terms of c1
val exprs = r.inverse(c1)
val pre1 = pre.s(c zip exprs)
// solve for bounds
val p1 = p.map(c1(_))
inferBounds(p1, pre1) match {
case Some(ranges) => (p1, ranges, exprs)
case _ => error("can't infer bounds: " + pre1)
}
}
}
// Specifies read/write ranges
// (reads are applications of arguments under path conditions for given x in DOM)
// (writes are bounds on DOM)
// In particular, we can write into tables used in read as long as it's same x
class MemorySpec(a: Algorithm) {
// table dimensions (range from 0 to Expr)
def write: List[Expr] = new LoopConstruct(a).inferBounds(a.args.take(dom), a.pre, false, true) match {
case Some(ranges) =>
for (Range(l, h) <- ranges) yield {
if (! smt.prove(a.pre implies l >= Const(0)))
error("can't allocate memory with possibly negative index")
h
}
case None =>
error("can't infer memory allocation bounds")
}
def alloc = "zeros((" + print(write) + "), int)"
def read: List[Vector] = {
var out: List[Vector] = Nil
transform(a) {
case (path, _, App(v: Var, vargs)) if a.args.contains(v) && v.arity == dom =>
out = Vector(path, v, vargs) :: out
Havoc
case (_, _, App(v, _)) if v == a.v =>
// TODO: block recursive passes of arguments, need to think about why it's ok
Havoc
case (_, _, v) if a.args.contains(v) && v.arity == dom =>
println(a)
error(s"can't infer memory spec if an argument function $v is passed in a call in ${a.v}")
}
out
}
}
// use specification memory spec for an algorithm
def memory(a: Algorithm): MemorySpec = a.expr match {
case Cond((_, App(v, _)) :: _, Havoc) => // split?
memory(algorithms.find(_.v == v).get)
case _ => new MemorySpec(a)
}
}