Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

default_to_not_stripping_fn_defs, flag to re-enable #144

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions import.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def prune_source(
Returns (source, compilable_source)."""
try:
ast = ast_util.parse_c(source, from_import=True)
orig_fn, _ = ast_util.extract_fn(ast, func_name)
orig_fn, _ = ast_util.extract_fn(ast, func_name, True)
if should_prune:
try:
ast_util.prune_ast(orig_fn, ast)
Expand Down Expand Up @@ -516,7 +516,7 @@ def prune_and_separate_context(
Returns (source, context)."""
try:
ast = ast_util.parse_c(source, from_import=True)
orig_fn, ind = ast_util.extract_fn(ast, func_name)
orig_fn, ind = ast_util.extract_fn(ast, func_name, True)
if should_prune:
try:
ind = ast_util.prune_ast(orig_fn, ast)
Expand Down
6 changes: 4 additions & 2 deletions src/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ def visit_If(self, n: ca.If) -> str:
return super().visit_If(n2) # type: ignore


def extract_fn(ast: ca.FileAST, fn_name: str) -> Tuple[ca.FuncDef, int]:
def extract_fn(
ast: ca.FileAST, fn_name: str, strip_other_fn_defs: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ast: ca.FileAST, fn_name: str, strip_other_fn_defs: bool
ast: ca.FileAST, fn_name: str, strip_other_fn_defs: bool = True

I think it'd make sense to give this arg a default value so it's clear what exactly the default behavior is. I'm a bit confused about this PR becuase I see you passing True in sometimes and False others, so it seems like a behavioral change as opposed to merely the introduction of a new option.

For IDO, I know sometimes function defs are required and affect codegen, but for GCC this isn't the case. I'm thinking it's okay to change default behavior based on the compiler provided in the settings file, but we probably should at least use that information instead of just changing the default behavior for everyone.

) -> Tuple[ca.FuncDef, int]:
ret = []
for i, node in enumerate(ast.ext):
if isinstance(node, ca.FuncDef):
if node.decl.name == fn_name:
ret.append((node, i))
elif "inline" not in node.decl.funcspec:
elif strip_other_fn_defs and "inline" not in node.decl.funcspec:
node = node.decl
ast.ext[i] = node
if isinstance(node, ca.Decl) and isinstance(node.type, ca.FuncDecl):
Expand Down
13 changes: 9 additions & 4 deletions src/candidate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Candidate:
ast: ca.FileAST

fn_name: str
strip_other_fn_defs: bool
rng_seed: int
randomizer: Randomizer
score_value: Optional[int] = field(init=False, default=None)
Expand All @@ -45,10 +46,10 @@ class Candidate:
@staticmethod
@functools.lru_cache(maxsize=16)
def _cached_shared_ast(
source: str, fn_name: str
source: str, fn_name: str, strip_other_fn_defs: bool
) -> Tuple[ca.FuncDef, int, ca.FileAST]:
ast = ast_util.parse_c(source)
orig_fn, fn_index = ast_util.extract_fn(ast, fn_name)
orig_fn, fn_index = ast_util.extract_fn(ast, fn_name, strip_other_fn_defs)
ast_util.normalize_ast(orig_fn, ast)
return orig_fn, fn_index, ast

Expand All @@ -57,14 +58,17 @@ def from_source(
source: str,
eval_state: EvalState,
fn_name: str,
strip_other_fn_defs: bool,
randomization_weights: Mapping[str, float],
rng_seed: int,
) -> "Candidate":
# Use the same AST for all instances of the same original source, but
# with the target function deeply copied. Since we never change the
# AST outside of the target function, this is fine, and it saves us
# performance (deepcopy is really slow).
orig_fn, fn_index, ast = Candidate._cached_shared_ast(source, fn_name)
orig_fn, fn_index, ast = Candidate._cached_shared_ast(
source, fn_name, strip_other_fn_defs
)
ast = copy.copy(ast)
ast.ext = copy.copy(ast.ext)
fn_copy = copy.deepcopy(orig_fn)
Expand All @@ -73,12 +77,13 @@ def from_source(
return Candidate(
ast=ast,
fn_name=fn_name,
strip_other_fn_defs=strip_other_fn_defs,
rng_seed=rng_seed,
randomizer=Randomizer(randomization_weights, rng_seed),
)

def randomize_ast(self) -> None:
self.randomizer.randomize(self.ast, self.fn_name)
self.randomizer.randomize(self.ast, self.fn_name, self.strip_other_fn_defs)
self._cache_source = None

def get_source(self) -> str:
Expand Down
9 changes: 9 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class Options:
network_priority: float = 1.0
no_context_output: bool = False
debug_mode: bool = False
strip_other_fn_defs: bool = False


def restricted_float(lo: float, hi: float) -> Callable[[str], float]:
Expand Down Expand Up @@ -356,6 +357,7 @@ def run_inner(options: Options, heartbeat: Callable[[], None]) -> List[int]:
better_only=options.better_only,
score_threshold=options.score_threshold,
debug_mode=options.debug_mode,
strip_other_fn_defs=options.strip_other_fn_defs,
)
except CandidateConstructionFailure as e:
print(e.message, file=sys.stderr)
Expand Down Expand Up @@ -708,6 +710,12 @@ def main() -> None:
action="store_true",
help="Debug mode, only compiles and scores the base for debugging issues",
)
parser.add_argument(
"--strip-extra_fn_defs",
dest="strip_other_fn_defs",
action="store_true",
help="Strip all function defs except for the target function and inlines into just function declarations",
)

args = parser.parse_args()

Expand Down Expand Up @@ -735,6 +743,7 @@ def main() -> None:
network_priority=args.network_priority,
no_context_output=args.no_context_output,
debug_mode=args.debug_mode,
strip_other_fn_defs=args.strip_other_fn_defs,
)

run(options)
Expand Down
1 change: 1 addition & 0 deletions src/net/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _create_permuter(
best_only=False,
score_threshold=None,
debug_mode=False,
strip_other_fn_defs=False,
)
except:
os.unlink(path)
Expand Down
4 changes: 4 additions & 0 deletions src/permuter.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
better_only: bool,
score_threshold: Optional[int],
debug_mode: bool,
strip_other_fn_defs: bool,
) -> None:
self.dir = dir
self.compiler = compiler
Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(
self._better_only = better_only
self._score_threshold = score_threshold
self._debug_mode = debug_mode
self._strip_other_fn_defs = strip_other_fn_defs
(
self.base_score,
self.base_hash,
Expand All @@ -142,6 +144,7 @@ def _create_and_score_base(self) -> Tuple[int, str, str]:
base_source,
eval_state,
self.fn_name,
self._strip_other_fn_defs,
self.randomization_weights,
rng_seed=0,
)
Expand Down Expand Up @@ -186,6 +189,7 @@ def _eval_candidate(self, seed: int) -> CandidateResult:
cand_c,
eval_state,
self.fn_name,
self._strip_other_fn_defs,
self.randomization_weights,
rng_seed=rng_seed,
)
Expand Down
6 changes: 4 additions & 2 deletions src/randomizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,8 +2232,10 @@ def __init__(
for method in RANDOMIZATION_PASSES
]

def randomize(self, ast: ca.FileAST, fn_name: str) -> None:
fn = ast_util.extract_fn(ast, fn_name)[0]
def randomize(
self, ast: ca.FileAST, fn_name: str, strip_other_fn_defs: bool
) -> None:
fn = ast_util.extract_fn(ast, fn_name, strip_other_fn_defs)[0]
indices = ast_util.compute_node_indices(fn)
region = get_randomization_region(fn, indices, self.random)
while True:
Expand Down