Skip to content

Commit

Permalink
move write_vis into contrib
Browse files Browse the repository at this point in the history
  • Loading branch information
zdevito committed Sep 18, 2017
1 parent 9fa935f commit c70d723
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 105 deletions.
Empty file added torch/contrib/__init__.py
Empty file.
112 changes: 112 additions & 0 deletions torch/contrib/_graph_vis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Experimental. Tools for visualizing the torch.jit.Graph objects.
"""
import string
import json

_vis_template = string.Template("""
<!doctype html>
<html>
<head>
<title>$name</title>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/vis/4.20.1/vis.min.js"></script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/vis/4.20.1/vis.min.css" rel="stylesheet" type="text/css" />
<style type="text/css">
#mynetwork {
height: 100vh;
}
</style>
</head>
<body>
<div id="mynetwork"></div>
<script type="text/javascript">
// create an array with nodes
var nodes = new vis.DataSet(
$nodes
);
// create an array with edges
var edges = new vis.DataSet(
$edges
);
// create a network
var container = document.getElementById('mynetwork');
var data = {
nodes: nodes,
edges: edges
};
var options = $options;
var network = new vis.Network(container, data, options);
</script>
</body>
</html>
""")


def write(self, filename):
"""
Write an html file that visualizes a torch.jit.Graph using vis.js
Arguments:
self (torch.jit.Graph): the graph.
filename (string): the output filename, an html-file.
"""

nodes = []
edges = []
options = {}
for n, i in enumerate(self.inputs()):
nodes.append({
'id': i.unique(),
'label': 'input {}'.format(n),
'shape': 'square',
})

existing = set()

def add_edge(i_, n):
i = i_ if i_.kind() != 'Select' else i_.input()
if (i, n) in existing:
return
existing.add((i, n))
e = {
'from': n.unique(),
'to': i.unique(),
'arrows': 'from',
}
if i.stage() != n.stage():
e['color'] = 'green'
edges.append(e)

counts = {}
offset = 0
for n in self.nodes():
if len(n.uses()) == 0 or n.kind() == 'Select' or n.kind() == 'Undefined':
continue
ident = counts.get(n.kind(), 0)
counts[n.kind()] = ident + 1
d = {
'id': n.unique(),
'label': '{}_{}'.format(n.kind(), ident),
'y': offset,
'fixed': {'y': True},
}
if n in self.outputs():
d['shape'] = 'triangle'

for i in n.inputs():
add_edge(i, n)

nodes.append(d)
offset += 30

result = _vis_template.substitute(nodes=json.dumps(nodes),
edges=json.dumps(edges),
options=json.dumps(options),
name=filename)
with open(filename, 'w') as f:
f.write(result)
3 changes: 2 additions & 1 deletion torch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import types
import contextlib
import os
import torch.contrib._graph_vis as graph_vis
# Example how to use:
#
# import torch.jit
Expand Down Expand Up @@ -97,7 +98,7 @@ def _dump_trace(trace_name, name, suffix, complete_trace):
filename = "{}_{}_{}".format(trace_name, name, suffix)
with open(filename + ".ir", "w") as f:
f.write(str(complete_trace))
complete_trace.graph().write_vis(filename + ".html")
graph_vis.write(complete_trace.graph(), filename + ".html")


# holds run() to run the function and self.inputs which
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/_functions/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def RNNTanhCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):


def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
if input.is_cuda:
if False and input.is_cuda:
igates = F.linear(input, w_ih)
hgates = F.linear(hidden[0], w_hh)
state = fusedBackend.LSTMFused()
Expand Down
103 changes: 0 additions & 103 deletions torch/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,106 +93,3 @@ def _op(self, opname, *args, **kwargs):


torch._C.Graph.op = _op


_vis_template = string.Template("""
<!doctype html>
<html>
<head>
<title>Network | Basic usage</title>
<script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/vis/4.20.1/vis.min.js"></script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/vis/4.20.1/vis.min.css" rel="stylesheet" type="text/css" />
<style type="text/css">
#mynetwork {
height: 100vh;
}
</style>
</head>
<body>
<div id="mynetwork"></div>
<script type="text/javascript">
// create an array with nodes
var nodes = new vis.DataSet(
$nodes
);
// create an array with edges
var edges = new vis.DataSet(
$edges
);
// create a network
var container = document.getElementById('mynetwork');
var data = {
nodes: nodes,
edges: edges
};
var options = $options;
var network = new vis.Network(container, data, options);
</script>
</body>
</html>
""")


def _write_vis(self, filename):
nodes = []
edges = []
options = {}
for n, i in enumerate(self.inputs()):
nodes.append({
'id': i.unique(),
'label': 'input {}'.format(n),
'shape': 'square',
})

existing = set()

def add_edge(i_, n):
i = i_ if i_.kind() != 'Select' else i_.input()
if (i, n) in existing:
return
existing.add((i, n))
e = {
'from': n.unique(),
'to': i.unique(),
'arrows': 'from',
}
if i.stage() != n.stage():
e['color'] = 'green'
edges.append(e)

counts = {}
offset = 0
for n in self.nodes():
if len(n.uses()) == 0 or n.kind() == 'Select' or n.kind() == 'Undefined':
continue
ident = counts.get(n.kind(), 0)
counts[n.kind()] = ident + 1
d = {
'id': n.unique(),
'label': '{}_{}'.format(n.kind(), ident),
'y': offset,
'fixed': {'y': True},
}
if n in self.outputs():
d['shape'] = 'triangle'

for i in n.inputs():
add_edge(i, n)

nodes.append(d)
offset += 30

result = _vis_template.substitute(nodes=json.dumps(nodes),
edges=json.dumps(edges),
options=json.dumps(options))
with open(filename, 'w') as f:
f.write(result)


torch._C.Graph.write_vis = _write_vis

0 comments on commit c70d723

Please sign in to comment.