Skip to content

Commit

Permalink
Support gradient checking / autodiff validation (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
ganler authored May 9, 2023
1 parent bad7426 commit 881cd7f
Show file tree
Hide file tree
Showing 20 changed files with 258 additions and 116 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ jobs:
yes | python nnsmith/cli/model_gen.py debug.viz=true model.type=torch backend.type="pt2 backend@inductor" mgen.method=concolic
yes | python nnsmith/cli/model_gen.py model.type=torch mgen.method=symbolic-cinit mgen.rank_choices="[4]" mgen.dtype_choices="[f32]" mgen.include="[core.NCHWConv2d, core.ReLU]" mgen.patch_requires=./tests/mock/requires_patch.py
yes | python nnsmith/cli/model_gen.py model.type=torch mgen.method=symbolic-cinit mgen.rank_choices="[4]" mgen.dtype_choices="[f32]" mgen.include="[core.NCHWConv2d, core.ReLU]" mgen.patch_requires=./tests/mock/requires_patch.py backend.type=torchjit
yes | python nnsmith/cli/model_gen.py model.type=torch mgen.method=symbolic-cinit mgen.rank_choices="[4]" mgen.dtype_choices="[f32]" mgen.include="[core.NCHWConv2d, core.ReLU]" mgen.patch_requires=./tests/mock/requires_patch.py backend.type=torchjit mgen.grad_check=true
yes | python nnsmith/cli/model_gen.py model.type=torch mgen.method=symbolic-cinit mgen.rank_choices="[4]" mgen.dtype_choices="[f32]" mgen.include="[core.NCHWConv2d, core.ReLU]" mgen.patch_requires=./tests/mock/requires_patch.py backend.type=pt2 mgen.grad_check=true
- name: Test ONNX + ONNXRuntime
run: |
pytest tests/onnxruntime
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ cython_debug/

# Exclude
!requirements/**/*.txt
.DS_Store
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
| Models | [`tvm`](https://github.com/apache/tvm) | [`pt2`](https://pytorch.org/get-started/pytorch-2.0/) | [`torchjit`](https://pytorch.org/docs/stable/jit.html) | [`tensorrt`](https://github.com/NVIDIA/TensorRT) | [`onnxruntime`](https://github.com/microsoft/onnxruntime) | [`xla`](https://www.tensorflow.org/xla) | [`tflite`](https://www.tensorflow.org/lite) |
| ------------ | ------------------------------------ | ----------------------------------------------- | ---------------------------------------------- | ----------------------------------------- | ------------------------------------- | ----------------------------------------------------- | ------------ |
| ONNX || | ||| | |
| PyTorch | 🔨 ||| | | | |
| PyTorch | 🔨 |📈 |📈 | | | | |
| TensorFlow | 🔨 | | | | |||

✅: Supported; 🔨: Coming soon;
✅: Supported; 📈: Supports gradient check; 🔨: Coming soon;

</div>

Expand Down
6 changes: 6 additions & 0 deletions doc/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ nnsmith.model_exec model.type=onnx \
cmp.with='{type:tvm, optmax:true, target:cpu}'
```

## Experimental: Gradient checking

For `pt2` and `torchjit`, we have initial supports for examining the gradients.

To enable that, just need to append `mgen.grad_check=true` to the examples illustrated above.

## Data type testing

Many compilers do not support a full set of operators (in ONNX and TensorFlow). Thus, we infer the support set by doing single operator testing.
Expand Down
6 changes: 2 additions & 4 deletions nnsmith/abstract/dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ def short(self) -> str:
DType.bool: "b",
}[self]

@staticmethod
def is_float(dtype): # Don't use string. Make it well-formed.
assert isinstance(dtype, DType)
return dtype in [DType.float32, DType.float64]
def is_float(self):
return self in [DType.float16, DType.float32, DType.float64]

@staticmethod
def from_str(s):
Expand Down
15 changes: 12 additions & 3 deletions nnsmith/backends/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def checked_exec(
def checked_compile_and_exec(
self, testcase: TestCase, crash_safe=False, timeout=None
) -> Union[Dict[str, np.ndarray], BugReport]:
# pre-check model dispatchability
# pre-check if model is dispatchable
self.critical_assert_dispatchable(testcase.model)
if (
not crash_safe and timeout is None
Expand Down Expand Up @@ -354,7 +354,12 @@ def emit_run(self, out_name: str, opt_name: str, inp_name: str) -> str:

@staticmethod
def init(
name: str, target: str = "cpu", optmax: bool = True, parse_name=False, **kwargs
name: str,
target: str = "cpu",
ad: str = None,
optmax: bool = True,
parse_name=False,
**kwargs,
):
if name is None:
raise ValueError(
Expand Down Expand Up @@ -414,9 +419,13 @@ def init(
from nnsmith.backends.torchjit import TorchJIT

return TorchJIT(target=target, optmax=optmax, **kwargs)
elif name == "torchjitAD":
from nnsmith.backends.torchjitAD import TorchJITAD

return TorchJITAD(target=target, optmax=optmax, ad=ad, **kwargs)
elif name == "pt2":
from nnsmith.backends.pt2 import PT2

return PT2(target=target, optmax=optmax, **kwargs)
return PT2(target=target, optmax=optmax, ad=ad, **kwargs)
else:
raise ValueError(f"unknown backend: {name}")
47 changes: 37 additions & 10 deletions nnsmith/backends/pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import numpy as np
import torch
import torch.fx
from multipledispatch import dispatch

from nnsmith.backends.factory import BackendCallable, BackendFactory
from nnsmith.materialize.torch import TorchModel
from nnsmith.materialize.torch import TorchModel, numpify
from nnsmith.materialize.torch.symbolnet import FxTracing


Expand All @@ -16,6 +17,8 @@ def __init__(self, target: str = "cpu", optmax: bool = True, **kwargs):
self.device = torch.device("cpu")
elif self.target == "cuda":
self.device = torch.device("cuda")
elif self.target == "mps":
self.device = torch.device("mps")
else:
raise ValueError(
f"Unknown target: {self.target}. Only `cpu` and `cuda` are supported."
Expand All @@ -35,7 +38,15 @@ def import_libs(self) -> List[str]:

@dispatch(TorchModel)
def make_backend(self, model: TorchModel) -> BackendCallable:
torch_net = model.torch_model.to(self.device).eval()
torch_net = model.torch_model.to(self.device)

do_grad_check = model.needs_grad_check()

if do_grad_check:
torch_net = torch_net.train()
else:
torch_net = torch_net.eval()

with torch.no_grad():
with FxTracing():
traced = torch.fx.symbolic_trace(torch_net)
Expand All @@ -44,15 +55,31 @@ def make_backend(self, model: TorchModel) -> BackendCallable:
)

def closure(inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
nonlocal do_grad_check
input_ts = [torch.from_numpy(v).to(self.device) for _, v in inputs.items()]
with torch.no_grad():
output: Tuple[torch.Tensor] = compiled(*input_ts)
return {
k: v.cpu().detach().resolve_conj().numpy()
if v.is_conj()
else v.cpu().detach().numpy()
for k, v in zip(torch_net.output_like.keys(), output)
}
if do_grad_check:
outputs: List[torch.Tensor] = compiled(*input_ts)
params = {k: v for k, v in compiled.named_parameters()}
ret = {}

for name, output in zip(torch_net.output_like.keys(), outputs):
ret[name] = numpify(output)
if output.requires_grad:
# get Vector-Jacobian product
out_grad = torch.autograd.grad(
outputs=output,
inputs=params.values(),
grad_outputs=torch.ones_like(output),
retain_graph=True,
allow_unused=True,
)
for k, v in zip(params.keys(), out_grad):
ret[name + "_vjp_" + k] = numpify(v)
else:
with torch.no_grad():
outputs: Tuple[torch.Tensor] = compiled(*input_ts)
ret = {k: numpify(v) for k, v in zip(torch_net.output_like, outputs)}
return ret

return closure

Expand Down
71 changes: 44 additions & 27 deletions nnsmith/backends/torchjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch.utils.mobile_optimizer import optimize_for_mobile

from nnsmith.backends.factory import BackendCallable, BackendFactory
from nnsmith.materialize.torch import TorchModel
from nnsmith.materialize.torch import TorchModel, numpify

# Check https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html
# for more PyTorch-internal options.
Expand All @@ -24,43 +24,60 @@ def __init__(self, target="cpu", optmax: bool = False, **kwargs):
self.device = torch.device("cuda")
torch.backends.cudnn.benchmark = True
else:
raise ValueError(
f"Unknown target: {self.target}. Only `cpu` and `cuda` are supported."
)
raise ValueError(f"Unknown {target=}. Only `cpu` and `cuda` are supported.")

@property
def system_name(self) -> str:
return "torchjit"

@dispatch(TorchModel)
def make_backend(self, model: TorchModel) -> BackendCallable:
torch_net = model.torch_model.to(self.device).eval()
torch_net = model.torch_model.to(self.device)
trace_inp = [ts.to(self.device) for ts in torch_net.get_random_inps().values()]
with torch.no_grad():
with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
category=torch.jit.TracerWarning,
)
exported = torch.jit.trace(
torch_net,
trace_inp,
)
exported = torch.jit.freeze(exported) # Fronzen graph.
exported = torch.jit.optimize_for_inference(exported)
if self.target == "cpu" and NNSMITH_PTJIT_OPT_MOBILE:
exported = optimize_for_mobile(exported)

do_grad_check = model.needs_grad_check()

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=torch.jit.TracerWarning)
if do_grad_check:
torch_net = torch_net.train()
compiled = torch.jit.trace(torch_net, trace_inp)
else:
torch_net = torch_net.eval()
with torch.no_grad():
compiled = torch.jit.trace(torch_net, trace_inp)
compiled = torch.jit.freeze(compiled) # Frozen graph
compiled = torch.jit.optimize_for_inference(compiled)

if self.target == "cpu" and NNSMITH_PTJIT_OPT_MOBILE:
compiled = optimize_for_mobile(compiled)

def closure(inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
nonlocal do_grad_check
input_ts = [torch.from_numpy(v).to(self.device) for _, v in inputs.items()]
with torch.no_grad():
output: Tuple[torch.Tensor] = exported(*input_ts)
return {
k: v.cpu().detach().resolve_conj().numpy()
if v.is_conj()
else v.cpu().detach().numpy()
for k, v in zip(torch_net.output_like.keys(), output)
}
if do_grad_check:
outputs: List[torch.Tensor] = compiled(*input_ts)
params = {k: v for k, v in compiled.named_parameters()}
ret = {}

for name, output in zip(torch_net.output_like.keys(), outputs):
ret[name] = numpify(output)
if output.requires_grad:
# get Vector-Jacobian product
out_grad = torch.autograd.grad(
outputs=output,
inputs=params.values(),
grad_outputs=torch.ones_like(output),
retain_graph=True,
allow_unused=True,
)
for k, v in zip(params.keys(), out_grad):
ret[name + "_vjp_" + k] = numpify(v)
else:
with torch.no_grad():
outputs: Tuple[torch.Tensor] = compiled(*input_ts)
ret = {k: numpify(v) for k, v in zip(torch_net.output_like, outputs)}
return ret

return closure

Expand Down
10 changes: 9 additions & 1 deletion nnsmith/cli/fuzz.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(

self.factory = BackendFactory.init(
cfg["backend"]["type"],
ad=cfg["ad"]["type"],
target=cfg["backend"]["target"],
optmax=cfg["backend"]["optmax"],
parse_name=True,
Expand All @@ -138,8 +139,14 @@ def __init__(
model_cfg["type"], backend_target=cfg["backend"]["target"]
)
self.ModelType.add_seed_setter()

self.opset = op_filter(
auto_opset(self.ModelType, self.factory, vulops=cfg["mgen"]["vulops"]),
auto_opset(
self.ModelType,
self.factory,
vulops=cfg["mgen"]["vulops"],
grad=cfg["mgen"]["grad_check"],
),
cfg["mgen"]["include"],
cfg["mgen"]["exclude"],
)
Expand Down Expand Up @@ -186,6 +193,7 @@ def make_testcase(self, seed) -> TestCase:
model.attach_viz(ir)

model.refine_weights() # either random generated or gradient-based.
model.set_grad_check(self.cfg["mgen"]["grad_check"])
oracle = model.make_oracle()
return TestCase(model, oracle)

Expand Down
8 changes: 7 additions & 1 deletion nnsmith/cli/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def main(cfg: DictConfig):
factory = None

# GENERATION
opset = auto_opset(ModelType, factory, vulops=mgen_cfg["vulops"])
opset = auto_opset(
ModelType,
factory,
vulops=mgen_cfg["vulops"],
grad=mgen_cfg["grad_check"],
)
opset = op_filter(opset, mgen_cfg["include"], mgen_cfg["exclude"])
hijack_patch_requires(mgen_cfg["patch_requires"])
activate_ext(opset=opset, factory=factory)
Expand Down Expand Up @@ -82,6 +87,7 @@ def main(cfg: DictConfig):

model = ModelType.from_gir(ir)
model.refine_weights() # either random generated or gradient-based.
model.set_grad_check(mgen_cfg["grad_check"])
oracle = model.make_oracle()
tmat = time.time() - tmat_begin

Expand Down
4 changes: 4 additions & 0 deletions nnsmith/config/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ mgen: # model gen.
include: null # ops to include; example mgen.include="[core.NCHWConv2d, core.ReLU]"
exclude: null # ops to exclude;
patch_requires: [] # files that with @patch_requires
grad_check: false # additionally check gradients

# backend config
backend:
type: null
optmax: true
target: "cpu"

ad:
type: null

cache:
topset: true # Run dtype test with automatically maintained cache

Expand Down
35 changes: 21 additions & 14 deletions nnsmith/difftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,26 @@ def assert_allclose(
lhs = actual[key]
rhs = desired[key]

# check if lhs is np.ndarray
if not isinstance(lhs, np.ndarray):
raise TypeError(f"{actual_name}[{key}] is not np.ndarray but {type(lhs)}")
if lhs is not None and rhs is not None:
# check if lhs is np.ndarray
if lhs is not None and not isinstance(lhs, np.ndarray):
raise TypeError(
f"{actual_name}[{key}] is not np.ndarray but {type(lhs)}"
)

# check if rhs is np.ndarray
if not isinstance(rhs, np.ndarray):
raise TypeError(f"{oracle_name}[{key}] is not np.ndarray but {type(rhs)}")
# check if rhs is np.ndarray
if rhs is not None and not isinstance(rhs, np.ndarray):
raise TypeError(
f"{oracle_name}[{key}] is not np.ndarray but {type(rhs)}"
)

testing.assert_allclose(
lhs,
rhs,
equal_nan=equal_nan,
rtol=rtol,
atol=atol,
err_msg=f"{actual_name} != {oracle_name} at {key}",
)
testing.assert_allclose(
lhs,
rhs,
equal_nan=equal_nan,
rtol=rtol,
atol=atol,
err_msg=f"{actual_name} != {oracle_name} at {key}",
)
else:
return lhs is None and rhs is None
Loading

0 comments on commit 881cd7f

Please sign in to comment.