Skip to content

Commit

Permalink
[BC] Refactor _KEEP_TEMPS for reusability (#376)
Browse files Browse the repository at this point in the history
Patch to env.py and compilation_runner.py which adds working_dir to
TimeStep. The patch also gives the option to keep the temporary
working_dir by setting keep_temps in compilation_runner.py to a
directory where all temporary working_dirs will be saved.
  • Loading branch information
tvmarino authored Oct 2, 2024
1 parent dafed47 commit 08e6242
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
18 changes: 14 additions & 4 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@ def __exit__(self, exc, value, tb):
pass


def get_workdir_context():
"""Return a context which manages how the temperory directories are handled.
When the flag keep_temps is specified temporary directories are stored in
keep_temps.
"""
if _KEEP_TEMPS.value is not None:
tempdir_context = NonTemporaryDirectory(dir=_KEEP_TEMPS.value)
else:
tempdir_context = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
return tempdir_context


def _overwrite_trajectory_reward(sequence_example: tf.train.SequenceExample,
reward: float) -> tf.train.SequenceExample:
"""Overwrite the reward in the trace (sequence_example) with the given one.
Expand Down Expand Up @@ -401,10 +414,7 @@ def collect_data(self,
compilation_runner.ProcessKilledException is passed through.
ValueError if example under default policy and ml policy does not match.
"""
if _KEEP_TEMPS.present:
tempdir_context = NonTemporaryDirectory(dir=_KEEP_TEMPS.value)
else:
tempdir_context = tempfile.TemporaryDirectory()
tempdir_context = get_workdir_context()

with tempdir_context as tempdir:
final_cmd_line = loaded_module_spec.build_command_line(tempdir)
Expand Down
17 changes: 13 additions & 4 deletions compiler_opt/rl/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
import contextlib
import io
import os
import tempfile
from typing import Callable, Generator, List, Optional, Tuple, Type

import numpy as np

from compiler_opt.rl import corpus
from compiler_opt.rl import log_reader
from compiler_opt.rl import compilation_runner


class StepType(Enum):
Expand All @@ -47,6 +47,7 @@ class TimeStep:
score_default: Optional[dict[str, float]]
context: Optional[str]
module_name: str
working_dir: str
obs_id: Optional[int]
step_type: StepType

Expand Down Expand Up @@ -115,10 +116,12 @@ class ClangProcess:
"""

def __init__(self, proc: subprocess.Popen,
get_scores_fn: Callable[[], dict[str, float]], module_name):
get_scores_fn: Callable[[], dict[str, float]], module_name: str,
working_dir: str):
self._proc = proc
self._get_scores_fn = get_scores_fn
self._module_name = module_name
self._working_dir = working_dir

def get_scores(self, timeout: Optional[int] = None):
self._proc.wait(timeout=timeout)
Expand All @@ -133,10 +136,11 @@ def __init__(
proc: subprocess.Popen,
get_scores_fn: Callable[[], dict[str, float]],
module_name: str,
working_dir: str,
reader_pipe: io.BufferedReader,
writer_pipe: io.BufferedWriter,
):
super().__init__(proc, get_scores_fn, module_name)
super().__init__(proc, get_scores_fn, module_name, working_dir)
self._reader_pipe = reader_pipe
self._writer_pipe = writer_pipe
self._obs_gen = log_reader.read_log_from_file(self._reader_pipe)
Expand All @@ -150,6 +154,7 @@ def __init__(
score_default=None,
context=None,
module_name=module_name,
working_dir=working_dir,
obs_id=None,
step_type=StepType.LAST,
)
Expand Down Expand Up @@ -180,6 +185,7 @@ def _get_step_type() -> StepType:
score_default=None,
context=obs.context,
module_name=self._module_name,
working_dir=self._working_dir,
obs_id=obs.observation_id,
step_type=_get_step_type(),
)
Expand Down Expand Up @@ -235,7 +241,8 @@ def clang_session(
Yields:
Either the constructed InteractiveClang or DefaultClang object.
"""
with tempfile.TemporaryDirectory() as td:
tempdir_context = compilation_runner.get_workdir_context()
with tempdir_context as td:
task_working_dir = os.path.join(td, '__task_working_dir__')
os.mkdir(task_working_dir)
task = task_type()
Expand Down Expand Up @@ -264,6 +271,7 @@ def _get_scores() -> dict[str, float]:
proc,
_get_scores,
module.name,
task_working_dir,
reader_pipe,
writer_pipe,
)
Expand All @@ -272,6 +280,7 @@ def _get_scores() -> dict[str, float]:
proc,
_get_scores,
module.name,
task_working_dir,
)

finally:
Expand Down
27 changes: 27 additions & 0 deletions compiler_opt/rl/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import ctypes
from unittest import mock
import subprocess
import os
import tempfile
from absl.testing import flagsaver

from typing import Dict, List, Optional

Expand Down Expand Up @@ -161,6 +164,30 @@ def test_interactive_clang_session(self, mock_popen):
self.assertEqual(obs.context, f'context_{idx}')
mock_popen.assert_called_once()

@mock.patch('subprocess.Popen')
def test_interactive_clang_temp_dir(self, mock_popen):
mock_popen.side_effect = mock_interactive_clang
working_dir = None

with env.clang_session(
_CLANG_PATH, _MOCK_MODULE, MockTask, interactive=True) as clang_session:
for _ in range(_NUM_STEPS):
obs = clang_session.get_observation()
working_dir = obs.working_dir
self.assertEqual(os.path.exists(working_dir), True)
self.assertEqual(os.path.exists(working_dir), False)

with tempfile.TemporaryDirectory() as td:
with flagsaver.flagsaver((env.compilation_runner._KEEP_TEMPS, td)): # pylint: disable=protected-access
with env.clang_session(
_CLANG_PATH, _MOCK_MODULE, MockTask,
interactive=True) as clang_session:
for _ in range(_NUM_STEPS):
obs = clang_session.get_observation()
working_dir = obs.working_dir
self.assertEqual(os.path.exists(working_dir), True)
self.assertEqual(os.path.exists(working_dir), True)


class MLGOEnvironmentTest(tf.test.TestCase):

Expand Down

0 comments on commit 08e6242

Please sign in to comment.