Skip to content

Commit

Permalink
Typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
simonlindholm committed Sep 25, 2020
1 parent 8a2c931 commit c703fbf
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 32 deletions.
5 changes: 3 additions & 2 deletions src/randomizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,7 @@ def visit_Assignment(self, node: ca.Assignment) -> None:

node = random.choice(cands)

assert isinstance(node.rvalue, ca.BinaryOp)
node.op = node.rvalue.op + node.op
node.rvalue = node.rvalue.right

Expand Down Expand Up @@ -1239,10 +1240,10 @@ def perm_float_literal(
"""Converts a Float Literal"""
typemap = build_typemap(ast)

cands: List[Expression] = []
cands: List[ca.Constant] = []

class Visitor(ca.NodeVisitor):
def visit_Constant(self, node) -> None:
def visit_Constant(self, node: ca.Constant) -> None:
if node.type == "float":
cands.append(node)

Expand Down
9 changes: 6 additions & 3 deletions strip_other_fns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import argparse
from typing import Optional
from pathlib import Path


Expand All @@ -26,7 +27,7 @@ def strip_other_fns(source: str, keep_fn_name: str) -> str:
while True:
fn_regex = re.compile(r"^.*\s+\**(\w+)\(.*\)\s*?{", re.M)
fn = re.search(fn_regex, remain)
if fn == None:
if fn is None:
result += remain
remain = ""
break
Expand All @@ -45,7 +46,9 @@ def strip_other_fns(source: str, keep_fn_name: str) -> str:
return result


def strip_other_fns_and_write(source: str, fn_name: str, out_filename=None) -> None:
def strip_other_fns_and_write(
source: str, fn_name: str, out_filename: Optional[str] = None
) -> None:
stripped = strip_other_fns(source, fn_name)

if out_filename is None:
Expand All @@ -55,7 +58,7 @@ def strip_other_fns_and_write(source: str, fn_name: str, out_filename=None) -> N
f.write(stripped)


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Remove all but a single function definition from a file."
)
Expand Down
56 changes: 29 additions & 27 deletions test/test_perm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# type: ignore
import unittest
import os
import tempfile
Expand All @@ -11,31 +12,32 @@
from src import main

c_files_list = [
['test_general.c', 'test_general'],
['test_general.c', 'test_general_3'],
['test_general.c', 'test_general_multiple'],
['test_ternary.c', 'test_ternary1'],
['test_ternary.c', 'test_ternary2'],
['test_type.c', 'test_type1'],
['test_type.c', 'test_type2'],
['test_type.c', 'test_type3'],
['test_randomizer.c', 'test_randomizer'],
["test_general.c", "test_general"],
["test_general.c", "test_general_3"],
["test_general.c", "test_general_multiple"],
["test_ternary.c", "test_ternary1"],
["test_ternary.c", "test_ternary2"],
["test_type.c", "test_type1"],
["test_type.c", "test_type2"],
["test_type.c", "test_type3"],
["test_randomizer.c", "test_randomizer"],
]


class TestStringMethods(unittest.TestCase):
@classmethod
def setUpClass(cls):
compiler = Compiler('test/compile.sh')
compiler = Compiler("test/compile.sh")
cls.tmp_dirs = {}
for test_c, test_fn in c_files_list:
d = tempfile.TemporaryDirectory()
file_test = os.path.join('test', test_c)
file_test = os.path.join("test", test_c)
file_actual = os.path.join(d.name, "actual.c")
file_base = os.path.join(d.name, "base.c")
file_target = os.path.join(d.name, "target.o")

actual_preprocessed = preprocess(file_test, cpp_args=['-DACTUAL'])
base_preprocessed = preprocess(file_test, cpp_args=['-UACTUAL'])
actual_preprocessed = preprocess(file_test, cpp_args=["-DACTUAL"])
base_preprocessed = preprocess(file_test, cpp_args=["-UACTUAL"])

strip_other_fns_and_write(actual_preprocessed, test_fn, file_actual)
strip_other_fns_and_write(base_preprocessed, test_fn, file_base)
Expand All @@ -48,61 +50,61 @@ def setUpClass(cls):

shutil.copy2("test/compile.sh", d.name)
cls.tmp_dirs[(test_c, test_fn)] = d

@classmethod
def tearDownClass(cls):
for d in cls.tmp_dirs.values():
d.cleanup()

def go(self, filename, fn_name, **kwargs) -> int:
d = self.tmp_dirs[(filename, fn_name)].name
score, = main.run(main.Options(directories=[d], stop_on_zero=True, **kwargs))
(score,) = main.run(main.Options(directories=[d], stop_on_zero=True, **kwargs))
return score

def test_general(self):
score = self.go('test_general.c', 'test_general')
score = self.go("test_general.c", "test_general")
self.assertEqual(score, 0)

def test_general_3(self):
score = self.go('test_general.c', 'test_general_3')
score = self.go("test_general.c", "test_general_3")
self.assertEqual(score, 0)

def test_general_multiple(self):
score = self.go('test_general.c', 'test_general_multiple')
score = self.go("test_general.c", "test_general_multiple")
self.assertEqual(score, 0)

def test_ternary1(self):
score = self.go('test_ternary.c', 'test_ternary1')
score = self.go("test_ternary.c", "test_ternary1")
self.assertEqual(score, 0)

def test_ternary2(self):
score = self.go('test_ternary.c', 'test_ternary2')
score = self.go("test_ternary.c", "test_ternary2")
self.assertEqual(score, 0)

def test_type1(self):
score = self.go('test_type.c', 'test_type1')
score = self.go("test_type.c", "test_type1")
self.assertEqual(score, 0)

def test_type2(self):
score = self.go('test_type.c', 'test_type2')
score = self.go("test_type.c", "test_type2")
self.assertEqual(score, 0)

def test_type3(self):
score = self.go('test_type.c', 'test_type3')
score = self.go("test_type.c", "test_type3")
self.assertEqual(score, 0)

def test_type3_threaded(self):
score = self.go('test_type.c', 'test_type3', threads=2)
score = self.go("test_type.c", "test_type3", threads=2)
self.assertEqual(score, 0)

def test_randomizer(self):
score = self.go('test_randomizer.c', 'test_randomizer')
score = self.go("test_randomizer.c", "test_randomizer")
self.assertEqual(score, 0)

def test_randomizer_threaded(self):
score = self.go('test_randomizer.c', 'test_randomizer', threads=2)
score = self.go("test_randomizer.c", "test_randomizer", threads=2)
self.assertEqual(score, 0)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

0 comments on commit c703fbf

Please sign in to comment.