Skip to content

Commit

Permalink
feat: crash safety in fuzzing loop (#54)
Browse files Browse the repository at this point in the history
* feat: crash safety

* fix: typo in known issues
  • Loading branch information
ganler authored Sep 25, 2022
1 parent 0f9fc34 commit a8307f2
Show file tree
Hide file tree
Showing 18 changed files with 235 additions and 91 deletions.
2 changes: 1 addition & 1 deletion doc/concept.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
- **Meta information**: `meta.json` meta information.
- `"system"`: like `"tvm-gpu"` and `"ort-cpu"`
- `"version"`: a version string hooked from `${SystemPackage}.__version__`
- `"symptom"`: `"crash"` or `"inconsistency"`
- `"symptom"`: `"crash"` or `"inconsistency"` or `"timeout"`
- `"version_id"` (optional): an identifier of the system's version (e.g., git hash or version strings)

## Abstract Operators (AO)
Expand Down
9 changes: 9 additions & 0 deletions doc/known_issues.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
## TensorFlow Lite CUDA Init Error in Crash-safe Mode

If we run `tflite` in the fuzzing loop with `fuzz.crash_safe=true`, you may encounter tons of crashes of:

```txt
F tensorflow/stream_executor/cuda/cuda_driver.cc:219] Failed setting context: CUDA_ERROR_NOT_INITIALIZED: initialization error
```

It is temporarily "fixed" by setting environment variable `CUDA_VISIBLE_DEVICES=-1` if we found cuda is not used in the fuzzing loop. Nevertheless, this should be a TensorFlow bug that needs to be fixed.
145 changes: 121 additions & 24 deletions nnsmith/backends/factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import multiprocessing as mp
import os
import sys
import traceback
from abc import ABC, abstractmethod
Expand All @@ -18,13 +20,10 @@


class BackendFactory(ABC):
def __init__(self, target="cpu", optmax: bool = False, catch_process_crash=True):
def __init__(self, target="cpu", optmax: bool = False):
super().__init__()
self.target = target
self.optmax = optmax
# If true, will run the compilation and execution in a subprocess.
# and catch segfaults returned as BugReport.
self.catch_process_crash = catch_process_crash

@property
@abstractmethod
Expand Down Expand Up @@ -53,16 +52,18 @@ def skip_dtypes(cls) -> List[DType]:
def make_backend(self, model: Model) -> BackendCallable:
raise NotImplementedError

def checked_make_backend(self, model: Model) -> BackendCallable:
if self.make_backend.dispatch(type(model)):
return self.make_backend(model)
else:
def critical_assert_dispatchable(self, model: Model):
if not self.make_backend.dispatch(type(model)):
CORE_LOG.critical(
f"[Not implemented] {type(self).__name__} for {type(model).__name__}!\n"
"Check https://github.com/ise-uiuc/nnsmith#backend-model-support for compatile `model.type` and `backend.type`."
)
sys.exit(1)

def checked_make_backend(self, model: Model) -> BackendCallable:
self.critical_assert_dispatchable(model)
return self.make_backend(model)

def checked_compile(self, testcase: TestCase) -> Union[BackendCallable, BugReport]:
try: # compilation
return self.checked_make_backend(testcase.model)
Expand Down Expand Up @@ -100,11 +101,98 @@ def checked_exec(
log=traceback.format_exc(),
)

def checked_compile_and_exec(self, testcase: TestCase):
executable = self.checked_compile(testcase)
if isinstance(executable, BugReport):
return executable
return self.checked_exec(executable, testcase)
def checked_compile_and_exec(
self, testcase: TestCase, crash_safe=False, timeout=None
) -> Union[Dict[str, np.ndarray], BugReport]:
# pre-check model dispatchability
self.critical_assert_dispatchable(testcase.model)
if (
not crash_safe and timeout is None
): # not crash safe, compile & exec natively in current process.
bug_or_exec = self.checked_compile(testcase)
if isinstance(bug_or_exec, BugReport):
return bug_or_exec
return self.checked_exec(bug_or_exec, testcase)
else: # crash safe, compile & exec in a separate process.
if timeout is not None:
assert isinstance(
timeout, int
), "timeout are `seconds` => must be an integer."

# TODO: optimize to shared memory in the future (Python 3.8+)
# https://docs.python.org/3/library/multiprocessing.shared_memory.html
# NOTE: Similar implementation as Tzer.
with mp.Manager() as manager:
shared_dict = manager.dict(
{
"symptom": None,
"stage": Stage.COMPILATION,
"log": None,
"output": None,
"uncaught_exception": None,
}
)

def crash_safe_compile_exec(sdict):
try:
bug_or_exec = self.checked_compile(testcase)
if isinstance(bug_or_exec, BugReport):
sdict["symptom"] = bug_or_exec.symptom
sdict["log"] = bug_or_exec.log
return

sdict["stage"] = Stage.EXECUTION
bug_or_result = self.checked_exec(bug_or_exec, testcase)
if isinstance(bug_or_result, BugReport):
sdict["symptom"] = bug_or_result.symptom
sdict["log"] = bug_or_result.log
return

sdict["output"] = bug_or_result
except Exception as e:
sdict["uncaught_exception"] = e

p = mp.Process(target=crash_safe_compile_exec, args=(shared_dict,))

p.start()
p.join(timeout=timeout)
if p.is_alive():
p.terminate()
p.join()
assert not p.is_alive()
return BugReport(
testcase=testcase,
system=self.system_name,
symptom=Symptom.TIMEOUT,
stage=shared_dict["stage"],
log=f"Timeout after {timeout} seconds.",
)

if shared_dict["output"] is not None:
return shared_dict["output"]

if shared_dict["uncaught_exception"] is not None:
CORE_LOG.critical(
f"Found uncaught {type(shared_dict['uncaught_exception'])} in crash safe mode."
)
raise shared_dict["uncaught_exception"]

if p.exitcode != 0:
return BugReport(
testcase=testcase,
system=self.system_name,
symptom=Symptom.SEGFAULT,
stage=shared_dict["stage"],
log=f"Process crashed with exit code: {p.exitcode}",
)
else:
return BugReport(
testcase=testcase,
system=self.system_name,
symptom=shared_dict["symptom"],
stage=shared_dict["stage"],
log=shared_dict["log"],
)

def verify_results(
self, output: Dict[str, np.ndarray], testcase: TestCase, equal_nan=True
Expand All @@ -129,9 +217,6 @@ def verify_results(
def verify_testcase(
self, testcase: TestCase, equal_nan=True
) -> Optional[BugReport]:
# TODO(@ganler): impl fault catching in subprocess
assert not self.catch_process_crash, "not implemented"

executable = self.checked_compile(testcase)
if isinstance(executable, BugReport):
return executable
Expand All @@ -146,8 +231,26 @@ def verify_testcase(
return None

def make_testcase(
self, model: Model, input: Dict[str, np.ndarray] = None
self,
model: Model,
input: Dict[str, np.ndarray] = None,
crash_safe=False,
timeout=None,
) -> Union[BugReport, TestCase]:
if input is None:
input = self.make_random_input(model.input_like)

partial_testcase = TestCase(
model=model, oracle=Oracle(input=input, output=None)
)
bug_or_res = self.checked_compile_and_exec(
partial_testcase, crash_safe=crash_safe, timeout=timeout
)
if isinstance(bug_or_res, BugReport):
return bug_or_res
else:
partial_testcase.oracle.output = bug_or_res

try: # compilation
executable = self.checked_make_backend(model)
except InternalError as e:
Expand Down Expand Up @@ -183,7 +286,7 @@ def make_testcase(
)

@staticmethod
def init(name, target="cpu", optmax=True, catch_process_crash=False, **kwargs):
def init(name, target="cpu", optmax=True, **kwargs):
if name is None:
raise ValueError(
"Backend type cannot be None. Specify via `backend.type=[onnxruntime | tvm | tensorrt | tflite | xla | iree]`"
Expand All @@ -198,7 +301,6 @@ def init(name, target="cpu", optmax=True, catch_process_crash=False, **kwargs):
return ORTFactory(
target=target,
optmax=optmax,
catch_process_crash=catch_process_crash,
**kwargs,
)
elif name == "tvm":
Expand All @@ -209,7 +311,6 @@ def init(name, target="cpu", optmax=True, catch_process_crash=False, **kwargs):
return TVMFactory(
target=target,
optmax=optmax,
catch_process_crash=catch_process_crash,
**kwargs,
)
elif name == "tensorrt":
Expand All @@ -218,7 +319,6 @@ def init(name, target="cpu", optmax=True, catch_process_crash=False, **kwargs):
return TRTFactory(
target=target,
optmax=optmax,
catch_process_crash=catch_process_crash,
**kwargs,
)
elif name == "tflite":
Expand All @@ -227,7 +327,6 @@ def init(name, target="cpu", optmax=True, catch_process_crash=False, **kwargs):
return TFLiteFactory(
target=target,
optmax=optmax,
catch_process_crash=catch_process_crash,
**kwargs,
)
elif name == "xla":
Expand All @@ -236,7 +335,6 @@ def init(name, target="cpu", optmax=True, catch_process_crash=False, **kwargs):
return XLAFactory(
target=target,
optmax=optmax,
catch_process_crash=catch_process_crash,
**kwargs,
)
elif name == "iree":
Expand All @@ -245,7 +343,6 @@ def init(name, target="cpu", optmax=True, catch_process_crash=False, **kwargs):
return IREEFactory(
target=target,
optmax=optmax,
catch_process_crash=catch_process_crash,
**kwargs,
)
else:
Expand Down
8 changes: 2 additions & 6 deletions nnsmith/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class IREEFactory(BackendFactory):
def __init__(self, target="cpu", optmax: bool = False, catch_process_crash=True):
def __init__(self, target="cpu", optmax: bool = False):
"""
Initialize the IREE backend factory.
Expand All @@ -21,8 +21,6 @@ def __init__(self, target="cpu", optmax: bool = False, catch_process_crash=True)
The compilation target including "cpu" (same as "llvm-cpu"), "vmvx", "vulkan-spirv", by default "cpu"
optmax : bool, optional
Release mode or not, by default False
catch_process_crash : bool, optional
Doing compilation without forking (may crash), by default True
"""
if target == "cpu":
target = "llvm-cpu"
Expand All @@ -34,9 +32,7 @@ def __init__(self, target="cpu", optmax: bool = False, catch_process_crash=True)
assert (
target in supported_backends
), f"Unsupported target {target}. Consider one of {supported_backends}"
super().__init__(
target=target, optmax=optmax, catch_process_crash=catch_process_crash
)
super().__init__(target=target, optmax=optmax)

@property
def system_name(self) -> str:
Expand Down
15 changes: 4 additions & 11 deletions nnsmith/backends/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@
from multipledispatch import dispatch

from nnsmith.backends.factory import BackendCallable, BackendFactory
from nnsmith.materialize.tensorflow import (
TFModel,
TFNetCallable,
np_dict_from_tf,
tf_dict_from_np,
)
from nnsmith.materialize.tensorflow import TFModel, TFNetCallable


class TFLiteRunner:
Expand All @@ -20,12 +15,7 @@ def __init__(self, tfnet_callable: TFNetCallable) -> None:

def __call__(self, input: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
return self.tfnet_callable(**input)
# https://github.com/tensorflow/tensorflow/issues/34536#issuecomment-565632906
# TFLite doesn't support NVIDIA GPU.
# It can automatically convert input args to np.ndarray, and it outputs np.ndarray.
tf_input = tf_dict_from_np(input)
tf_output = self.tfnet_callable(**tf_input)
return np_dict_from_tf(tf_output)


class TFLiteFactory(BackendFactory):
Expand All @@ -39,6 +29,9 @@ class TFLiteFactory(BackendFactory):
"""

def __init__(self, target, optmax, **kwargs) -> None:
# https://github.com/tensorflow/tensorflow/issues/34536#issuecomment-565632906
# TFLite doesn't support NVIDIA GPU.
assert target != "cuda"
super().__init__(target, optmax, **kwargs)

@property
Expand Down
4 changes: 2 additions & 2 deletions nnsmith/backends/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


class XLAFactory(BackendFactory):
def __init__(self, target="cpu", optmax: bool = False, catch_process_crash=True):
super().__init__(target, optmax, catch_process_crash)
def __init__(self, target="cpu", optmax: bool = False):
super().__init__(target, optmax)

if self.target == "cpu":
self.device = tf.device(tf.config.list_logical_devices("CPU")[0].name)
Expand Down
Loading

0 comments on commit a8307f2

Please sign in to comment.