-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fuser Improvements #224
Fuser Improvements #224
Conversation
dba7474
to
e8ec1a5
Compare
Good point with the dual fuser. We could add tracking of function -> derivative edges to the tracer natively (like we do with Handle edges, but now for all nodes), and slice off the backward fusion groups in this way. On the other hand backward fusion groups could also include simple maps appearing in neighboring ops. |
e8ec1a5
to
e966676
Compare
Some numbers now that the fusion in the backward pass works (caveat: because of #230 I cannot verify correctness, though it doesn't crash immediately...):
|
This still needs to be rebased onto the jit branch before it is ready to merge. I was waiting because of the hypothesis that there pybind issues. I'll see what happens tomorrow. |
e966676
to
6ea766f
Compare
Force pushed rebase |
Some timings for the kernels themselves. These are not precisely the same because the boundaries are not precisely the same, but they are pretty close:
|
6ea766f
to
f44a9f2
Compare
f44a9f2
to
b16e0aa
Compare
@@ -191,6 +206,8 @@ void emitCompilationUnit(std::ostream & out, | |||
body << format("auto ${node} = ${access};\n",env); | |||
} | |||
for(auto n : subgraph.nodes()) { | |||
if(n->kind() == kConcat) | |||
continue; // Concat nodes by narrowing the output Tensors before the kernel runs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should it be nodes -> works?
struct ConcatDesc { | ||
size_t nSubtensors; // == 1 for outputs that are not concats, otherwise it is the number tensors concatenated | ||
size_t dim; // dimension along which the concat occurs | ||
std::unique_ptr<TensorDesc> subtensorDesc; // descriptor for the subtensor, if it exists |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, why is there a single subtensor descriptor for multiple subtensors? Do they have to be of the same size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They do, yes, because if they are not they cannot be produced by a simple map that works over one size. This is checked for in the optimization pass.
// meaning: stride[dim - 1] != stride[dim]*size[dim] | ||
// so dim - 1 is no longer contiguous | ||
cont[dim - 1] = false; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps assert that desc.size[dim] % nSubtensors == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorDesc do not store actual sizes, just contiguity information, so while this assertion is true, it is hard to check. This is because the fusion group is not specialized on sizes, just contiguity.
void emitCompilationUnit(std::ostream & out, | ||
const std::string & name, | ||
AnnotatedGraph & agraph) { | ||
std::vector<ConcatDesc> emitCompilationUnit(std::ostream & out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels a bit weird that we fill in concat descriptors only when we emit the code 😕 I'd feel tempted to inspect them earlier if I was debugging and would be really surprised to find that there are no descriptors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emitCompilationUnit
compilation unit is called in the constructor, so it is basically being initialized at the earliest possible moment.
Also, I don't feel like committing that graph visualizer into PyTorch 😕 Can we just have a mode that emits JSON and then a separate script (you can put it in a gist) that has all HTML strings and can produce the visualization? |
I'm happy to hide the visualizer more (it shouldn't be in |
Is there a problem with doing what I suggested? Let's just have a JSON printer and the we can develop a whole toolkit for visualizing/inspecting this data on the side |
I don't want to have to dig up code that isn't with the repo just because I wanted to write a file to disk to debug the IR. I guess I don't see a serious cost to keeping this code in the repo, but I do see a serious one for not. Longer-term we can invest in better debugging, but right now that doesn't exist and I need something to use. |
I really think this code doesn't belong to the core repo, and I can't see any problem with keeping it on a side. It's not like you have a high overhead for using it. Just paste it in the place where you want to use it (probably word language model) and you're done. |
Well, I can't put it in the application, because I can't get to the code that dumps the IR between optimization passes. It is inside the Traceable class. I also want the benefits of source control for debugging code so that I don't have to dig up the right version of it everytime I want to do something with it. |
|
if the visualizer code can go into |
…d update the fusion compiler to support code that includes final concats
b16e0aa
to
c70d723
Compare
c70d723
to
2dcd39a
Compare
I looked into what the fuser does with the backwards pass and have a plan to fix it. This PR so far just adds code to visualize the graph, and demonstrate how to fix the problem. I will later add an update with actual code support.
For the word language model the full graph looks like this:
https://rawgit.com/zdevito/3d619ef61f698815fe80525f0ad42e97/raw/8430fab99cbfedaba6e0fb43996db75b2a82e69d/before.html
The forward pass cell, FusionGroup_0, is fully fused but the reverse is still in 4 separate fusion groups (6,7,8,9). This is a byproduct of the heuristic that tries to fuse many producers into a single consumer. In the backward case, Concat is not a simple map and hence is not fusable. Because of this, there is not a single output of the backward pass LSTM cell, but rather 4 outputs each starting a new 'seed' fusion group which cannot merge with the other fusion groups.
One approach would be to have better handling for merging adjacent fusion groups together. This can get tricky - for instance, 7 can merge with 9, but only if you observe that 8 can go after 9. We should do this eventually, but we don't need to do it now.
The approach I want to go with is to allow Concat nodes to be an exit node of a FusionGroup. This fixes the issue above. Unlike simply fixing fusions it also makes sure the Concat is not done in a separate pass (which adds kernel launches and uses more memory).
If we allow this then the trace is what we want:
https://rawgit.com/zdevito/104ea16a7234e5688fc62b87cc4da711/raw/4e330d888dad4fef8ae949ffe6e3856dd5ba3faf/after.html
It is valid to fuse a Concat into a group as long as the output of concat (which is no longer the simple map size) is not used in the group, and each individual element being concatenated is the same size (which will be true in this case). The implementation strategy is pretty easy as well: allocate the Concat output before the fusion group runs. Narrow the tensors that form the body of the Concat and then pass those into the fusion kernel as normal outputs.
Finally, a thought about fusions: if we have a valid fusion in the forward pass, then there is always a corresponding fusion for the backward. The gradient of a simple map is still a simple map. This suggests that if we find a forward fusion we like, even if we didn't add new fusion heuristics, we should be able to find the fusion for the gradient. Or equivalently, there exists a dual of our fusion engine that works by fusing consumers into producers scanning in the opposite direction as our current pass.