-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
115 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
#pragma once | ||
|
||
#include <iostream> | ||
#include <boost/yap/yap.hpp> | ||
#include <boost/yap/print.hpp> | ||
#include <boost/yap/expression.hpp> | ||
#include <boost/hana.hpp> | ||
#include "xforms/AllocTensor.hpp" | ||
#include "xforms/CodeGen.hpp" | ||
|
||
template <typename ...Exprs> | ||
struct ExprBlock { | ||
// static_assert(all_are yap_expr_type) | ||
boost::hana::tuple<Exprs...> expr_block_; | ||
|
||
constexpr ExprBlock(Exprs const &... exprs) : expr_block_(exprs...) {} | ||
|
||
constexpr ExprBlock(ExprBlock const& other) : expr_block_(other.expr_block_) {} | ||
|
||
constexpr ExprBlock(ExprBlock&& other) noexcept : expr_block_(other.expr_block_) {} | ||
|
||
// template <typename... BodyExprs> | ||
// constexpr auto operator[] (BodyExprs &&... body_exprs) const { | ||
// auto body_list = hana::tuple(body_exprs...); | ||
// } | ||
|
||
template <typename ...Args> | ||
constexpr auto gen_code(Args &&... args) const { | ||
namespace yap = boost::yap; | ||
|
||
// First replace all the placeholders in expr with args | ||
auto ir_list = hana::transform(expr_block_, [&args...](auto const& expr) { | ||
return yap::replace_placeholders(expr, static_cast<Args &&>(args)...); | ||
}); | ||
// auto constexpr dumping = Dumping(); | ||
// if constexpr (need_dump(dumping)) { | ||
// yap::print(std::cout, ast); | ||
// } | ||
|
||
// Go through a set of transforms | ||
auto xforms = hana::make_tuple( | ||
CodeGen() | ||
); | ||
|
||
// Call each xform's transform() method. Each xform's output serves as input of next xform. | ||
auto codes = hana::fold_left(xforms, ir_list, [](auto &&ir, auto &&xform) { | ||
return xform.transform(ir, with_dump{}); | ||
}); | ||
|
||
return codes; | ||
} | ||
|
||
friend std::ostream& operator<< (std::ostream &os, ExprBlock const& L) { | ||
namespace hana = boost::hana; | ||
hana::for_each(L.expr_block_, [&os](auto const &expr) { | ||
boost::yap::print(os, expr); | ||
}); | ||
return os; | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#include "gtest/gtest.h" | ||
#include "ExprBlock.hpp" | ||
using namespace boost::yap::literals; | ||
|
||
|
||
TEST(TestExprList, Test1) { | ||
auto list1 = ExprBlock { | ||
1_p + 2_p, | ||
2_p * 3_p | ||
}; | ||
std::cout << list1 << std::endl; | ||
|
||
list1.gen_code(1, 2, 3); | ||
} | ||
|
||
TEST(TestExprBlock, Test2) { | ||
auto format1 = make_format(Dim2(2, 4), RowMajorLayout()); | ||
auto tensor1 = TensorE(float(), format1, MemSpace::GM(), 0x10); | ||
auto tensor2 = TensorE(float(), format1, MemSpace::GM(), 0x20); | ||
auto tensor3 = TensorE(float(), format1, MemSpace::GM(), 0x30); | ||
auto temp_1 = TensorE(float(), format1, MemSpace::GM(), 0x40); | ||
|
||
auto add_mul = ExprBlock { | ||
temp_1 = tensor1 + tensor2, | ||
tensor3 = temp_1 * tensor3 | ||
}; | ||
|
||
add_mul.gen_code()(); | ||
} | ||
|
||
//TEST(TestExprBlock, Test3) { | ||
// auto format1 = make_format(Dim2(2, 4), RowMajorLayout()); | ||
// auto tensor1 = Tensor(float(), format1, MemSpace::GM(), 0x10); | ||
// auto tensor2 = Tensor(float(), format1, MemSpace::GM(), 0x20); | ||
// auto tensor3 = Tensor(float(), format1, MemSpace::GM(), 0x30); | ||
// auto temp_1 = Tensor(float(), format1, MemSpace::GM(), 0x40); | ||
// | ||
// auto add_mul = ExprBlock { | ||
// _1 = 1_p + 2_p, | ||
// // sync(), | ||
// _1 * 3_p | ||
// }; | ||
// | ||
// add_mul.gen_code(tensor1, tensor2, tensor3)(); | ||
//} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters