Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuser Improvements #224

Merged
merged 5 commits into from
Sep 18, 2017
Merged

Fuser Improvements #224

merged 5 commits into from
Sep 18, 2017

Conversation

zdevito
Copy link
Collaborator

@zdevito zdevito commented Sep 13, 2017

I looked into what the fuser does with the backwards pass and have a plan to fix it. This PR so far just adds code to visualize the graph, and demonstrate how to fix the problem. I will later add an update with actual code support.

For the word language model the full graph looks like this:
https://rawgit.com/zdevito/3d619ef61f698815fe80525f0ad42e97/raw/8430fab99cbfedaba6e0fb43996db75b2a82e69d/before.html

The forward pass cell, FusionGroup_0, is fully fused but the reverse is still in 4 separate fusion groups (6,7,8,9). This is a byproduct of the heuristic that tries to fuse many producers into a single consumer. In the backward case, Concat is not a simple map and hence is not fusable. Because of this, there is not a single output of the backward pass LSTM cell, but rather 4 outputs each starting a new 'seed' fusion group which cannot merge with the other fusion groups.

One approach would be to have better handling for merging adjacent fusion groups together. This can get tricky - for instance, 7 can merge with 9, but only if you observe that 8 can go after 9. We should do this eventually, but we don't need to do it now.

The approach I want to go with is to allow Concat nodes to be an exit node of a FusionGroup. This fixes the issue above. Unlike simply fixing fusions it also makes sure the Concat is not done in a separate pass (which adds kernel launches and uses more memory).

If we allow this then the trace is what we want:

https://rawgit.com/zdevito/104ea16a7234e5688fc62b87cc4da711/raw/4e330d888dad4fef8ae949ffe6e3856dd5ba3faf/after.html

It is valid to fuse a Concat into a group as long as the output of concat (which is no longer the simple map size) is not used in the group, and each individual element being concatenated is the same size (which will be true in this case). The implementation strategy is pretty easy as well: allocate the Concat output before the fusion group runs. Narrow the tensors that form the body of the Concat and then pass those into the fusion kernel as normal outputs.

Finally, a thought about fusions: if we have a valid fusion in the forward pass, then there is always a corresponding fusion for the backward. The gradient of a simple map is still a simple map. This suggests that if we find a forward fusion we like, even if we didn't add new fusion heuristics, we should be able to find the fusion for the gradient. Or equivalently, there exists a dual of our fusion engine that works by fusing consumers into producers scanning in the opposite direction as our current pass.

@zdevito zdevito force-pushed the pr/fuser_improvements branch 2 times, most recently from dba7474 to e8ec1a5 Compare September 13, 2017 07:17
@apaszke
Copy link
Collaborator

apaszke commented Sep 13, 2017

Good point with the dual fuser. We could add tracking of function -> derivative edges to the tracer natively (like we do with Handle edges, but now for all nodes), and slice off the backward fusion groups in this way. On the other hand backward fusion groups could also include simple maps appearing in neighboring ops.

@zdevito
Copy link
Collaborator Author

zdevito commented Sep 14, 2017

Some numbers now that the fusion in the backward pass works (caveat: because of #230 I cannot verify correctness, though it doesn't crash immediately...):

  • custom LSTM: 67ms/iteration (forward+backward)
  • our auto-fused LSTM: 70ms/iteration (forward+backward) (only a few % worse!)
  • naive LSTM: 130 ms/iteration (forward+backward)

@zdevito
Copy link
Collaborator Author

zdevito commented Sep 14, 2017

This still needs to be rebased onto the jit branch before it is ready to merge. I was waiting because of the hypothesis that there pybind issues. I'll see what happens tomorrow.

@ezyang ezyang force-pushed the pr/fuser_improvements branch from e966676 to 6ea766f Compare September 14, 2017 14:10
@ezyang
Copy link
Owner

ezyang commented Sep 14, 2017

Force pushed rebase

@zdevito
Copy link
Collaborator Author

zdevito commented Sep 15, 2017

Some timings for the kernels themselves. These are not precisely the same because the boundaries are not precisely the same, but they are pretty close:

HAND FUSED:
Time(%)      Time     Calls       Avg       Min       Max  Name
1.56%  187.63ms     22330  8.4020us  7.7760us  15.521us  void THNN_CudaLSTMForward<float, unsigned int, int=-2>(TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, unsigned int, unsigned int)
1.08%  129.24ms     22322  5.7900us  5.3760us  13.472us  void THNN_CudaLSTMBackward<float, unsigned int, int=-2>(TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, TensorInfo<float, unsigned int>, unsigned int, unsigned int)

OURS:
1.16%  204.53ms     33460  6.1120us  5.6000us  15.456us  kernel_0 // roughly equivalent to forward
0.92%  161.96ms     31548  5.1330us  4.8320us  16.992us  kernel_5 //roughly equivalent to backward

@zdevito zdevito force-pushed the pr/fuser_improvements branch from 6ea766f to f44a9f2 Compare September 15, 2017 02:41
@zdevito zdevito changed the title [WIP] Fuser Improvements Fuser Improvements Sep 15, 2017
@zdevito zdevito force-pushed the pr/fuser_improvements branch from f44a9f2 to b16e0aa Compare September 15, 2017 02:51
@@ -191,6 +206,8 @@ void emitCompilationUnit(std::ostream & out,
body << format("auto ${node} = ${access};\n",env);
}
for(auto n : subgraph.nodes()) {
if(n->kind() == kConcat)
continue; // Concat nodes by narrowing the output Tensors before the kernel runs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should it be nodes -> works?

struct ConcatDesc {
size_t nSubtensors; // == 1 for outputs that are not concats, otherwise it is the number tensors concatenated
size_t dim; // dimension along which the concat occurs
std::unique_ptr<TensorDesc> subtensorDesc; // descriptor for the subtensor, if it exists
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, why is there a single subtensor descriptor for multiple subtensors? Do they have to be of the same size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They do, yes, because if they are not they cannot be produced by a simple map that works over one size. This is checked for in the optimization pass.

// meaning: stride[dim - 1] != stride[dim]*size[dim]
// so dim - 1 is no longer contiguous
cont[dim - 1] = false;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps assert that desc.size[dim] % nSubtensors == 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorDesc do not store actual sizes, just contiguity information, so while this assertion is true, it is hard to check. This is because the fusion group is not specialized on sizes, just contiguity.

void emitCompilationUnit(std::ostream & out,
const std::string & name,
AnnotatedGraph & agraph) {
std::vector<ConcatDesc> emitCompilationUnit(std::ostream & out,
Copy link
Collaborator

@apaszke apaszke Sep 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels a bit weird that we fill in concat descriptors only when we emit the code 😕 I'd feel tempted to inspect them earlier if I was debugging and would be really surprised to find that there are no descriptors

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emitCompilationUnit compilation unit is called in the constructor, so it is basically being initialized at the earliest possible moment.

@apaszke
Copy link
Collaborator

apaszke commented Sep 15, 2017

Also, I don't feel like committing that graph visualizer into PyTorch 😕 Can we just have a mode that emits JSON and then a separate script (you can put it in a gist) that has all HTML strings and can produce the visualization?

@zdevito
Copy link
Collaborator Author

zdevito commented Sep 15, 2017

I'm happy to hide the visualizer more (it shouldn't be in onnx.py, but I'd like it to be in the code so it is easily accessible to me. I needed it to figure out how to make our code go fast, and I will need it in the future.

@apaszke
Copy link
Collaborator

apaszke commented Sep 15, 2017

Is there a problem with doing what I suggested? Let's just have a JSON printer and the we can develop a whole toolkit for visualizing/inspecting this data on the side

@zdevito
Copy link
Collaborator Author

zdevito commented Sep 15, 2017

I don't want to have to dig up code that isn't with the repo just because I wanted to write a file to disk to debug the IR. I guess I don't see a serious cost to keeping this code in the repo, but I do see a serious one for not. Longer-term we can invest in better debugging, but right now that doesn't exist and I need something to use.

@apaszke
Copy link
Collaborator

apaszke commented Sep 16, 2017

I really think this code doesn't belong to the core repo, and I can't see any problem with keeping it on a side. It's not like you have a high overhead for using it. Just paste it in the place where you want to use it (probably word language model) and you're done.

@zdevito
Copy link
Collaborator Author

zdevito commented Sep 17, 2017

Well, I can't put it in the application, because I can't get to the code that dumps the IR between optimization passes. It is inside the Traceable class. I also want the benefits of source control for debugging code so that I don't have to dig up the right version of it everytime I want to do something with it.

@soumith
Copy link
Collaborator

soumith commented Sep 18, 2017

  1. the visualizer code goes into torch.utils, Adam it's pretty dumb to have to dig up this code everytime you have to visualize JIT traces (which we'll be doing pretty often).

  2. Having the core jit.py code only have a JSON dumper, and the utils code (or directly the html) consuming the html seems cleaner. Though at this particular point, speed of development is more important than stability, so dont give huge emphasis into such decisions.

@soumith
Copy link
Collaborator

soumith commented Sep 18, 2017

if the visualizer code can go into torch.contrib.utils, that'd be nicer than torch.utils. We should actively create and start using contrib for such things.

@zdevito zdevito force-pushed the pr/fuser_improvements branch from b16e0aa to c70d723 Compare September 18, 2017 22:05
@zdevito zdevito force-pushed the pr/fuser_improvements branch from c70d723 to 2dcd39a Compare September 18, 2017 22:06
@zdevito zdevito merged commit e5438cc into jit Sep 18, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants