Skip to content

Commit

Permalink
default_to_not_stripping_fn_defs, flag to re-enable
Browse files Browse the repository at this point in the history
  • Loading branch information
snuffysasa committed Jul 4, 2022
1 parent c88c350 commit c09b05c
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 10 deletions.
4 changes: 2 additions & 2 deletions import.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def prune_source(
Returns (source, compilable_source)."""
try:
ast = ast_util.parse_c(source, from_import=True)
orig_fn, _ = ast_util.extract_fn(ast, func_name)
orig_fn, _ = ast_util.extract_fn(ast, func_name, True)
if should_prune:
try:
ast_util.prune_ast(orig_fn, ast)
Expand Down Expand Up @@ -516,7 +516,7 @@ def prune_and_separate_context(
Returns (source, context)."""
try:
ast = ast_util.parse_c(source, from_import=True)
orig_fn, ind = ast_util.extract_fn(ast, func_name)
orig_fn, ind = ast_util.extract_fn(ast, func_name, True)
if should_prune:
try:
ind = ast_util.prune_ast(orig_fn, ast)
Expand Down
7 changes: 5 additions & 2 deletions src/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,16 @@ def visit_If(self, n: ca.If) -> str:
return super().visit_If(n2) # type: ignore


def extract_fn(ast: ca.FileAST, fn_name: str) -> Tuple[ca.FuncDef, int]:
def extract_fn(
ast: ca.FileAST, fn_name: str, strip_other_fn_defs: bool
) -> Tuple[ca.FuncDef, int]:
ret = []
for i, node in enumerate(ast.ext):
if isinstance(node, ca.FuncDef):
if node.decl.name == fn_name:
ret.append((node, i))
elif "inline" not in node.decl.funcspec:
break
elif strip_other_fn_defs and "inline" not in node.decl.funcspec:
node = node.decl
ast.ext[i] = node
if isinstance(node, ca.Decl) and isinstance(node.type, ca.FuncDecl):
Expand Down
13 changes: 9 additions & 4 deletions src/candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Candidate:
ast: ca.FileAST

fn_name: str
strip_other_fn_defs: bool
rng_seed: int
randomizer: Randomizer
score_value: Optional[int] = field(init=False, default=None)
Expand All @@ -45,10 +46,10 @@ class Candidate:
@staticmethod
@functools.lru_cache(maxsize=16)
def _cached_shared_ast(
source: str, fn_name: str
source: str, fn_name: str, strip_other_fn_defs: bool
) -> Tuple[ca.FuncDef, int, ca.FileAST]:
ast = ast_util.parse_c(source)
orig_fn, fn_index = ast_util.extract_fn(ast, fn_name)
orig_fn, fn_index = ast_util.extract_fn(ast, fn_name, strip_other_fn_defs)
ast_util.normalize_ast(orig_fn, ast)
return orig_fn, fn_index, ast

Expand All @@ -57,14 +58,17 @@ def from_source(
source: str,
eval_state: EvalState,
fn_name: str,
strip_other_fn_defs: bool,
randomization_weights: Mapping[str, float],
rng_seed: int,
) -> "Candidate":
# Use the same AST for all instances of the same original source, but
# with the target function deeply copied. Since we never change the
# AST outside of the target function, this is fine, and it saves us
# performance (deepcopy is really slow).
orig_fn, fn_index, ast = Candidate._cached_shared_ast(source, fn_name)
orig_fn, fn_index, ast = Candidate._cached_shared_ast(
source, fn_name, strip_other_fn_defs
)
ast = copy.copy(ast)
ast.ext = copy.copy(ast.ext)
fn_copy = copy.deepcopy(orig_fn)
Expand All @@ -73,12 +77,13 @@ def from_source(
return Candidate(
ast=ast,
fn_name=fn_name,
strip_other_fn_defs=strip_other_fn_defs,
rng_seed=rng_seed,
randomizer=Randomizer(randomization_weights, rng_seed),
)

def randomize_ast(self) -> None:
self.randomizer.randomize(self.ast, self.fn_name)
self.randomizer.randomize(self.ast, self.fn_name, self.strip_other_fn_defs)
self._cache_source = None

def get_source(self) -> str:
Expand Down
9 changes: 9 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class Options:
network_priority: float = 1.0
no_context_output: bool = False
debug_mode: bool = False
strip_other_fn_defs: bool = False


def restricted_float(lo: float, hi: float) -> Callable[[str], float]:
Expand Down Expand Up @@ -356,6 +357,7 @@ def run_inner(options: Options, heartbeat: Callable[[], None]) -> List[int]:
better_only=options.better_only,
score_threshold=options.score_threshold,
debug_mode=options.debug_mode,
strip_other_fn_defs=options.strip_other_fn_defs,
)
except CandidateConstructionFailure as e:
print(e.message, file=sys.stderr)
Expand Down Expand Up @@ -708,6 +710,12 @@ def main() -> None:
action="store_true",
help="Debug mode, only compiles and scores the base for debugging issues",
)
parser.add_argument(
"--strip-extra_fn_defs",
dest="strip_other_fn_defs",
action="store_true",
help="Strip all function defs except for the target function and inlines into just function declarations",
)

args = parser.parse_args()

Expand Down Expand Up @@ -735,6 +743,7 @@ def main() -> None:
network_priority=args.network_priority,
no_context_output=args.no_context_output,
debug_mode=args.debug_mode,
strip_other_fn_defs=args.strip_other_fn_defs,
)

run(options)
Expand Down
1 change: 1 addition & 0 deletions src/net/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _create_permuter(
best_only=False,
score_threshold=None,
debug_mode=False,
strip_other_fn_defs=False,
)
except:
os.unlink(path)
Expand Down
4 changes: 4 additions & 0 deletions src/permuter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
better_only: bool,
score_threshold: Optional[int],
debug_mode: bool,
strip_other_fn_defs: bool,
) -> None:
self.dir = dir
self.compiler = compiler
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(
self._better_only = better_only
self._score_threshold = score_threshold
self._debug_mode = debug_mode
self._strip_other_fn_defs = strip_other_fn_defs
(
self.base_score,
self.base_hash,
Expand All @@ -142,6 +144,7 @@ def _create_and_score_base(self) -> Tuple[int, str, str]:
base_source,
eval_state,
self.fn_name,
self._strip_other_fn_defs,
self.randomization_weights,
rng_seed=0,
)
Expand Down Expand Up @@ -186,6 +189,7 @@ def _eval_candidate(self, seed: int) -> CandidateResult:
cand_c,
eval_state,
self.fn_name,
self._strip_other_fn_defs,
self.randomization_weights,
rng_seed=rng_seed,
)
Expand Down
6 changes: 4 additions & 2 deletions src/randomizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,8 +2232,10 @@ def __init__(
for method in RANDOMIZATION_PASSES
]

def randomize(self, ast: ca.FileAST, fn_name: str) -> None:
fn = ast_util.extract_fn(ast, fn_name)[0]
def randomize(
self, ast: ca.FileAST, fn_name: str, strip_other_fn_defs: bool
) -> None:
fn = ast_util.extract_fn(ast, fn_name, strip_other_fn_defs)[0]
indices = ast_util.compute_node_indices(fn)
region = get_randomization_region(fn, indices, self.random)
while True:
Expand Down

0 comments on commit c09b05c

Please sign in to comment.