Skip to content

Commit

Permalink
Add python bindings for RDom::where
Browse files Browse the repository at this point in the history
  • Loading branch information
vksnk committed May 9, 2016
1 parent 00f68fe commit 5fed5d6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
15 changes: 13 additions & 2 deletions python_bindings/python/RDom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ h::RDom *RDom_constructor4(h::Expr min0, h::Expr extent0,
void defineRDom()
{
using Halide::RDom;


defineRVar();

Expand Down Expand Up @@ -218,7 +217,19 @@ void defineRDom()
"Compare two reduction domains for equality of reference")
.def("dimensions", &RDom::dimensions, p::arg("self"),
"Get the dimensionality of a reduction domain")

.def("where", &RDom::where, p::args("self", "predicate"),
"Add a predicate to the RDom. An RDom may have multiple"
"predicates associated with it. An update definition that uses"
"an RDom only iterates over the subset points in the domain for"
"which all of its predicates are true. The predicate expression"
"obeys the same rules as the expressions used on the"
"right-hand-side of the corresponding update definition. It may"
"refer to the RDom's variables and free variables in the Func's"
"update definition. It may include calls to other Funcs, or make"
"recursive calls to the same Func. This permits iteration over"
"non-rectangular domains, or domains with sizes that vary with"
"some free variable, or domains with shapes determined by some"
"other Func. ")
//"Get at one of the dimensions of the reduction domain"
//EXPORT RVar operator[](int) const;

Expand Down
33 changes: 33 additions & 0 deletions python_bindings/tests/test_rdom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/python3

import halide as h

def test_rdom():
x = h.Var("x")
y = h.Var("y")

diagonal = h.Func("diagonal")
diagonal[x, y] = 1

domain_width = 10
domain_height = 10

r = h.RDom(0, domain_width, 0, domain_height)
r.where(r.x <= r.y)

diagonal[r.x, r.y] = 2
output = diagonal.realize(domain_width, domain_height)
output = h.Image(h.Int(32), output)

for iy in range(domain_height):
for ix in range(domain_width):
if ix <= iy:
assert output(ix, iy) == 2
else:
assert output(ix, iy) == 1

print("Success!")
return 0

if __name__ == "__main__":
test_rdom()

0 comments on commit 5fed5d6

Please sign in to comment.