Skip to content

Commit

Permalink
Increase code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
kehemo committed Oct 31, 2024
1 parent 3416cbd commit e0c1039
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 29 deletions.
38 changes: 9 additions & 29 deletions src/exo/frontend/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def __init__(self, nm):
self.nm = nm


def str_to_mem(name):
return getattr(sys.modules[__name__], name)


@dataclass
class SourceInfo:
src_file: str
Expand Down Expand Up @@ -1354,10 +1350,6 @@ def unquote_to_index(unquoted, ref_node, srcinfo, top_level):
if isinstance(e, pyast.Slice):
idxs.append(self.parse_slice(e, node))
srcinfo_for_idxs.append(srcinfo)
unquote_eval_result = self.try_eval_unquote(e)
if len(unquote_eval_result) == 1:
unquoted = unquote_eval_result[0]

else:
unquote_eval_result = self.try_eval_unquote(e)
if len(unquote_eval_result) == 1:
Expand Down Expand Up @@ -1396,19 +1388,16 @@ def parse_slice(self, e, node):
else:
srcinfo = self.getsrcinfo(node)

if isinstance(e, pyast.Slice):
lo = None if e.lower is None else self.parse_expr(e.lower)
hi = None if e.upper is None else self.parse_expr(e.upper)
if e.step is not None:
self.err(
e,
"expected windowing to have the form x[:], "
"x[i:], x[:j], or x[i:j], but not x[i:j:k]",
)
lo = None if e.lower is None else self.parse_expr(e.lower)
hi = None if e.upper is None else self.parse_expr(e.upper)
if e.step is not None:
self.err(

Check warning on line 1394 in src/exo/frontend/pyparser.py

View check run for this annotation

Codecov / codecov/patch

src/exo/frontend/pyparser.py#L1394

Added line #L1394 was not covered by tests
e,
"expected windowing to have the form x[:], "
"x[i:], x[:j], or x[i:j], but not x[i:j:k]",
)

return UAST.Interval(lo, hi, srcinfo)
else:
return UAST.Point(self.parse_expr(e), srcinfo)
return UAST.Interval(lo, hi, srcinfo)

# parse expressions, including values, indices, and booleans
def parse_expr(self, e):
Expand All @@ -1433,17 +1422,8 @@ def parse_expr(self, e):
else:
return PAST.Read(nm, idxs, self.getsrcinfo(e))
else:
parent_globals = self.parent_scope.get_globals()
parent_locals = self.parent_scope.read_locals()
if nm_node.id in self.exo_locals:
nm = self.exo_locals[nm_node.id]
elif (
nm_node.id in parent_locals
and parent_locals[nm_node.id] is not None
):
nm = parent_locals[nm_node.id].val
elif nm_node.id in parent_globals:
nm = parent_globals[nm_node.id]
else:
self.err(nm_node, f"variable '{nm_node.id}' undefined")

Expand Down
13 changes: 13 additions & 0 deletions tests/golden/test_metaprogramming/test_local_externs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include "test.h"

#include <stdio.h>
#include <stdlib.h>

#include <math.h>
// foo(
// a : f64 @DRAM
// )
void foo( void *ctxt, double* a ) {
*a = log((double)*a);
}

12 changes: 12 additions & 0 deletions tests/golden/test_metaprogramming/test_unary_ops.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "test.h"

#include <stdio.h>
#include <stdlib.h>

// foo(
// a : i32 @DRAM
// )
void foo( void *ctxt, int32_t* a ) {
*a = ((int32_t) -2);
}

144 changes: 144 additions & 0 deletions tests/test_metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from exo.API_scheduling import rename
from exo.frontend.pyparser import ParseError
import pytest
import warnings
from exo.core.extern import Extern, _EErr


def test_unrolling(golden):
Expand Down Expand Up @@ -376,3 +378,145 @@ def test_outer_return_disallowed():
def foo(a: i32):
with python:
return


def test_with_block():
@proc
def foo(a: i32):
with python:

def issue_warning():
warnings.warn("deprecated", DeprecationWarning)

with warnings.catch_warnings(record=True) as recorded_warnings:
issue_warning()
assert len(recorded_warnings) == 1
pass


def test_unary_ops(golden):
@proc
def foo(a: i32):
with python:
x = ~1
with exo:
a = x

c_file, _ = compile_procs_to_strings([foo], "test.h")
assert c_file == golden


def test_return_in_async():
@proc
def foo(a: i32):
with python:

async def bar():
return 1

pass


def test_local_externs(golden):
class _Log(Extern):
def __init__(self):
super().__init__("log")

def typecheck(self, args):
if len(args) != 1:
raise _EErr(f"expected 1 argument, got {len(args)}")

Check warning on line 427 in tests/test_metaprogramming.py

View check run for this annotation

Codecov / codecov/patch

tests/test_metaprogramming.py#L427

Added line #L427 was not covered by tests

atyp = args[0].type
if not atyp.is_real_scalar():
raise _EErr(

Check warning on line 431 in tests/test_metaprogramming.py

View check run for this annotation

Codecov / codecov/patch

tests/test_metaprogramming.py#L431

Added line #L431 was not covered by tests
f"expected argument 1 to be a real scalar value, but "
f"got type {atyp}"
)
return atyp

def globl(self, prim_type):
return "#include <math.h>"

def compile(self, args, prim_type):
return f"log(({prim_type}){args[0]})"

log = _Log()

@proc
def foo(a: f64):
a = log(a)

c_file, _ = compile_procs_to_strings([foo], "test.h")
assert c_file == golden


def test_unquote_multiple_exprs():
with pytest.raises(ParseError, match="Unquote must take 1 argument"):
x = 0

@proc
def foo(a: i32):
a = {x, x}


def test_disallow_with_in_exo():
with pytest.raises(ParseError, match="Expected unquote"):

@proc
def foo(a: i32):
with a:
pass


def test_unquote_multiple_stmts():
with pytest.raises(ParseError, match="Unquote must take 1 argument"):

@proc
def foo(a: i32):
with python:
with exo as s:
a += 1
with exo:
{s, s}


def test_unquote_non_statement():
with pytest.raises(
ParseError,
match="Statement-level unquote expression must return Exo statements",
):

@proc
def foo(a: i32):
with python:
x = ~{a}
with exo:
{x}


def test_unquote_slice_with_step():
with pytest.raises(ParseError, match="Unquote returned slice index with step"):

@proc
def bar(a: [i32][10]):
a[0] = 0

@proc
def foo(a: i32[20]):
with python:
x = slice(0, 20, 2)
with exo:
bar(a[x])


def test_typecheck_unquote_index():
with pytest.raises(
ParseError, match="Unquote received input that couldn't be unquoted"
):

@proc
def foo(a: i32[20]):
with python:
x = "0"
with exo:
a[x] = 0

0 comments on commit e0c1039

Please sign in to comment.