diff --git a/torch/csrc/autograd/functions/jit_closure.cpp b/torch/csrc/autograd/functions/jit_closure.cpp index 93707b5eb..f8014ae31 100644 --- a/torch/csrc/autograd/functions/jit_closure.cpp +++ b/torch/csrc/autograd/functions/jit_closure.cpp @@ -120,8 +120,8 @@ struct EmitNull : public Function { struct LambdaFunction : public Function { LambdaFunction(int num_inputs, std::function fn) : fn(fn) { - is_executable = true; - num_inputs = num_inputs; + this->is_executable = true; + this->num_inputs = num_inputs; } virtual variable_list apply(const variable_list& inputs) { @@ -272,7 +272,7 @@ struct FusionGroupFunction : public Function { std::vector outputs; outputs.reserve(function->outputDescriptors().size()); for(auto & od : function->outputDescriptors()) { - outputs.push_back(at::CUDA(od.scalar_type).tensor(data.back().sizes())); + outputs.push_back(at::CUDA(od.scalar_type).tensor()); } function->launch(data, outputs); return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) { diff --git a/torch/csrc/jit/fusion_compiler.cpp b/torch/csrc/jit/fusion_compiler.cpp index 7e04f9379..58b8d68c2 100644 --- a/torch/csrc/jit/fusion_compiler.cpp +++ b/torch/csrc/jit/fusion_compiler.cpp @@ -151,9 +151,9 @@ const char * scalarTypeName(at::ScalarType type) { } } -void emitCompilationUnit(std::ostream & out, - const std::string & name, - AnnotatedGraph & agraph) { +std::vector emitCompilationUnit(std::ostream & out, + const std::string & name, + AnnotatedGraph & agraph) { Graph& subgraph = *agraph.graph; TemplateEnv env; env.s("kernelName",name); @@ -177,10 +177,25 @@ void emitCompilationUnit(std::ostream & out, for(auto p : subgraph.inputs()) emitFormal(p,agraph.input_desc[i++]); } + std::vector concat_desc; + std::vector flat_output_nodes; { size_t i = 0; - for(auto o : subgraph.outputs()) - emitFormal(o,agraph.output_desc[i++]); + for(auto o : subgraph.outputs()) { + auto & desc = agraph.output_desc[i++]; + if(o->kind() != kConcat) { + emitFormal(o, desc); + concat_desc.emplace_back(); + flat_output_nodes.push_back(o); + } else { + size_t nInputs = o->inputs().size(); + concat_desc.emplace_back(desc, nInputs, o->i(kaxis)); + for(auto c : o->inputs()) { + emitFormal(c, *concat_desc.back().subtensorDesc); + flat_output_nodes.push_back(c); + } + } + } } size_t formal_count = 0; for(auto p : subgraph.inputs()) { @@ -191,6 +206,8 @@ void emitCompilationUnit(std::ostream & out, body << format("auto ${node} = ${access};\n",env); } for(auto n : subgraph.nodes()) { + if(n->kind() == kConcat) + continue; // Concat nodes by narrowing the output Tensors before the kernel runs size_t i = 0; for(auto in : n->inputs()) { env.s(std::to_string(i++),nodeName(in)); @@ -199,7 +216,7 @@ void emitCompilationUnit(std::ostream & out, env.s("rhs",format(simple_map_ops.at(n->kind())(n),env)); body << format("auto ${node} = ${rhs};\n",env); } - for(auto o : subgraph.outputs()) { + for(auto o : flat_output_nodes) { env.d("formal",formal_count++); env.s("access",format("t${formal}.data[t${formal}_offset]",env)); env.s("node",nodeName(o)); @@ -209,6 +226,7 @@ void emitCompilationUnit(std::ostream & out, env.s("kernelBody",body.str()); env.v("formals",formals); out << compilation_unit_template.format(env); + return concat_desc; } //////////////////////////////////////////////////////////////////////////////// @@ -234,7 +252,7 @@ CompiledFusionFunction::CompiledFusionFunction(const std::string & name, Annotat JIT_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); std::stringstream cu; - codegen::emitCompilationUnit(cu, name, agraph); + concat_desc = codegen::emitCompilationUnit(cu, name, agraph); compliation_unit = cu.str(); nvrtcProgram program; @@ -325,19 +343,23 @@ void compressContiguous( void CompiledFusionFunction::launch(at::ArrayRef inputs, at::ArrayRef outputs) { JIT_ASSERT(inputs.size() == input_desc.size()); JIT_ASSERT(outputs.size() == output_desc.size()); + size_t flat_outputs_size = 0; + for(auto & c : concat_desc) + flat_outputs_size += c.nSubtensors; // XXX: this code assumes that inputs are 32-bit addressable // XXX: this code assumes that all inputs are of the same size JIT_ASSERT(inputs[0].numel() <= std::numeric_limits::max()); uint32_t numel = inputs[0].numel(); + at::IntList map_size = inputs[0].sizes(); // Compute the storage needed to store TensorInfo structs for inputs and outputs. size_t uncompressedDim = input_desc.at(0).contiguity.size(); size_t maxPossibleTensorInfoSize = sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim; - size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (inputs.size() + outputs.size()); + size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (inputs.size() + flat_outputs_size); std::vector buffer(maxPossibleBufferSize); char * buffer_next = buffer.data(); // A vector of arguments to the kernel. It's (numel, *input_descs, *output_descs) std::vector arguments; - arguments.reserve(1 + inputs.size() + outputs.size()); + arguments.reserve(1 + inputs.size() + flat_outputs_size); // Asserts that t's dims can be compressed in the same way as in desc // (that's what the kernel assumes), and appends it to the arguments vector. auto addTensorInfo = [&](TensorDesc & desc, const at::Tensor & t) { @@ -352,8 +374,28 @@ void CompiledFusionFunction::launch(at::ArrayRef inputs, at::ArrayRe arguments.push_back(&numel); for (std::size_t i = 0; i < input_desc.size(); ++i) addTensorInfo(input_desc[i], inputs[i]); - for (std::size_t i = 0; i < output_desc.size(); ++i) - addTensorInfo(output_desc[i], outputs[i]); + for (std::size_t i = 0; i < output_desc.size(); ++i) { + auto & c = concat_desc[i]; + at::Tensor o = outputs[i]; + if(c.nSubtensors == 1) { + o.resize_(map_size); + addTensorInfo(output_desc[i], outputs[i]); + } else { + size_t small_size = map_size[c.dim]; + std::vector concat_size(map_size.begin(), map_size.end()); + concat_size[c.dim] = small_size * c.nSubtensors; + o.resize_(concat_size); + size_t offset = 0; + for(size_t j = 0; j < c.nSubtensors; ++j) { + // because the concatenated_output stays live, the underlying data + // in this view remains live through the end of this function + // so there is not need to hold onto this tensor + auto view = o.narrow(c.dim, offset, small_size); + addTensorInfo(*c.subtensorDesc, view); + offset += small_size; + } + } + } launch(numel, arguments.data()); } diff --git a/torch/csrc/jit/fusion_compiler.h b/torch/csrc/jit/fusion_compiler.h index 2b81f5d93..b3ca24ec1 100644 --- a/torch/csrc/jit/fusion_compiler.h +++ b/torch/csrc/jit/fusion_compiler.h @@ -18,11 +18,13 @@ struct TensorDesc { at::ScalarType scalar_type; std::vector contiguity; - TensorDesc(const at::ScalarType& type, const at::IntList& sizes, const at::IntList& strides) - : scalar_type(type) - , contiguity(TensorDesc::findContiguous(sizes, strides)) { + TensorDesc(const at::ScalarType& type, const std::vector& contiguity) + : scalar_type(type), contiguity(contiguity) { nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + (lastIsContiguous() ? 1 : 0); } + + TensorDesc(const at::ScalarType& type, const at::IntList& sizes, const at::IntList& strides) + : TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {} TensorDesc(const at::Tensor& t) : TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {} TensorDesc(TensorType *type) @@ -56,6 +58,27 @@ struct AnnotatedGraph { std::vector output_desc; }; +struct ConcatDesc { + size_t nSubtensors; // == 1 for outputs that are not concats, otherwise it is the number tensors concatenated + size_t dim; // dimension along which the concat occurs + std::unique_ptr subtensorDesc; // descriptor for the subtensor, if it exists + ConcatDesc() + : nSubtensors(1), dim(0) {} + ConcatDesc(const TensorDesc & desc, size_t nSubtensors, size_t dim) + : nSubtensors(nSubtensors), dim(dim) { + JIT_ASSERT(nSubtensors > 1); + std::vector cont = desc.contiguity; + if(dim > 0) { + // when we narrow the concatenated output + // we make the size[dim] smaller while keeping the stride[dim] the same, + // meaning: stride[dim - 1] != stride[dim]*size[dim] + // so dim - 1 is no longer contiguous + cont[dim - 1] = false; + } + subtensorDesc.reset(new TensorDesc(desc.scalar_type, cont)); + } +}; + struct CompiledFusionFunction { TH_DISALLOW_COPY_AND_ASSIGN(CompiledFusionFunction); @@ -84,6 +107,11 @@ struct CompiledFusionFunction { std::vector input_desc; std::vector output_desc; + + // same size as output_desc, describes whether + // an output is actually a concatenation of + // many subtensors that the fusion group produces + std::vector concat_desc; }; // caching compiler diff --git a/torch/csrc/jit/interned_strings.h b/torch/csrc/jit/interned_strings.h index 10bb976c6..5cd8897f3 100644 --- a/torch/csrc/jit/interned_strings.h +++ b/torch/csrc/jit/interned_strings.h @@ -29,7 +29,6 @@ _(Transpose) \ _(Concat) \ _(Reshape) \ _(split) \ -_(Dim) \ _(Offset) \ _(value) \ _(Subgraph) \ diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 7ccba8e64..35e2bc98c 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -12,7 +12,6 @@ std::unordered_set simple_mappable = { kAdd, kNeg, kAddConstant, - kConcat, }; bool isSimpleMap(Node *node) { @@ -42,6 +41,27 @@ struct GraphFuser { return isSimpleMap(node) && isCuda(node); } + // Can this node produce an _output_ of a fusion group? + // all Fusable nodes can do this, but additionally Concat, which normally cannot be fused + // because it is not a simple map, can be put in a fusion group + // as long as no items in the group read the output of concat + bool isFusableAsExitNode(Node * node) { + if(isFusable(node)) + return true; + if(node->kind() != kConcat || !isCuda(node)) + return false; + + // this concat fusion only works when all the inputs are the same size + // otherwise they cannot partipate in the same map + auto sizes = node->inputs().at(0)->type()->expect()->sizes(); + for(auto i : node->inputs()) { + if(sizes != i->type()->expect()->sizes()){ + return false; + } + } + return true; + } + // necessary condition for fusion. If all of the uses of producer are consumer // then it is safe to merge producer into consumer, because it doesn't have any other uses // If there are other uses, but they occur _after_ consumer, then we can still merge in producer @@ -232,7 +252,7 @@ struct GraphFuser { // returns where to continue scanning graph_node_list::iterator scanNode(Node * consumer) { graph->setStage(consumer->stage()); - if(isFusable(consumer)) { + if(isFusableAsExitNode(consumer)) { // handle inputs in reverse topological order as well... // otherwise in f(a,a+b) it will appear a is used twice if we consider // the f-a fusion before the f-(a+b) fusion first. diff --git a/torch/csrc/jit/test_jit.cpp b/torch/csrc/jit/test_jit.cpp index 733de0351..bfaff5e64 100644 --- a/torch/csrc/jit/test_jit.cpp +++ b/torch/csrc/jit/test_jit.cpp @@ -173,7 +173,34 @@ static void fusionTests() { testOne(0,1,1,2); testOne(1,2,0,2); + + + auto testConcat = [&](int dim) { + Graph graph; + Node * i0 = graph.addInput(); + Node * i1 = graph.addInput(); + auto o0 = appendNewNode(kMul,graph,{i0, i1}); + graph.registerOutput(o0); + graph.registerOutput(appendNewNode(kConcat, graph, {i0,o0})->i_(kaxis, dim)); + auto a = at::CUDA(at::kFloat).rand({3,4,5}); + auto b = at::CUDA(at::kFloat).rand({4,3,5}).transpose(0,1); + auto o = at::CUDA(at::kFloat).zeros({3,4,5}); + + auto o_r = a*b; + auto o2_r = at::cat({a, o_r}, dim); + auto o2 = at::CUDA(at::kFloat).zeros(o2_r.sizes()); + comp.debugLaunchGraph(graph, {a,b}, {o, o2}); + + float max_diff = (o_r - o).abs().max().toDouble(); + JIT_ASSERT(max_diff == 0); + float max_diff2 = (o2_r - o2).abs().max().toDouble(); + JIT_ASSERT(max_diff2 == 0); + }; + testConcat(0); + testConcat(1); + testConcat(2); } + #else //WITH_CUDA void fusionTests() {} #endif