Skip to content

Commit

Permalink
Merge pull request #111 from pfnet-research/refactor_generic_context
Browse files Browse the repository at this point in the history
Refactor generic context
  • Loading branch information
okdshin authored Sep 28, 2018
2 parents 59f356d + 1fb9702 commit 23287ce
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,40 @@ namespace menoh_impl {
assert(is_found_from_other_context);
} while(false);
}
std::vector<array> output_list;
for(auto const& output_name : node.output_name_list) {
auto found = required_output_table.find(output_name);
if(found == required_output_table.end()) {
// allocate new array by using profile
output_list.push_back(
array(output_profile_table.at(output_name)));
} else {
// use already allocated array
output_list.push_back(found->second);
}
}

procedure op_proc;
std::vector<std::pair<std::string, array>> new_outputs;
try {
auto factory =
procedure_factory_table_.at(node.op_type);
std::tie(op_proc, new_outputs) =
factory.operator()(current_index, node_list,
input_list, required_output_table);
} catch(...) { break; }
op_proc =
factory.operator()(node, input_list, output_list);
} catch(std::exception const& e) {
*logger << e.what() << std::endl;
break;
}
new_op_proc_list.push_back(op_proc);
procedure_list.insert(
procedure_list.end(),
std::make_move_iterator(new_copy_procedure_list.begin()),
std::make_move_iterator(new_copy_procedure_list.end()));
variable_table_.insert(
std::make_move_iterator(new_outputs.begin()),
std::make_move_iterator(new_outputs.end()));

assert(node.output_name_list.size() == output_list.size());
for(int i = 0; i < node.output_name_list.size(); ++i) {
variable_table_.emplace(node.output_name_list.at(i),
output_list.at(i));
}
}

// when no nodes are processed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ namespace menoh_impl {
return variable_table_.at(name);
}

using procedure_factory = std::function<std::tuple<
procedure, std::vector<std::pair<std::string, array>>>(
int, std::vector<node> const&, std::vector<array> const&,
std::unordered_map<std::string, array> const&)>;
using procedure_factory = std::function<procedure(
node const&, // node
std::vector<array> const&, // input list
std::vector<array> const& // output list
)>;
optional<std::function<void()>>
try_to_get_input_from_common_table(
std::string const& input_name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,46 +1,27 @@
#ifndef MENOH_IMPL_MKLDNN_WITH_GENERIC_FALLBACK_BACKEND_BACKEND_GENERIC_OPERATOR_RELU_HPP
#define MENOH_IMPL_MKLDNN_WITH_GENERIC_FALLBACK_BACKEND_BACKEND_GENERIC_OPERATOR_RELU_HPP

#include <menoh/array.hpp>
#include <menoh/mkldnn_with_generic_fallback/procedure.hpp>

namespace menoh_impl {
namespace mkldnn_with_generic_fallback_backend {
namespace generic_backend {
inline std::tuple<procedure,
std::vector<std::pair<std::string, array>>>
make_relu(int node_index, std::vector<node> const& node_list,
std::vector<array> const& input_list,
std::unordered_map<std::string, array> const&
required_output_table) {
inline procedure make_relu(node const& node,
std::vector<array> const& input_list,
std::vector<array> const& output_list) {
assert(input_list.size() == 1);
auto const& node = node_list.at(node_index);
assert(output_list.size() == 1);

auto const& x_arr = input_list.at(0);

auto found =
required_output_table.find(node.output_name_list.at(0));
optional<array> output_opt;
if(found == required_output_table.end()) {
output_opt = array(dtype_t::float_,
x_arr.dims()); // TODO check inplace-able
} else {
output_opt =
found->second; // output is required so not inplace-able
}

auto procedure = [x_arr, output = *output_opt]() {
for(decltype(total_size(x_arr)) i = 0;
i < total_size(x_arr); ++i) {
fat(output, i) = std::max(fat(x_arr, i), 0.f);
auto procedure = [input = input_list.at(0),
output = output_list.at(0)]() {
for(decltype(total_size(input)) i = 0;
i < total_size(input); ++i) {
fat(output, i) = std::max(fat(input, i), 0.f);
}
};

std::vector<std::pair<std::string, array>> outputs;
if(found == required_output_table.end()) {
outputs.push_back(std::pair<std::string, array>(
node.output_name_list.at(0), *output_opt));
}
return std::make_tuple(procedure, outputs);
return procedure;
}

} // namespace generic_backend
Expand Down
6 changes: 3 additions & 3 deletions menoh/mkldnn_with_generic_fallback/model_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ namespace menoh_impl {
}
// if any context can not process the node
if(!is_found) {
*logger_ << "failed to interpret"
<< graph.node_list().at(current_index).op_type
<< "with all context";
*logger_
<< "failed to interpret: no contexts can interpret '"
<< node.op_type << "'";
throw unsupported_operator(node.op_type);
}
}
Expand Down

0 comments on commit 23287ce

Please sign in to comment.