Skip to content

Commit

Permalink
allow Concat operators to be the final operator in a fusion group, an…
Browse files Browse the repository at this point in the history
…d update the fusion compiler to support code that includes final concats
  • Loading branch information
zdevito committed Sep 14, 2017
1 parent b570758 commit e966676
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 20 deletions.
6 changes: 3 additions & 3 deletions torch/csrc/autograd/functions/jit_closure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ struct EmitNull : public Function {
struct LambdaFunction : public Function {
LambdaFunction(int num_inputs, std::function<variable_list(const variable_list&)> fn)
: fn(fn) {
is_executable = true;
num_inputs = num_inputs;
this->is_executable = true;
this->num_inputs = num_inputs;
}

virtual variable_list apply(const variable_list& inputs) {
Expand Down Expand Up @@ -272,7 +272,7 @@ struct FusionGroupFunction : public Function {
std::vector<at::Tensor> outputs;
outputs.reserve(function->outputDescriptors().size());
for(auto & od : function->outputDescriptors()) {
outputs.push_back(at::CUDA(od.scalar_type).tensor(data.back().sizes()));
outputs.push_back(at::CUDA(od.scalar_type).tensor());
}
function->launch(data, outputs);
return wrap_outputs(inputs, std::move(outputs), [](FunctionFlags f) {
Expand Down
64 changes: 53 additions & 11 deletions torch/csrc/jit/fusion_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ const char * scalarTypeName(at::ScalarType type) {
}
}

void emitCompilationUnit(std::ostream & out,
const std::string & name,
AnnotatedGraph & agraph) {
std::vector<ConcatDesc> emitCompilationUnit(std::ostream & out,
const std::string & name,
AnnotatedGraph & agraph) {
Graph& subgraph = *agraph.graph;
TemplateEnv env;
env.s("kernelName",name);
Expand All @@ -177,10 +177,25 @@ void emitCompilationUnit(std::ostream & out,
for(auto p : subgraph.inputs())
emitFormal(p,agraph.input_desc[i++]);
}
std::vector<ConcatDesc> concat_desc;
std::vector<Node*> flat_output_nodes;
{
size_t i = 0;
for(auto o : subgraph.outputs())
emitFormal(o,agraph.output_desc[i++]);
for(auto o : subgraph.outputs()) {
auto & desc = agraph.output_desc[i++];
if(o->kind() != kConcat) {
emitFormal(o, desc);
concat_desc.emplace_back();
flat_output_nodes.push_back(o);
} else {
size_t nInputs = o->inputs().size();
concat_desc.emplace_back(desc, nInputs, o->i(kaxis));
for(auto c : o->inputs()) {
emitFormal(c, *concat_desc.back().subtensorDesc);
flat_output_nodes.push_back(c);
}
}
}
}
size_t formal_count = 0;
for(auto p : subgraph.inputs()) {
Expand All @@ -191,6 +206,8 @@ void emitCompilationUnit(std::ostream & out,
body << format("auto ${node} = ${access};\n",env);
}
for(auto n : subgraph.nodes()) {
if(n->kind() == kConcat)
continue; // Concat nodes by narrowing the output Tensors before the kernel runs
size_t i = 0;
for(auto in : n->inputs()) {
env.s(std::to_string(i++),nodeName(in));
Expand All @@ -199,7 +216,7 @@ void emitCompilationUnit(std::ostream & out,
env.s("rhs",format(simple_map_ops.at(n->kind())(n),env));
body << format("auto ${node} = ${rhs};\n",env);
}
for(auto o : subgraph.outputs()) {
for(auto o : flat_output_nodes) {
env.d("formal",formal_count++);
env.s("access",format("t${formal}.data[t${formal}_offset]",env));
env.s("node",nodeName(o));
Expand All @@ -209,6 +226,7 @@ void emitCompilationUnit(std::ostream & out,
env.s("kernelBody",body.str());
env.v("formals",formals);
out << compilation_unit_template.format(env);
return concat_desc;
}

////////////////////////////////////////////////////////////////////////////////
Expand All @@ -234,7 +252,7 @@ CompiledFusionFunction::CompiledFusionFunction(const std::string & name, Annotat
JIT_CUDA_CHECK(cudaGetDeviceProperties(&prop, device));

std::stringstream cu;
codegen::emitCompilationUnit(cu, name, agraph);
concat_desc = codegen::emitCompilationUnit(cu, name, agraph);
compliation_unit = cu.str();

nvrtcProgram program;
Expand Down Expand Up @@ -325,19 +343,23 @@ void compressContiguous(
void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, at::ArrayRef<at::Tensor> outputs) {
JIT_ASSERT(inputs.size() == input_desc.size());
JIT_ASSERT(outputs.size() == output_desc.size());
size_t flat_outputs_size = 0;
for(auto & c : concat_desc)
flat_outputs_size += c.nSubtensors;
// XXX: this code assumes that inputs are 32-bit addressable
// XXX: this code assumes that all inputs are of the same size
JIT_ASSERT(inputs[0].numel() <= std::numeric_limits<uint32_t>::max());
uint32_t numel = inputs[0].numel();
at::IntList map_size = inputs[0].sizes();
// Compute the storage needed to store TensorInfo structs for inputs and outputs.
size_t uncompressedDim = input_desc.at(0).contiguity.size();
size_t maxPossibleTensorInfoSize = sizeof(TensorInfo) + 2 * sizeof(uint32_t) * uncompressedDim;
size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (inputs.size() + outputs.size());
size_t maxPossibleBufferSize = maxPossibleTensorInfoSize * (inputs.size() + flat_outputs_size);
std::vector<char> buffer(maxPossibleBufferSize);
char * buffer_next = buffer.data();
// A vector of arguments to the kernel. It's (numel, *input_descs, *output_descs)
std::vector<void*> arguments;
arguments.reserve(1 + inputs.size() + outputs.size());
arguments.reserve(1 + inputs.size() + flat_outputs_size);
// Asserts that t's dims can be compressed in the same way as in desc
// (that's what the kernel assumes), and appends it to the arguments vector.
auto addTensorInfo = [&](TensorDesc & desc, const at::Tensor & t) {
Expand All @@ -352,8 +374,28 @@ void CompiledFusionFunction::launch(at::ArrayRef<at::Tensor> inputs, at::ArrayRe
arguments.push_back(&numel);
for (std::size_t i = 0; i < input_desc.size(); ++i)
addTensorInfo(input_desc[i], inputs[i]);
for (std::size_t i = 0; i < output_desc.size(); ++i)
addTensorInfo(output_desc[i], outputs[i]);
for (std::size_t i = 0; i < output_desc.size(); ++i) {
auto & c = concat_desc[i];
at::Tensor o = outputs[i];
if(c.nSubtensors == 1) {
o.resize_(map_size);
addTensorInfo(output_desc[i], outputs[i]);
} else {
size_t small_size = map_size[c.dim];
std::vector<int64_t> concat_size(map_size.begin(), map_size.end());
concat_size[c.dim] = small_size * c.nSubtensors;
o.resize_(concat_size);
size_t offset = 0;
for(size_t j = 0; j < c.nSubtensors; ++j) {
// because the concatenated_output stays live, the underlying data
// in this view remains live through the end of this function
// so there is not need to hold onto this tensor
auto view = o.narrow(c.dim, offset, small_size);
addTensorInfo(*c.subtensorDesc, view);
offset += small_size;
}
}
}
launch(numel, arguments.data());
}

Expand Down
34 changes: 31 additions & 3 deletions torch/csrc/jit/fusion_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ struct TensorDesc {
at::ScalarType scalar_type;
std::vector<bool> contiguity;

TensorDesc(const at::ScalarType& type, const at::IntList& sizes, const at::IntList& strides)
: scalar_type(type)
, contiguity(TensorDesc::findContiguous(sizes, strides)) {
TensorDesc(const at::ScalarType& type, const std::vector<bool>& contiguity)
: scalar_type(type), contiguity(contiguity) {
nDim_ = std::count(contiguity.begin(), contiguity.end(), false) + (lastIsContiguous() ? 1 : 0);
}

TensorDesc(const at::ScalarType& type, const at::IntList& sizes, const at::IntList& strides)
: TensorDesc(type, TensorDesc::findContiguous(sizes, strides)) {}
TensorDesc(const at::Tensor& t)
: TensorDesc(t.type().scalarType(), t.sizes(), t.strides()) {}
TensorDesc(TensorType *type)
Expand Down Expand Up @@ -56,6 +58,27 @@ struct AnnotatedGraph {
std::vector<TensorDesc> output_desc;
};

struct ConcatDesc {
size_t nSubtensors; // == 1 for outputs that are not concats, otherwise it is the number tensors concatenated
size_t dim; // dimension along which the concat occurs
std::unique_ptr<TensorDesc> subtensorDesc; // descriptor for the subtensor, if it exists
ConcatDesc()
: nSubtensors(1), dim(0) {}
ConcatDesc(const TensorDesc & desc, size_t nSubtensors, size_t dim)
: nSubtensors(nSubtensors), dim(dim) {
JIT_ASSERT(nSubtensors > 1);
std::vector<bool> cont = desc.contiguity;
if(dim > 0) {
// when we narrow the concatenated output
// we make the size[dim] smaller while keeping the stride[dim] the same,
// meaning: stride[dim - 1] != stride[dim]*size[dim]
// so dim - 1 is no longer contiguous
cont[dim - 1] = false;
}
subtensorDesc.reset(new TensorDesc(desc.scalar_type, cont));
}
};

struct CompiledFusionFunction {
TH_DISALLOW_COPY_AND_ASSIGN(CompiledFusionFunction);

Expand Down Expand Up @@ -84,6 +107,11 @@ struct CompiledFusionFunction {

std::vector<TensorDesc> input_desc;
std::vector<TensorDesc> output_desc;

// same size as output_desc, describes whether
// an output is actually a concatenation of
// many subtensors that the fusion group produces
std::vector<ConcatDesc> concat_desc;
};

// caching compiler
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ _(Transpose) \
_(Concat) \
_(Reshape) \
_(split) \
_(Dim) \
_(Offset) \
_(value) \
_(Subgraph) \
Expand Down
24 changes: 22 additions & 2 deletions torch/csrc/jit/passes/graph_fuser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ std::unordered_set<NodeKind> simple_mappable = {
kAdd,
kNeg,
kAddConstant,
kConcat,
};

bool isSimpleMap(Node *node) {
Expand Down Expand Up @@ -42,6 +41,27 @@ struct GraphFuser {
return isSimpleMap(node) && isCuda(node);
}

// Can this node produce an _output_ of a fusion group?
// all Fusable nodes can do this, but additionally Concat, which normally cannot be fused
// because it is not a simple map, can be put in a fusion group
// as long as no items in the group read the output of concat
bool isFusableAsExitNode(Node * node) {
if(isFusable(node))
return true;
if(node->kind() != kConcat || !isCuda(node))
return false;

// this concat fusion only works when all the inputs are the same size
// otherwise they cannot partipate in the same map
auto sizes = node->inputs().at(0)->type()->expect<TensorType>()->sizes();
for(auto i : node->inputs()) {
if(sizes != i->type()->expect<TensorType>()->sizes()){
return false;
}
}
return true;
}

// necessary condition for fusion. If all of the uses of producer are consumer
// then it is safe to merge producer into consumer, because it doesn't have any other uses
// If there are other uses, but they occur _after_ consumer, then we can still merge in producer
Expand Down Expand Up @@ -232,7 +252,7 @@ struct GraphFuser {
// returns where to continue scanning
graph_node_list::iterator scanNode(Node * consumer) {
graph->setStage(consumer->stage());
if(isFusable(consumer)) {
if(isFusableAsExitNode(consumer)) {
// handle inputs in reverse topological order as well...
// otherwise in f(a,a+b) it will appear a is used twice if we consider
// the f-a fusion before the f-(a+b) fusion first.
Expand Down
27 changes: 27 additions & 0 deletions torch/csrc/jit/test_jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,34 @@ static void fusionTests() {
testOne(0,1,1,2);
testOne(1,2,0,2);



auto testConcat = [&](int dim) {
Graph graph;
Node * i0 = graph.addInput();
Node * i1 = graph.addInput();
auto o0 = appendNewNode(kMul,graph,{i0, i1});
graph.registerOutput(o0);
graph.registerOutput(appendNewNode(kConcat, graph, {i0,o0})->i_(kaxis, dim));
auto a = at::CUDA(at::kFloat).rand({3,4,5});
auto b = at::CUDA(at::kFloat).rand({4,3,5}).transpose(0,1);
auto o = at::CUDA(at::kFloat).zeros({3,4,5});

auto o_r = a*b;
auto o2_r = at::cat({a, o_r}, dim);
auto o2 = at::CUDA(at::kFloat).zeros(o2_r.sizes());
comp.debugLaunchGraph(graph, {a,b}, {o, o2});

float max_diff = (o_r - o).abs().max().toDouble();
JIT_ASSERT(max_diff == 0);
float max_diff2 = (o2_r - o2).abs().max().toDouble();
JIT_ASSERT(max_diff2 == 0);
};
testConcat(0);
testConcat(1);
testConcat(2);
}

#else //WITH_CUDA
void fusionTests() {}
#endif
Expand Down

0 comments on commit e966676

Please sign in to comment.