diff --git a/.gitignore b/.gitignore index 16389f34e..5511edc86 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ CMakeCache.txt doc apps/tensor_times_vector/tensor_times_vector +tags diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index cacd1411c..bd177c24c 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -22,6 +22,8 @@ #include "taco/lower/iterator.h" #include "taco/index_notation/provenance_graph.h" +#include "taco/linalg_notation/linalg_notation_nodes_abstract.h" + namespace taco { class Type; @@ -59,6 +61,9 @@ struct SuchThatNode; class IndexExprVisitorStrict; class IndexStmtVisitorStrict; +struct VarNode; +class LinalgAssignment; + /// A tensor index expression describes a tensor computation as a scalar /// expression where tensors are indexed by index variables (`IndexVar`). The /// index variables range over the tensor dimensions they index, and the scalar diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 95439cd6b..e7374a4b2 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -68,6 +68,13 @@ struct NegNode : public UnaryExprNode { } }; +struct TransposeNode : public UnaryExprNode { + TransposeNode(IndexExpr operand) : UnaryExprNode(operand) {} + + void accept (IndexExprVisitorStrict* v) const { + v->visit(this); + } +}; struct BinaryExprNode : public IndexExprNode { virtual std::string getOperatorString() const = 0; diff --git a/include/taco/linalg.h b/include/taco/linalg.h new file mode 100644 index 000000000..3664547b0 --- /dev/null +++ b/include/taco/linalg.h @@ -0,0 +1,271 @@ +#ifndef TACO_LINALG_H +#define TACO_LINALG_H + +#include "taco/type.h" +#include "taco/tensor.h" +#include "taco/format.h" + +#include "taco/linalg_notation/linalg_notation.h" +#include "taco/linalg_notation/linalg_notation_nodes.h" +#include "taco/linalg_notation/linalg_notation_printer.h" + + +namespace taco { + +class LinalgBase : public LinalgExpr { +protected: + std::string name; + Type tensorType; + + LinalgAssignment assignment; + IndexStmt indexAssignment; + + int idxcount = 0; + std::vector indexVarNameList = {"i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"}; + + IndexExpr rewrite(LinalgExpr linalg, std::vector indices); + + IndexVar getUniqueIndex(); + + std::vector getUniqueIndices(size_t order); + +public: + LinalgBase(std::string name, Type tensorType, Datatype dtype, std::vector dims, Format format, bool isColVec = false); + + /// [LINALG NOTATION] + LinalgAssignment operator=(const LinalgExpr &expr); + + const LinalgAssignment getAssignment() const; + + const IndexStmt getIndexAssignment() const; + + IndexStmt rewrite(); + + typedef LinalgVarNode Node; +}; + +std::ostream &operator<<(std::ostream &os, const LinalgBase &linalg); + +IndexExpr rewrite(LinalgExpr linalg, std::vector); + +IndexStmt rewrite(LinalgStmt linalg); + +// ------------------------------------------------------------ +// Matrix class +// ------------------------------------------------------------ + +template +class Matrix : public LinalgBase { +public: + explicit Matrix(std::string name); + + Matrix(std::string name, size_t dim1, size_t dim2); + + Matrix(std::string name, std::vector dimensions); + + Matrix(std::string name, size_t dim1, size_t dim2, Format format); + + Matrix(std::string name, std::vector dimensions, Format format); + + Matrix(std::string name, size_t dim1, size_t dim2, ModeFormat format1, ModeFormat format2); + + LinalgAssignment operator=(const LinalgExpr &expr) { + return LinalgBase::operator=(expr); + } + + // Read method + CType at(int coord_x, int coord_y); + + // Write method + void insert(int coord_x, int coord_y, CType value); + + // ScalarAccess supports reading/assigning to single element + ScalarAccess operator()(int i, int j); + + // Access methods + const Access operator()(const IndexVar i, const IndexVar j) const; + Access operator()(const IndexVar i, const IndexVar j); + + // Allow to be cast to a TensorBase for the sake of ASSERT_TENSOR_EQ + operator TensorBase() const { return *tensorBase; } +}; + +// ------------------------------------------------------------ +// Matrix template method implementations +// ------------------------------------------------------------ + +template +Matrix::Matrix(std::string name) : + LinalgBase(name, Type(type(), {42, 42}), type(), {42, 42}, Format({dense, dense})) {} + +template +Matrix::Matrix(std::string name, std::vector dimensions) : + LinalgBase(name, Type(type(), Shape(std::vector(dimensions.begin(), dimensions.end()))), type(), std::vector(dimensions.begin(), dimensions.end()), Format({dense,dense})) {} + +template +Matrix::Matrix(std::string name, size_t dim1, size_t dim2) : + LinalgBase(name, Type(type(), {dim1, dim2}), type(), {(int) dim1, (int) dim2}, Format({dense,dense})) {} + +template +Matrix::Matrix(std::string name, size_t dim1, size_t dim2, Format format) : + LinalgBase(name, Type(type(), {dim1, dim2}), type(), {(int) dim1, (int) dim2}, format) {} + +template +Matrix::Matrix(std::string name, std::vector dimensions, Format format) : + LinalgBase(name, Type(type(), Shape(std::vector(dimensions.begin(), dimensions.end()))), type(), std::vector(dimensions.begin(), dimensions.end()), format) {} + +template +Matrix::Matrix(std::string name, size_t dim1, size_t dim2, ModeFormat format1, ModeFormat format2) : + LinalgBase(name, Type(type(), {dim1, dim2}), type(), {(int)dim1, (int)dim2}, Format({format1, format2}), false) {} + +// Definition of Read methods +template +CType Matrix::at(int coord_x, int coord_y) { + return tensorBase->at({coord_x, coord_y}); +} + +// Definition of Write methods +template +void Matrix::insert(int coord_x, int coord_y, CType value) { + tensorBase->insert({coord_x, coord_y}, value); +} + +template +ScalarAccess Matrix::operator()(int i, int j) { + return ScalarAccess(tensorBase, {i, j}); +} + +// Definition of Access methods +template +const Access Matrix::operator()(const IndexVar i, const IndexVar j) const { + return (*tensorBase)({i,j}); +} + +template +Access Matrix::operator()(const IndexVar i, const IndexVar j) { + return (*tensorBase)({i,j}); +} + + +// ------------------------------------------------------------ +// Vector class +// ------------------------------------------------------------ + +template +class Vector : public LinalgBase { + std::string name; + Datatype ctype; +public: + explicit Vector(std::string name); + + Vector(std::string name, int dim, bool isColVec = true); + + Vector(std::string name, size_t dim, Format format, bool isColVec = true); + + Vector(std::string name, size_t dim, ModeFormat format, bool isColVec = true); + + LinalgAssignment operator=(const LinalgExpr &expr) { + return LinalgBase::operator=(expr); + } + + // Support some Write methods + void insert(int coord, CType value); + + ScalarAccess operator()(int i); + + // Support some Read methods too + CType at(int coord); + + // Access methods for use in IndexExprs + const Access operator()(const IndexVar i) const; + Access operator()(const IndexVar i); + + // Allow to be cast to a TensorBase for the sake of ASSERT_TENSOR_EQ + operator TensorBase() const { return *tensorBase; } +}; + +// ------------------------------------------------------------ +// Vector template method implementations +// ------------------------------------------------------------ + +template +Vector::Vector(std::string name) : + LinalgBase(name, Type(type(), {42}), type(), {42}, Format({dense}), true) {} + +template +Vector::Vector(std::string name, int dim, bool isColVec) : + LinalgBase(name, Type(type(), {(size_t)dim}), type(), {(int)dim}, Format({dense}), isColVec) {} + +template +Vector::Vector(std::string name, size_t dim, Format format, bool isColVec) : LinalgBase(name, + Type(type(), {dim}), + type(), + {(int) dim}, + format, isColVec) {} + +template +Vector::Vector(std::string name, size_t dim, ModeFormat format, bool isColVec) : + LinalgBase(name, Type(type(), {dim}), type(), {(int)dim}, Format(format), isColVec) {} + +// Vector write methods +template +void Vector::insert(int coord, CType value) { + tensorBase->insert({coord}, value); +} + +template +ScalarAccess Vector::operator()(int i) { + return ScalarAccess(tensorBase, {i}); +} + +template +CType Vector::at(int coord) { + return tensorBase->at({coord}); +} + +// Definition of Access methods +template +const Access Vector::operator()(const IndexVar i) const { + return (*tensorBase)({i}); +} + +template +Access Vector::operator()(const IndexVar i) { + return (*tensorBase)({i}); +} + +// ------------------------------------------------------------ +// Scalar class +// ------------------------------------------------------------ + +template +class Scalar : public LinalgBase { + std::string name; + Datatype ctype; +public: + explicit Scalar(std::string name); + + LinalgAssignment operator=(const LinalgExpr &expr) { + return LinalgBase::operator=(expr); + } + + void operator=(const IndexExpr& expr) { + (*tensorBase) = expr; + } + + CType operator=(CType x) { + tensorBase->insert({}, x); + return x; + } + + operator CType() const { return tensorBase->at({}); } + + operator TensorBase() const { return *tensorBase; } +}; + +template +Scalar::Scalar(std::string name) : + LinalgBase(name, Type(type(), {}) , type(), {}, Format(), false) {} + +} // namespace taco +#endif diff --git a/include/taco/linalg_notation/linalg_notation.h b/include/taco/linalg_notation/linalg_notation.h new file mode 100644 index 000000000..8e761c1f3 --- /dev/null +++ b/include/taco/linalg_notation/linalg_notation.h @@ -0,0 +1,201 @@ +// +// Created by Olivia Hsu on 10/30/20. +// + +#ifndef TACO_LINALG_NOTATION_H +#define TACO_LINALG_NOTATION_H +#include +#include +#include +#include +#include +#include +#include + +#include "taco/format.h" +#include "taco/error.h" +#include "taco/util/intrusive_ptr.h" +#include "taco/util/comparable.h" +#include "taco/type.h" +#include "taco/ir/ir.h" +#include "taco/codegen/module.h" + +#include "taco/ir_tags.h" +#include "taco/lower/iterator.h" +#include "taco/index_notation/provenance_graph.h" + +#include "taco/linalg_notation/linalg_notation_nodes_abstract.h" +#include "taco/linalg.h" + +#include "taco/tensor.h" + +namespace taco { + +class Type; + +class Dimension; + +class Format; + +class Schedule; + +class TensorVar; + +class LinalgBase; + +class LinalgExpr; + +class LinalgAssignment; + +class Access; + +struct LinalgVarNode; +struct LinalgLiteralNode; +struct LinalgNegNode; +struct LinalgTransposeNode; +struct LinalgAddNode; +struct LinalgSubNode; +struct LinalgMatMulNode; +struct LinalgElemMulNode; +struct LinalgDivNode; +struct LinalgUnaryExprNode; +struct LinalgBinaryExprNode; + +class LinalgExprVisitorStrict; + + +class LinalgExpr : public util::IntrusivePtr { +public: + LinalgExpr() : util::IntrusivePtr(nullptr) {} + + LinalgExpr(const LinalgExprNode *n) : util::IntrusivePtr(n) {} + + /// Construct a scalar tensor access. + /// ``` + /// A(i,j) = b; + /// ``` + explicit LinalgExpr(TensorVar); + + LinalgExpr(TensorVar, bool isColVec, TensorBase* tensorBase); + + explicit LinalgExpr(TensorBase* _tensorBase, bool isColVec=false); + + LinalgExpr(TensorVar var, bool isColVec); + /// Consturct an integer literal. + /// ``` + /// A(i,j) = 1; + /// ``` + LinalgExpr(char); + + LinalgExpr(int8_t); + + LinalgExpr(int16_t); + + LinalgExpr(int32_t); + + LinalgExpr(int64_t); + + /// Consturct an unsigned integer literal. + /// ``` + /// A(i,j) = 1u; + /// ``` + LinalgExpr(uint8_t); + + LinalgExpr(uint16_t); + + LinalgExpr(uint32_t); + + LinalgExpr(uint64_t); + + /// Consturct double literal. + /// ``` + /// A(i,j) = 1.0; + /// ``` + LinalgExpr(float); + + LinalgExpr(double); + + /// Construct complex literal. + /// ``` + /// A(i,j) = complex(1.0, 1.0); + /// ``` + LinalgExpr(std::complex); + + LinalgExpr(std::complex); + + Datatype getDataType() const; + int getOrder() const; + bool isColVector() const; + void setColVector(bool) const; + + /// Visit the linalg expression's sub-expressions. + void accept(LinalgExprVisitorStrict *) const; + + /// Print the index expression. + friend std::ostream &operator<<(std::ostream &, const LinalgExpr &); + + TensorBase *tensorBase; +}; + +/// Compare two index expressions by value. +bool equals(LinalgExpr, LinalgExpr); + +/// Construct and returns an expression that negates this expression. +LinalgExpr operator-(const LinalgExpr&); + +/// Add two linear algebra expressions. +LinalgExpr operator+(const LinalgExpr&, const LinalgExpr&); + +/// Subtract a linear algebra expressions from another. +LinalgExpr operator-(const LinalgExpr&, const LinalgExpr&); + +/// Matrix Multiply two linear algebra expressions. +LinalgExpr operator*(const LinalgExpr&, const LinalgExpr&); + +/// Divide a linear expression by another. +LinalgExpr operator/(const LinalgExpr&, const LinalgExpr&); + +/// Element-wise multiply two linear algebra expressions +// FIXME: May want to be consistent with eigen library in c++ and change to cmul +LinalgExpr elemMul(const LinalgExpr& lhs, const LinalgExpr& rhs); + +/// Construct and returns an expression that transposes this expression +// FIXME: May want to change this with '^T' in the future +LinalgExpr transpose(const LinalgExpr& lhs); +//LinalgExpr operator^(const LinalgExpr&, const T); + +/// Check to make sure operators are legal (shape-wise) +int getMatMulOrder(const LinalgExpr &lhs, const LinalgExpr &rhs); + +void checkCompatibleShape(const LinalgExpr &lhs, const LinalgExpr &rhs); +/// A an index statement computes a tensor. The index statements are +/// assignment, forall, where, multi, and sequence. +class LinalgStmt : public util::IntrusivePtr { +public: + LinalgStmt(); + LinalgStmt(const LinalgStmtNode* n); + + /// Visit the tensor expression + void accept(LinalgStmtVisitorStrict *) const; +}; + +class LinalgAssignment : public LinalgStmt { +public: + LinalgAssignment() = default; + LinalgAssignment(const LinalgAssignmentNode*); + + /// Create an assignment. + LinalgAssignment(TensorVar lhs, LinalgExpr rhs); + + /// Return the assignment's left-hand side. + TensorVar getLhs() const; + + /// Return the assignment's right-hand side. + LinalgExpr getRhs() const; + + typedef LinalgAssignmentNode Node; +}; + +} + +#endif //TACO_LINALG_NOTATION_H diff --git a/include/taco/linalg_notation/linalg_notation_nodes.h b/include/taco/linalg_notation/linalg_notation_nodes.h new file mode 100644 index 000000000..44e3342ea --- /dev/null +++ b/include/taco/linalg_notation/linalg_notation_nodes.h @@ -0,0 +1,232 @@ +#ifndef TACO_LINALG_NOTATION_NODES_H +#define TACO_LINALG_NOTATION_NODES_H + +#include +#include +#include + +#include "taco/type.h" +#include "taco/index_notation/index_notation.h" +#include "taco/index_notation/index_notation_nodes_abstract.h" +#include "taco/index_notation/index_notation_visitor.h" +#include "taco/index_notation/intrinsic.h" +#include "taco/util/strings.h" +#include "taco/linalg_notation/linalg_notation.h" +#include "taco/linalg_notation/linalg_notation_nodes_abstract.h" +#include "taco/linalg_notation/linalg_notation_visitor.h" + +#include "taco/tensor.h" + +namespace taco { + + + struct LinalgVarNode : public LinalgExprNode { + LinalgVarNode(TensorVar tensorVar) + : LinalgExprNode(tensorVar.getType().getDataType(), tensorVar.getOrder()), tensorVar(tensorVar) {} + LinalgVarNode(TensorVar tensorVar, bool isColVec) + : LinalgExprNode(tensorVar.getType().getDataType(), tensorVar.getOrder(), isColVec), tensorVar(tensorVar) {} + + void accept(LinalgExprVisitorStrict* v) const override { + v->visit(this); + } + + virtual void setAssignment(const LinalgAssignment& assignment) {} + + TensorVar tensorVar; + }; + + struct LinalgTensorBaseNode : public LinalgVarNode { + LinalgTensorBaseNode(TensorVar tensorVar, TensorBase *tensorBase) + : LinalgVarNode(tensorVar), tensorBase(tensorBase) {} + LinalgTensorBaseNode(TensorVar tensorVar, TensorBase *tensorBase, bool isColVec) + : LinalgVarNode(tensorVar, isColVec), tensorBase(tensorBase) {} + void accept(LinalgExprVisitorStrict* v) const override { + v->visit(this); + } + + virtual void setAssignment(const LinalgAssignment& assignment) {} + + TensorBase* tensorBase; + }; + + struct LinalgLiteralNode : public LinalgExprNode { + template LinalgLiteralNode(T val) : LinalgExprNode(type()) { + this->val = malloc(sizeof(T)); + *static_cast(this->val) = val; + } + + ~LinalgLiteralNode() { + free(val); + } + + void accept(LinalgExprVisitorStrict* v) const override { + v->visit(this); + } + + template T getVal() const { + taco_iassert(getDataType() == type()) + << "Attempting to get data of wrong type"; + return *static_cast(val); + } + + void* val; + }; + + + struct LinalgUnaryExprNode : public LinalgExprNode { + LinalgExpr a; + + protected: + LinalgUnaryExprNode(LinalgExpr a) : LinalgExprNode(a.getDataType(), a.getOrder(), a.isColVector()), a(a) {} + LinalgUnaryExprNode(LinalgExpr a, bool isColVec) : LinalgExprNode(a.getDataType(), a.getOrder(), isColVec), a(a) {} + }; + + + struct LinalgNegNode : public LinalgUnaryExprNode { + LinalgNegNode(LinalgExpr operand) : LinalgUnaryExprNode(operand) {} + + void accept(LinalgExprVisitorStrict* v) const override{ + v->visit(this); + } + }; + + struct LinalgTransposeNode : public LinalgUnaryExprNode { + LinalgTransposeNode(LinalgExpr operand) : LinalgUnaryExprNode(operand) {} + LinalgTransposeNode(LinalgExpr operand, bool isColVec) : LinalgUnaryExprNode(operand, isColVec) {} + + void accept (LinalgExprVisitorStrict* v) const override{ + v->visit(this); + } + }; + + struct LinalgBinaryExprNode : public LinalgExprNode { + virtual std::string getOperatorString() const = 0; + + LinalgExpr a; + LinalgExpr b; + + protected: + LinalgBinaryExprNode() : LinalgExprNode() {} + LinalgBinaryExprNode(LinalgExpr a, LinalgExpr b, int order) + : LinalgExprNode(max_type(a.getDataType(), b.getDataType()), order), a(a), b(b) {} + LinalgBinaryExprNode(LinalgExpr a, LinalgExpr b, int order, bool isColVec) + : LinalgExprNode(max_type(a.getDataType(), b.getDataType()), order, isColVec), a(a), b(b) {} + }; + + + struct LinalgAddNode : public LinalgBinaryExprNode { + LinalgAddNode() : LinalgBinaryExprNode() {} + LinalgAddNode(LinalgExpr a, LinalgExpr b, int order) : LinalgBinaryExprNode(a, b, order) {} + LinalgAddNode(LinalgExpr a, LinalgExpr b, int order, bool isColVec) : LinalgBinaryExprNode(a, b, order, isColVec) {} + + std::string getOperatorString() const override{ + return "+"; + } + + void accept(LinalgExprVisitorStrict* v) const override{ + v->visit(this); + } + }; + + + struct LinalgSubNode : public LinalgBinaryExprNode { + LinalgSubNode() : LinalgBinaryExprNode() {} + LinalgSubNode(LinalgExpr a, LinalgExpr b, int order) : LinalgBinaryExprNode(a, b, order) {} + LinalgSubNode(LinalgExpr a, LinalgExpr b, int order, bool isColVec) : LinalgBinaryExprNode(a, b, order, isColVec) {} + + std::string getOperatorString() const override{ + return "-"; + } + + void accept(LinalgExprVisitorStrict* v) const override{ + v->visit(this); + } + }; + + + struct LinalgMatMulNode : public LinalgBinaryExprNode { + LinalgMatMulNode() : LinalgBinaryExprNode() {} + LinalgMatMulNode(LinalgExpr a, LinalgExpr b, int order) : LinalgBinaryExprNode(a, b, order) {} + LinalgMatMulNode(LinalgExpr a, LinalgExpr b, int order, bool isColVec) : LinalgBinaryExprNode(a, b, order, isColVec) {} + + std::string getOperatorString() const override{ + return "*"; + } + + void accept(LinalgExprVisitorStrict* v) const override{ + v->visit(this); + } + }; + +struct LinalgElemMulNode : public LinalgBinaryExprNode { + LinalgElemMulNode() : LinalgBinaryExprNode() {} + LinalgElemMulNode(LinalgExpr a, LinalgExpr b, int order) : LinalgBinaryExprNode(a, b, order) {} + LinalgElemMulNode(LinalgExpr a, LinalgExpr b, int order, bool isColVec) : LinalgBinaryExprNode(a, b, order, isColVec) {} + + std::string getOperatorString() const override{ + return "elemMul"; + } + + void accept(LinalgExprVisitorStrict* v) const override{ + v->visit(this); + } +}; + +struct LinalgDivNode : public LinalgBinaryExprNode { + LinalgDivNode() : LinalgBinaryExprNode() {} + LinalgDivNode(LinalgExpr a, LinalgExpr b, int order) : LinalgBinaryExprNode(a, b, order) {} + LinalgDivNode(LinalgExpr a, LinalgExpr b, int order, bool isColVec) : LinalgBinaryExprNode(a, b, order, isColVec) {} + + std::string getOperatorString() const override{ + return "/"; + } + + void accept(LinalgExprVisitorStrict* v) const override{ + v->visit(this); + } +}; + +// Linalg Statements +struct LinalgAssignmentNode : public LinalgStmtNode { + LinalgAssignmentNode(const TensorVar& lhs, const LinalgExpr& rhs) + : lhs(lhs), rhs(rhs) { isColVec = false;} + + LinalgAssignmentNode(const TensorVar& lhs, bool isColVec, const LinalgExpr& rhs) + : lhs(lhs), rhs(rhs), isColVec(isColVec) {} + + void accept(LinalgStmtVisitorStrict* v) const { + v->visit(this); + } + + TensorVar lhs; + LinalgExpr rhs; + bool isColVec; +}; + +/// Returns true if expression e is of type E. +template +inline bool isa(const LinalgExprNode* e) { + return e != nullptr && dynamic_cast(e) != nullptr; +} + +/// Casts the expression e to type E. +template +inline const E* to(const LinalgExprNode* e) { + taco_iassert(isa(e)) << + "Cannot convert " << typeid(e).name() << " to " << typeid(E).name(); + return static_cast(e); +} + +/// Returns true if statement e is of type S. +template +inline bool isa(const LinalgStmtNode* s) { + return s != nullptr && dynamic_cast(s) != nullptr; +} + +//template +//inline const typename I::Node* getNode(const I& stmt) { +// taco_iassert(isa(stmt.ptr)); +// return static_cast(stmt.ptr); +//} +} +#endif //TACO_LINALG_NOTATION_NODES_H diff --git a/include/taco/linalg_notation/linalg_notation_nodes_abstract.h b/include/taco/linalg_notation/linalg_notation_nodes_abstract.h new file mode 100644 index 000000000..136034053 --- /dev/null +++ b/include/taco/linalg_notation/linalg_notation_nodes_abstract.h @@ -0,0 +1,63 @@ +// +// Created by Olivia Hsu on 10/30/20. +// + +#ifndef TACO_LINALG_NOTATION_NODES_ABSTRACT_H +#define TACO_LINALG_NOTATION_NODES_ABSTRACT_H + +#include +#include + +#include "taco/type.h" +#include "taco/util/uncopyable.h" +#include "taco/util/intrusive_ptr.h" +#include "taco/linalg_notation/linalg_notation_visitor.h" + +namespace taco { + +class TensorVar; +class LinalgExprVisitorStrict; +class Precompute; + +/// A node of a scalar index expression tree. +struct LinalgExprNode : public util::Manageable, + private util::Uncopyable { +public: + LinalgExprNode() = default; + explicit LinalgExprNode(Datatype type); + LinalgExprNode(Datatype type, int order); + LinalgExprNode(Datatype type, int order, bool isColVec); + + virtual ~LinalgExprNode() = default; + + virtual void accept(LinalgExprVisitorStrict*) const = 0; + + /// Return the scalar data type of the index expression. + Datatype getDataType() const; + int getOrder() const; + bool isColVector() const; + void setColVector(bool val); + +private: + Datatype dataType; + int order; + bool isColVec; +}; + +struct LinalgStmtNode : public util::Manageable, + private util::Uncopyable { +public: + LinalgStmtNode() = default; + LinalgStmtNode(Type type); + virtual ~LinalgStmtNode() = default; + virtual void accept(LinalgStmtVisitorStrict*) const = 0; + + Type getType() const; + +private: + Type type; +}; + +} + +#endif //TACO_LINALG_NOTATION_NODES_ABSTRACT_H diff --git a/include/taco/linalg_notation/linalg_notation_printer.h b/include/taco/linalg_notation/linalg_notation_printer.h new file mode 100644 index 000000000..7a4754c2d --- /dev/null +++ b/include/taco/linalg_notation/linalg_notation_printer.h @@ -0,0 +1,54 @@ +#ifndef TACO_LINALG_NOTATION_PRINTER_H +#define TACO_LINALG_NOTATION_PRINTER_H + +#include +#include "taco/linalg_notation/linalg_notation_visitor.h" + +namespace taco { + +class LinalgNotationPrinter : public LinalgNotationVisitorStrict { +public: + explicit LinalgNotationPrinter(std::ostream& os); + + void print(const LinalgExpr& expr); + void print(const LinalgStmt& expr); + + using LinalgExprVisitorStrict::visit; + + // Scalar Expressions + void visit(const LinalgVarNode*); + void visit(const LinalgTensorBaseNode*); + void visit(const LinalgLiteralNode*); + void visit(const LinalgNegNode*); + void visit(const LinalgAddNode*); + void visit(const LinalgSubNode*); + void visit(const LinalgMatMulNode*); + void visit(const LinalgElemMulNode*); + void visit(const LinalgDivNode*); + void visit(const LinalgTransposeNode*); + + void visit(const LinalgAssignmentNode*); + +private: + std::ostream& os; + + enum class Precedence { + ACCESS = 2, + VAR = 2, + FUNC = 2, + NEG = 3, + TRANSPOSE = 3, + MATMUL = 5, + ELEMMUL = 5, + DIV = 5, + ADD = 6, + SUB = 6, + TOP = 20 + }; + Precedence parentPrecedence; + + template void visitBinary(Node op, Precedence p); +}; + +} +#endif //TACO_LINALG_NOTATION_PRINTER_H diff --git a/include/taco/linalg_notation/linalg_notation_visitor.h b/include/taco/linalg_notation/linalg_notation_visitor.h new file mode 100644 index 000000000..001dbdd7a --- /dev/null +++ b/include/taco/linalg_notation/linalg_notation_visitor.h @@ -0,0 +1,99 @@ +#ifndef TACO_LINALG_NOTATION_VISITOR_H +#define TACO_LINALG_NOTATION_VISITOR_H +namespace taco { + +class LinalgExpr; +class LinalgStmt; + +class TensorVar; + +struct LinalgVarNode; +struct LinalgTensorBaseNode; +struct LinalgLiteralNode; +struct LinalgNegNode; +struct LinalgTransposeNode; +struct LinalgAddNode; +struct LinalgSubNode; +struct LinalgMatMulNode; +struct LinalgElemMulNode; +struct LinalgDivNode; +struct LinalgUnaryExprNode; +struct LinalgBinaryExprNode; + +struct LinalgAssignmentNode; + +/// Visit the nodes in an expression. This visitor provides some type safety +/// by requiring all visit methods to be overridden. +class LinalgExprVisitorStrict { +public: + virtual ~LinalgExprVisitorStrict() = default; + + void visit(const LinalgExpr &); + + virtual void visit(const LinalgVarNode *) = 0; + + virtual void visit(const LinalgTensorBaseNode*) = 0; + + virtual void visit(const LinalgLiteralNode *) = 0; + + virtual void visit(const LinalgNegNode *) = 0; + + virtual void visit(const LinalgAddNode *) = 0; + + virtual void visit(const LinalgSubNode *) = 0; + + virtual void visit(const LinalgMatMulNode *) = 0; + + virtual void visit(const LinalgElemMulNode *) = 0; + + virtual void visit(const LinalgDivNode *) = 0; + + virtual void visit(const LinalgTransposeNode *) = 0; +}; + +class LinalgStmtVisitorStrict { +public: + virtual ~LinalgStmtVisitorStrict() = default; + + void visit(const LinalgStmt&); + + virtual void visit(const LinalgAssignmentNode*) = 0; +}; + +/// Visit nodes in linalg notation +class LinalgNotationVisitorStrict : public LinalgExprVisitorStrict, + public LinalgStmtVisitorStrict { +public: + virtual ~LinalgNotationVisitorStrict() = default; + + using LinalgExprVisitorStrict::visit; + using LinalgStmtVisitorStrict::visit; +}; + +/// Visit nodes in an expression. +class LinalgNotationVisitor : public LinalgNotationVisitorStrict { +public: + virtual ~LinalgNotationVisitor() = default; + + using LinalgNotationVisitorStrict::visit; + + // Index Expressions + virtual void visit(const LinalgVarNode* node); + virtual void visit(const LinalgTensorBaseNode* node); + virtual void visit(const LinalgLiteralNode* node); + virtual void visit(const LinalgNegNode* node); + virtual void visit(const LinalgAddNode* node); + virtual void visit(const LinalgSubNode* node); + virtual void visit(const LinalgMatMulNode* node); + virtual void visit(const LinalgElemMulNode* node); + virtual void visit(const LinalgDivNode* node); + virtual void visit(const LinalgUnaryExprNode* node); + virtual void visit(const LinalgBinaryExprNode* node); + virtual void visit(const LinalgTransposeNode* node); + + // Index Statments + virtual void visit(const LinalgAssignmentNode* node); +}; + +} +#endif //TACO_LINALG_NOTATION_VISITOR_H diff --git a/include/taco/linalg_notation/linalg_rewriter.h b/include/taco/linalg_notation/linalg_rewriter.h new file mode 100644 index 000000000..6f8a37578 --- /dev/null +++ b/include/taco/linalg_notation/linalg_rewriter.h @@ -0,0 +1,75 @@ +#ifndef TACO_LINALG_REWRITER_H +#define TACO_LINALG_REWRITER_H + +#include +#include +#include +#include +#include + +#include "taco/lower/iterator.h" +#include "taco/util/scopedset.h" +#include "taco/util/uncopyable.h" +#include "taco/ir_tags.h" + +namespace taco { + +class TensorVar; + +class IndexVar; + +class IndexExpr; + +class LinalgBase; + +class LinalgRewriter : public util::Uncopyable { +public: + LinalgRewriter(); + + virtual ~LinalgRewriter() = default; + + /// Lower an index statement to an IR function. + IndexStmt rewrite(LinalgBase linalgBase); + + void setLiveIndices(std::vector indices); +protected: + + virtual IndexExpr rewriteSub(const LinalgSubNode* sub); + + virtual IndexExpr rewriteAdd(const LinalgAddNode* add); + + virtual IndexExpr rewriteElemMul(const LinalgElemMulNode* elemMul); + + virtual IndexExpr rewriteMatMul(const LinalgMatMulNode* matMul); + + virtual IndexExpr rewriteDiv(const LinalgDivNode* div); + + virtual IndexExpr rewriteNeg(const LinalgNegNode* neg); + + virtual IndexExpr rewriteTranspose(const LinalgTransposeNode* transpose); + + virtual IndexExpr rewriteLiteral(const LinalgLiteralNode* literal); + + virtual IndexExpr rewriteVar(const LinalgVarNode* var); + + virtual IndexExpr rewriteTensorBase(const LinalgTensorBaseNode* node); + + virtual IndexStmt rewriteAssignment(const LinalgAssignmentNode* node); + + IndexExpr rewrite(LinalgExpr linalgExpr); + +private: + std::vector liveIndices; + + int idxcount = 0; + std::vector indexVarNameList = {"i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"}; + + IndexVar getUniqueIndex(); + + class Visitor; + friend class Visitor; + std::shared_ptr visitor; +}; + +} // namespace taco +#endif //TACO_LINALG_REWRITER_H diff --git a/include/taco/parser/lexer.h b/include/taco/parser/lexer.h index 55dc74410..c5a626cac 100644 --- a/include/taco/parser/lexer.h +++ b/include/taco/parser/lexer.h @@ -23,6 +23,9 @@ enum class Token { mul, div, eq, + caretT, + elemMul, + transpose, eot, // End of tokens error }; diff --git a/include/taco/parser/linalg_parser.h b/include/taco/parser/linalg_parser.h new file mode 100644 index 000000000..892db30db --- /dev/null +++ b/include/taco/parser/linalg_parser.h @@ -0,0 +1,110 @@ +#ifndef TACO_LINALG_PARSER_H +#define TACO_LINALG_PARSER_H + +#include +#include +#include +#include +#include + +#include "taco/tensor.h" +#include "taco/util/uncopyable.h" +#include "taco/type.h" +#include "taco/parser/parser.h" +#include "taco/linalg_notation/linalg_notation_nodes.h" + +namespace taco { +class TensorBase; +class LinalgBase; +class Format; +class IndexVar; +class LinalgExpr; + +class LinalgStmt; +class LinalgAssignment; + +namespace parser { +enum class Token; + +class LinalgParser : public AbstractParser { + +public: + + /// Create a parser object from linalg notation + LinalgParser(std::string expression, const std::map& formats, + const std::map& dataTypes, + const std::map>& tensorDimensions, + const std::map& tensors, + const std::map& linalgShapes, const std::map& linalgVecShapes, + int defaultDimension=5); + + /// Parses the linalg expression and sets the result tensor to the result of that expression + /// @throws ParserError if there is an error with parsing the linalg string + void parse() override; + + /// Gets the result tensor after parsing is complete. + const TensorBase& getResultTensor() const override; + + /// Gets all tensors + const std::map& getTensors() const override; + + /// Retrieve the tensor with the given name + const TensorBase& getTensor(std::string name) const override; + + /// Returns true if the tensor appeared in the expression + bool hasTensor(std::string name) const; + + /// Returns true if the index variable appeared in the expression + bool hasIndexVar(std::string name) const; + + /// Retrieve the index variable with the given name + IndexVar getIndexVar(std::string name) const; + +private: + Datatype outType; + Format format; + + struct Content; + std::shared_ptr content; + std::vector names; + + /// assign ::= var '=' expr + LinalgBase parseAssign(); + + /// expr ::= term {('+' | '-') term} + LinalgExpr parseExpr(); + + /// term ::= factor {('*' | '/') factor} + LinalgExpr parseTerm(); + + /// factor ::= final + /// | '(' expr ')' + /// | '-' factor + /// | factor '^T' + LinalgExpr parseFactor(); + + /// final ::= var + /// | scalar + LinalgExpr parseFinal(); + + LinalgExpr parseCall(); + + /// var ::= identifier + LinalgBase parseVar(); + + std::string currentTokenString(); + + void consume(Token expected); + + /// Retrieve the next token from the lexer + void nextToken(); + + /// FIXME: REMOVE LATER, temporary workaround to use Tensor API and TensorBase + std::vector getUniqueIndices(size_t order); + + int idxcount; +}; + + +}} +#endif //TACO_LINALG_PARSER_H diff --git a/include/taco/parser/parser.h b/include/taco/parser/parser.h index 9a3c4cfff..a2b0d94d9 100644 --- a/include/taco/parser/parser.h +++ b/include/taco/parser/parser.h @@ -20,11 +20,19 @@ class Access; namespace parser { enum class Token; +class AbstractParser : public util::Uncopyable { +public: + virtual void parse() = 0; + virtual const TensorBase& getResultTensor() const = 0; + virtual const std::map& getTensors() const = 0; + virtual const TensorBase& getTensor(std::string name) const = 0; +}; + /// A simple index expression parser. The parser can parse an index expression /// string, where tensor access expressions are in the form (e.g.) `A(i,j)`, /// A_{i,j} or A_i. A variable is taken to be free if it is used to index the /// lhs, and taken to be a summation variable otherwise. -class Parser : public util::Uncopyable { +class Parser : public AbstractParser { public: Parser(std::string expression, const std::map& formats, const std::map& dataTypes, @@ -34,10 +42,10 @@ class Parser : public util::Uncopyable { /// Parse the expression. /// @throws ParseError if there's a parser error - void parse(); + void parse() override; /// Returns the result (lhs) tensor of the index expression. - const TensorBase& getResultTensor() const; + const TensorBase& getResultTensor() const override; /// Returns true if the index variable appeared in the expression bool hasIndexVar(std::string name) const; @@ -49,10 +57,10 @@ class Parser : public util::Uncopyable { bool hasTensor(std::string name) const; /// Retrieve the tensor with the given name - const TensorBase& getTensor(std::string name) const; + const TensorBase& getTensor(std::string name) const override; /// Retrieve a map from tensor names to tensors. - const std::map& getTensors() const; + const std::map& getTensors() const override; /// Retrieve a list of names in the order they occurred in the expression const std::vector getNames() const; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f68d4e4c7..f63016d84 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -6,7 +6,7 @@ else() message("-- Static library") endif() -set(TACO_SRC_DIRS . parser index_notation lower ir codegen storage error util) +set(TACO_SRC_DIRS . parser index_notation lower ir codegen storage error util linalg_notation) foreach(dir ${TACO_SRC_DIRS}) file(GLOB TACO_HEADERS ${TACO_HEADERS} ${dir}/*.h) diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 5310455f6..99458a9ac 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -844,7 +844,6 @@ IndexStmt reorderLoopsTopologically(IndexStmt stmt) { varOrderFromTensorLevels(tensorLevelVar.second); } const auto hardDeps = depsFromVarOrders(tensorVarOrders); - struct CollectSoftDependencies : public IndexNotationVisitor { using IndexNotationVisitor::visit; diff --git a/src/linalg.cpp b/src/linalg.cpp new file mode 100644 index 000000000..5cf3eb861 --- /dev/null +++ b/src/linalg.cpp @@ -0,0 +1,101 @@ +#include "taco/linalg.h" + +#include "taco/index_notation/index_notation.h" +#include "taco/index_notation/index_notation_nodes.h" +#include "taco/linalg_notation/linalg_notation_nodes.h" +#include "taco/linalg_notation/linalg_rewriter.h" + +using namespace std; + +namespace taco { + +LinalgBase::LinalgBase(string name, Type tensorType, Datatype dtype, std::vector dims, Format format, bool isColVec) : + LinalgExpr(TensorVar(name, tensorType, format), isColVec, new TensorBase(name, dtype, dims, format)), name(name), + tensorType(tensorType), idxcount(0) { +} + +LinalgAssignment LinalgBase::operator=(const LinalgExpr& expr) { + taco_iassert(isa(this->ptr)); + TensorVar var = to(this->get())->tensorVar; + + taco_uassert(var.getOrder() == expr.getOrder()) << "LHS (" << var.getOrder() << ") and RHS (" << expr.getOrder() + << ") of linalg assignment must match order"; + if (var.getOrder() == 1) + taco_uassert(this->isColVector() == expr.isColVector()) << "RHS and LHS of linalg assignment must match vector type"; + + LinalgAssignment assignment = LinalgAssignment(var, expr); + this->assignment = assignment; + this->rewrite(); + return assignment; +} + +const LinalgAssignment LinalgBase::getAssignment() const{ + return this->assignment; +} +const IndexStmt LinalgBase::getIndexAssignment() const { + if (this->indexAssignment.defined()) { + return this->indexAssignment; + } + return IndexStmt(); +} + +vector LinalgBase::getUniqueIndices(size_t order) { + vector result; + for (int i = idxcount; i < (idxcount + (int)order); i++) { + string name = "i" + to_string(i); + IndexVar indexVar(name); + result.push_back(indexVar); + } + idxcount += order; + return result; +} + +IndexVar LinalgBase::getUniqueIndex() { + int loc = idxcount % indexVarNameList.size(); + int num = idxcount / indexVarNameList.size(); + + string indexVarName; + if (num == 0) + indexVarName = indexVarNameList.at(loc); + else + indexVarName = indexVarNameList.at(loc) + to_string(num); + + idxcount += 1; + IndexVar result(indexVarName); + return result; +} + +IndexExpr LinalgBase::rewrite(LinalgExpr linalg, vector indices) { + return IndexExpr(); +} + +IndexStmt rewrite(LinalgStmt linalg) { + return IndexStmt(); +} + +IndexStmt LinalgBase::rewrite() { + if (this->assignment.defined()) { + auto linalgRewriter = new LinalgRewriter(); + //linalgRewriter->setLiveIndices(indices); + IndexStmt stmt = linalgRewriter->rewrite(*this); + this->indexAssignment = stmt; + return stmt; + } + return IndexStmt(); +} + +std::ostream& operator<<(std::ostream& os, const LinalgBase& linalg) { + LinalgAssignment assignment = linalg.getAssignment(); + + // If TensorBase exists, print the storage + if (linalg.tensorBase != nullptr) { + return os << *(linalg.tensorBase); + } + + if (!assignment.defined()) return os << getNode(linalg)->tensorVar.getName(); + LinalgNotationPrinter printer(os); + printer.print(assignment); + return os; +} + +} diff --git a/src/linalg_notation/linalg_notation.cpp b/src/linalg_notation/linalg_notation.cpp new file mode 100644 index 000000000..a90e52dad --- /dev/null +++ b/src/linalg_notation/linalg_notation.cpp @@ -0,0 +1,227 @@ +//#include "taco/linalg_notation/linalg_notation.h" + +#include +#include +#include +#include +#include +#include +#include +#include "lower/mode_access.h" + +#include "error/error_checks.h" +#include "taco/error/error_messages.h" +#include "taco/type.h" + +//#include "taco/linalg_notation/linalg_notation.h" +#include "taco/linalg.h" +#include "taco/linalg_notation/linalg_notation_nodes.h" + +#include "taco/index_notation/schedule.h" +#include "taco/index_notation/transformations.h" +#include "taco/index_notation/index_notation_nodes.h" + +#include "taco/ir/ir.h" + + +using namespace std; + +namespace taco { + +LinalgExpr::LinalgExpr(TensorVar var) : LinalgExpr(new LinalgVarNode(var)) { +} + +LinalgExpr::LinalgExpr(TensorVar var, bool isColVec, TensorBase* _tensorBase) : LinalgExpr(new LinalgTensorBaseNode(var, _tensorBase, isColVec)) { + tensorBase = _tensorBase; +} + +LinalgExpr::LinalgExpr(TensorBase* _tensorBase, bool isColVec) : LinalgExpr(new LinalgTensorBaseNode(_tensorBase->getTensorVar(), _tensorBase, isColVec)) { + +} + +LinalgExpr::LinalgExpr(TensorVar var, bool isColVec) : LinalgExpr(new LinalgVarNode(var, isColVec)) { +} + +LinalgExpr::LinalgExpr(char val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(int8_t val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(int16_t val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(int32_t val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(int64_t val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(uint8_t val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(uint16_t val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(uint32_t val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(uint64_t val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(float val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(double val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(std::complex val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +LinalgExpr::LinalgExpr(std::complex val) : LinalgExpr(new LinalgLiteralNode(val)) { +} + +Datatype LinalgExpr::getDataType() const { + return const_cast(this->ptr)->getDataType(); +} + +int LinalgExpr::getOrder() const { + return const_cast(this->ptr)->getOrder(); +} + +bool LinalgExpr::isColVector() const { + return const_cast(this->ptr)->isColVector(); +} + +void LinalgExpr::setColVector(bool val) const { + const_cast(this->ptr)->setColVector(val); +} + +void LinalgExpr::accept(LinalgExprVisitorStrict *v) const { + ptr->accept(v); +} + +std::ostream& operator<<(std::ostream& os, const LinalgExpr& expr) { + if (!expr.defined()) return os << "LinalgExpr()"; + LinalgNotationPrinter printer(os); + printer.print(expr); + return os; +} + +void checkCompatibleShape(const LinalgExpr &lhs, const LinalgExpr &rhs) { + if (lhs.getOrder() != 0 && rhs.getOrder() != 0) + taco_uassert(lhs.getOrder() == rhs.getOrder()) << "RHS and LHS order do not match for linear algebra " + "binary operation" << endl; + if (lhs.getOrder() == 1) + taco_uassert(lhs.isColVector() == rhs.isColVector()) << "RHS and LHS vector type do not match for linear algebra " + "binary operation" << endl; +} + +LinalgExpr operator-(const LinalgExpr &expr) { + return LinalgExpr(new LinalgNegNode(expr.ptr)); +} + +LinalgExpr operator+(const LinalgExpr &lhs, const LinalgExpr &rhs) { + checkCompatibleShape(lhs, rhs); + if (lhs.getOrder() == 0) + return new LinalgAddNode(lhs, rhs, rhs.getOrder(), rhs.isColVector()); + return new LinalgAddNode(lhs, rhs, lhs.getOrder(), lhs.isColVector()); +} + +LinalgExpr operator-(const LinalgExpr &lhs, const LinalgExpr &rhs) { + checkCompatibleShape(lhs, rhs); + if (lhs.getOrder() == 0) + return new LinalgSubNode(lhs, rhs, rhs.getOrder(), rhs.isColVector()); + return new LinalgSubNode(lhs, rhs, lhs.getOrder(), lhs.isColVector()); +} + +LinalgExpr operator*(const LinalgExpr &lhs, const LinalgExpr &rhs) { + int order = 0; + bool isColVec = false; + // Matrix-matrix mult + if (lhs.getOrder() == 2 && rhs.getOrder() == 2) { + order = 2; + } + // Matrix-column vector multiply + else if (lhs.getOrder() == 2 && rhs.getOrder() == 1 && rhs.isColVector()) { + order = 1; + isColVec = true; + } + // Row-vector Matrix multiply + else if (lhs.getOrder() == 1 && !lhs.isColVector() && rhs.getOrder() == 2) { + order = 1; + } + // Inner product + else if (lhs.getOrder() == 1 && !lhs.isColVector() && rhs.getOrder() == 1 && rhs.isColVector()) { + order = 0; + } + // Outer product + else if (lhs.getOrder() == 1 && lhs.isColVector() && rhs.getOrder() == 1 && !rhs.isColVector()) { + order = 2; + } + // Scalar product + else if (lhs.getOrder() == 0) { + order = rhs.getOrder(); + isColVec = rhs.isColVector(); + } + else if (rhs.getOrder() == 0) { + order = lhs.getOrder(); + isColVec = lhs.isColVector(); + } + else { + taco_uassert(lhs.getOrder() != rhs.getOrder()) << "LHS (" << lhs.getOrder() << "," << lhs.isColVector() + << ") and RHS (" << rhs.getOrder() << "," << rhs.isColVector() + << ") order/vector type do not match " + "for linear algebra matrix multiply" << endl; + } + return new LinalgMatMulNode(lhs, rhs, order, isColVec); +} + +LinalgExpr operator/(const LinalgExpr &lhs, const LinalgExpr &rhs) { + checkCompatibleShape(lhs, rhs); + if (lhs.getOrder() == 0) + return new LinalgDivNode(lhs, rhs, rhs.getOrder(), rhs.isColVector()); + return new LinalgDivNode(lhs, rhs, lhs.getOrder(), lhs.isColVector()); +} + +LinalgExpr elemMul(const LinalgExpr &lhs, const LinalgExpr &rhs) { + checkCompatibleShape(lhs, rhs); + if (lhs.getOrder() == 0) + return new LinalgElemMulNode(lhs, rhs, rhs.getOrder(), rhs.isColVector()); + return new LinalgElemMulNode(lhs, rhs, lhs.getOrder(), lhs.isColVector()); +} + +LinalgExpr transpose(const LinalgExpr &lhs) { + return new LinalgTransposeNode(lhs, !lhs.isColVector()); +} + +// class LinalgStmt +LinalgStmt::LinalgStmt() : util::IntrusivePtr(nullptr) { +} + +LinalgStmt::LinalgStmt(const LinalgStmtNode* n) + : util::IntrusivePtr(n) { +} + +void LinalgStmt::accept(LinalgStmtVisitorStrict *v) const { + ptr->accept(v); +} + + +// class LinalgAssignment +LinalgAssignment::LinalgAssignment(const LinalgAssignmentNode* n) : LinalgStmt(n) { +} + +LinalgAssignment::LinalgAssignment(TensorVar lhs, LinalgExpr rhs) + : LinalgAssignment(new LinalgAssignmentNode(lhs, rhs)) { +} + +TensorVar LinalgAssignment::getLhs() const { + return getNode(*this)->lhs; +} + +LinalgExpr LinalgAssignment::getRhs() const { + return getNode(*this)->rhs; +} + +} // namespace taco diff --git a/src/linalg_notation/linalg_notation_nodes_abstract.cpp b/src/linalg_notation/linalg_notation_nodes_abstract.cpp new file mode 100644 index 000000000..52cf9ba3e --- /dev/null +++ b/src/linalg_notation/linalg_notation_nodes_abstract.cpp @@ -0,0 +1,42 @@ +#include "taco/linalg_notation/linalg_notation_nodes_abstract.h" + +using namespace std; + +namespace taco { + +LinalgExprNode::LinalgExprNode(Datatype type) + : dataType(type), order(0), isColVec(false) { +} + +LinalgExprNode::LinalgExprNode(Datatype type, int order) + : dataType(type), order(order) { + if (order != 1) + isColVec = false; + else + isColVec = true; +} + +LinalgExprNode::LinalgExprNode(Datatype type, int order, bool isColVec) + : dataType(type), order(order) { + if (order != 1) + this->isColVec = false; + else + this->isColVec = isColVec; +} + +Datatype LinalgExprNode::getDataType() const { + return dataType; +} + +int LinalgExprNode::getOrder() const { + return order; +} + +bool LinalgExprNode::isColVector() const { + return isColVec; +} + +void LinalgExprNode::setColVector(bool val) { + isColVec = val; +} +} diff --git a/src/linalg_notation/linalg_notation_printer.cpp b/src/linalg_notation/linalg_notation_printer.cpp new file mode 100644 index 000000000..0c8041895 --- /dev/null +++ b/src/linalg_notation/linalg_notation_printer.cpp @@ -0,0 +1,163 @@ +#include "taco/linalg_notation/linalg_notation_printer.h" +#include "taco/linalg_notation/linalg_notation_nodes.h" + +using namespace std; + +namespace taco { + +LinalgNotationPrinter::LinalgNotationPrinter(std::ostream& os) : os(os) { +} + +void LinalgNotationPrinter::print(const LinalgExpr& expr) { + parentPrecedence = Precedence::TOP; + expr.accept(this); +} + +void LinalgNotationPrinter::print(const LinalgStmt& expr) { + parentPrecedence = Precedence::TOP; + expr.accept(this); +} + +void LinalgNotationPrinter::visit(const LinalgVarNode* op) { + os << op->tensorVar.getName(); +} + +void LinalgNotationPrinter::visit(const LinalgTensorBaseNode* op) { + os << op->tensorBase->getName(); +} + +void LinalgNotationPrinter::visit(const LinalgLiteralNode* op) { + switch (op->getDataType().getKind()) { + case Datatype::Bool: + os << op->getVal(); + break; + case Datatype::UInt8: + os << op->getVal(); + break; + case Datatype::UInt16: + os << op->getVal(); + break; + case Datatype::UInt32: + os << op->getVal(); + break; + case Datatype::UInt64: + os << op->getVal(); + break; + case Datatype::UInt128: + taco_not_supported_yet; + break; + case Datatype::Int8: + os << op->getVal(); + break; + case Datatype::Int16: + os << op->getVal(); + break; + case Datatype::Int32: + os << op->getVal(); + break; + case Datatype::Int64: + os << op->getVal(); + break; + case Datatype::Int128: + taco_not_supported_yet; + break; + case Datatype::Float32: + os << op->getVal(); + break; + case Datatype::Float64: + os << op->getVal(); + break; + case Datatype::Complex64: + os << op->getVal>(); + break; + case Datatype::Complex128: + os << op->getVal>(); + break; + case Datatype::Undefined: + break; + } +} + +void LinalgNotationPrinter::visit(const LinalgNegNode* op) { + Precedence precedence = Precedence::NEG; + bool parenthesize = precedence > parentPrecedence; + parentPrecedence = precedence; + os << "-"; + if (parenthesize) { + os << "("; + } + op->a.accept(this); + if (parenthesize) { + os << ")"; + } +} + +void LinalgNotationPrinter::visit(const LinalgTransposeNode* op) { + Precedence precedence = Precedence::TRANSPOSE; + bool parenthesize = precedence > parentPrecedence; + parentPrecedence = precedence; + if (parenthesize) { + os << "("; + } + op->a.accept(this); + if (parenthesize) { + os << ")"; + } + os << "^T"; +} + +template +void LinalgNotationPrinter::visitBinary(Node op, Precedence precedence) { + bool parenthesize = precedence > parentPrecedence; + if (parenthesize) { + os << "("; + } + parentPrecedence = precedence; + op->a.accept(this); + os << " " << op->getOperatorString() << " "; + parentPrecedence = precedence; + op->b.accept(this); + if (parenthesize) { + os << ")"; + } +} + +void LinalgNotationPrinter::visit(const LinalgAddNode* op) { + visitBinary(op, Precedence::ADD); +} + +void LinalgNotationPrinter::visit(const LinalgSubNode* op) { + visitBinary(op, Precedence::SUB); +} + +void LinalgNotationPrinter::visit(const LinalgMatMulNode* op) { + visitBinary(op, Precedence::MATMUL); +} + +void LinalgNotationPrinter::visit(const LinalgElemMulNode* op) { + visitBinary(op, Precedence::ELEMMUL); +} + +void LinalgNotationPrinter::visit(const LinalgDivNode* op) { + visitBinary(op, Precedence::DIV); +} + +template +static inline void acceptJoin(LinalgNotationPrinter* printer, + std::ostream& stream, const std::vector& nodes, + std::string sep) { + if (nodes.size() > 0) { + nodes[0].accept(printer); + } + for (size_t i = 1; i < nodes.size(); ++i) { + stream << sep; + nodes[i].accept(printer); + } +} + +void LinalgNotationPrinter::visit(const LinalgAssignmentNode* op) { + os << op->lhs.getName() << " " << "= "; + op->rhs.accept(this); +} + +} diff --git a/src/linalg_notation/linalg_notation_visitor.cpp b/src/linalg_notation/linalg_notation_visitor.cpp new file mode 100644 index 000000000..c71001d06 --- /dev/null +++ b/src/linalg_notation/linalg_notation_visitor.cpp @@ -0,0 +1,17 @@ + +#include "taco/linalg_notation/linalg_notation_visitor.h" +#include "taco/linalg_notation/linalg_notation_nodes.h" + +using namespace std; + +namespace taco { + +void LinalgExprVisitorStrict::visit(const LinalgExpr &expr) { + expr.accept(this); +} + +void LinalgStmtVisitorStrict::visit(const LinalgStmt& stmt) { + stmt.accept(this); +} + +} diff --git a/src/linalg_notation/linalg_rewriter.cpp b/src/linalg_notation/linalg_rewriter.cpp new file mode 100644 index 000000000..d6f6c6a82 --- /dev/null +++ b/src/linalg_notation/linalg_rewriter.cpp @@ -0,0 +1,242 @@ +#include "taco/linalg_notation/linalg_rewriter.h" + +#include "taco/linalg_notation/linalg_notation_nodes.h" +#include "taco/index_notation/index_notation_nodes.h" + +using namespace std; +using namespace taco; + +class LinalgRewriter::Visitor : public LinalgNotationVisitorStrict { +public: + Visitor(LinalgRewriter* rewriter ) : rewriter(rewriter) {} + IndexExpr rewrite(LinalgExpr linalgExpr) { + this->expr = IndexExpr(); + LinalgNotationVisitorStrict::visit(linalgExpr); + return this->expr; + } + IndexStmt rewrite(LinalgStmt linalgStmt) { + this->stmt = IndexStmt(); + LinalgNotationVisitorStrict::visit(linalgStmt); + return this->stmt; + } +private: + LinalgRewriter* rewriter; + IndexExpr expr; + IndexStmt stmt; + using LinalgNotationVisitorStrict::visit; + void visit(const LinalgSubNode* node) { expr = rewriter->rewriteSub(node); } + void visit(const LinalgAddNode* node) { expr = rewriter->rewriteAdd(node); } + void visit(const LinalgElemMulNode* node) { expr = rewriter->rewriteElemMul(node); } + void visit(const LinalgMatMulNode* node) { expr = rewriter->rewriteMatMul(node); } + void visit(const LinalgDivNode* node) { expr = rewriter->rewriteDiv(node); } + void visit(const LinalgNegNode* node) { expr = rewriter->rewriteNeg(node); } + void visit(const LinalgTransposeNode* node) { expr = rewriter->rewriteTranspose(node); } + void visit(const LinalgLiteralNode* node) { expr = rewriter->rewriteLiteral(node); } + void visit(const LinalgVarNode* node) { expr = rewriter->rewriteVar(node); } + void visit(const LinalgTensorBaseNode* node) { expr = rewriter->rewriteTensorBase(node); } + void visit(const LinalgAssignmentNode* node) { stmt = rewriter->rewriteAssignment(node); } + +}; + +LinalgRewriter::LinalgRewriter() : visitor(new Visitor(this)) { +} + +IndexExpr LinalgRewriter::rewriteSub(const LinalgSubNode* sub) { + auto originalIndices = liveIndices; + IndexExpr indexA = rewrite(sub->a); + liveIndices = originalIndices; + IndexExpr indexB = rewrite(sub->b); + return new SubNode(indexA, indexB); +} + +IndexExpr LinalgRewriter::rewriteAdd(const LinalgAddNode* add) { + auto originalIndices = liveIndices; + IndexExpr indexA = rewrite(add->a); + liveIndices = originalIndices; + IndexExpr indexB = rewrite(add->b); + return new AddNode(indexA, indexB); +} + +IndexExpr LinalgRewriter::rewriteElemMul(const LinalgElemMulNode* elemMul) { + auto originalIndices = liveIndices; + IndexExpr indexA = rewrite(elemMul->a); + liveIndices = originalIndices; + IndexExpr indexB = rewrite(elemMul->b); + return new MulNode(indexA, indexB); +} + +IndexExpr LinalgRewriter::rewriteMatMul(const LinalgMatMulNode *matMul) { + IndexVar index = getUniqueIndex(); + vector indicesA; + vector indicesB; + if (matMul->a.getOrder() == 2 && matMul->b.getOrder() == 2) { + indicesA = {liveIndices[0], index}; + indicesB = {index, liveIndices[1]}; + } + else if (matMul->a.getOrder() == 1 && matMul->b.getOrder() == 2) { + indicesA = {index}; + indicesB = {index, liveIndices[0]}; + } + else if (matMul->a.getOrder() == 2 && matMul->b.getOrder() == 1) { + indicesA = {liveIndices[0], index}; + indicesB = {index}; + } + else if (matMul->a.getOrder() == 1 && matMul->a.isColVector() && matMul->b.getOrder() == 1) { + indicesA = {liveIndices[0]}; + indicesB = {liveIndices[1]}; + } else if (matMul->a.getOrder() == 0) { + indicesA = {}; + indicesB = liveIndices; + } else if (matMul->b.getOrder() == 0) { + indicesA = liveIndices; + indicesB = {}; + } else { + indicesA = {index}; + indicesB = {index}; + } + liveIndices = indicesA; + IndexExpr indexA = rewrite(matMul->a); + liveIndices = indicesB; + IndexExpr indexB = rewrite(matMul->b); + return new MulNode(indexA, indexB); +} + +IndexExpr LinalgRewriter::rewriteDiv(const LinalgDivNode *div) { + auto originalIndices = liveIndices; + IndexExpr indexA = rewrite(div->a); + liveIndices = originalIndices; + IndexExpr indexB = rewrite(div->b); + return new DivNode(indexA, indexB); +} + +IndexExpr LinalgRewriter::rewriteNeg(const LinalgNegNode *neg) { + IndexExpr index = rewrite(neg->a); + return new NegNode(index); +} + +IndexExpr LinalgRewriter::rewriteTranspose(const LinalgTransposeNode *transpose) { + if (transpose->a.getOrder() == 2) { + liveIndices = {liveIndices[1], liveIndices[0]}; + return rewrite(transpose->a); + } + else if (transpose->a.getOrder() == 1) { + liveIndices = {liveIndices[0]}; + return rewrite(transpose->a); + } + liveIndices = {}; + return rewrite(transpose->a); +} + +IndexExpr LinalgRewriter::rewriteLiteral(const LinalgLiteralNode *lit) { + LiteralNode* value; + switch (lit->getDataType().getKind()) { + case Datatype::Bool: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::UInt8: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::UInt16: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::UInt32: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::UInt64: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::UInt128: + taco_not_supported_yet; + break; + case Datatype::Int8: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::Int16: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::Int32: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::Int64: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::Int128: + taco_not_supported_yet; + break; + case Datatype::Float32: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::Float64: + value = new LiteralNode(lit->getVal()); + break; + case Datatype::Complex64: + value = new LiteralNode(lit->getVal>()); + break; + case Datatype::Complex128: + value = new LiteralNode(lit->getVal>()); + break; + case Datatype::Undefined: + taco_uerror << "unsupported Datatype"; + break; + } + return value; +} + +IndexExpr LinalgRewriter::rewriteVar(const LinalgVarNode *var) { + return new AccessNode(var->tensorVar, liveIndices); +} + +IndexExpr LinalgRewriter::rewriteTensorBase(const LinalgTensorBaseNode *node) { + return node->tensorBase->operator()(liveIndices); +} + +IndexVar LinalgRewriter::getUniqueIndex() { + int loc = idxcount % indexVarNameList.size(); + int num = idxcount / indexVarNameList.size(); + + string indexVarName; + if (num == 0) + indexVarName = indexVarNameList.at(loc); + else + indexVarName = indexVarNameList.at(loc) + to_string(num); + + idxcount += 1; + IndexVar result(indexVarName); + return result; +} + +IndexStmt LinalgRewriter::rewriteAssignment(const LinalgAssignmentNode *node) { + return IndexStmt(); +} + +void LinalgRewriter::setLiveIndices(std::vector indices) { + liveIndices = indices; +} + +IndexExpr LinalgRewriter::rewrite(LinalgExpr linalgExpr) { + return visitor->rewrite(linalgExpr); +} + +IndexStmt LinalgRewriter::rewrite(LinalgBase linalgBase) { + TensorVar tensor = linalgBase.getAssignment().getLhs(); + + vector indices = {}; + if (tensor.getOrder() == 1) { + indices.push_back(getUniqueIndex()); + } else if (tensor.getOrder() == 2) { + indices.push_back(getUniqueIndex()); + indices.push_back(getUniqueIndex()); + } + + Access lhs = Access(tensor, indices); + + liveIndices = indices; + auto rhs = rewrite(linalgBase.getAssignment().getRhs()); + + if(linalgBase.tensorBase != nullptr) { + linalgBase.tensorBase->operator()(indices) = rhs; + } + + Assignment indexAssign = Assignment(lhs, rhs); + return indexAssign; +} diff --git a/src/lower/expr_tools.cpp b/src/lower/expr_tools.cpp index ded5c53dd..e6a5a9b8c 100644 --- a/src/lower/expr_tools.cpp +++ b/src/lower/expr_tools.cpp @@ -215,6 +215,10 @@ class SubExprVisitor : public IndexExprVisitorStrict { subExpr = unarySubExpr(op); } + void visit(const TransposeNode* op) { + subExpr = unarySubExpr(op); + } + void visit(const SqrtNode* op) { subExpr = unarySubExpr(op); } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 17a4dab3b..a66473ced 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -1700,6 +1700,7 @@ Stmt LowererImpl::lowerSuchThat(SuchThat suchThat) { Expr LowererImpl::lowerAccess(Access access) { + TensorVar var = access.getTensorVar(); if (isScalar(var.getType())) { diff --git a/src/parser/lexer.cpp b/src/parser/lexer.cpp index a490840a8..8b0913dd3 100644 --- a/src/parser/lexer.cpp +++ b/src/parser/lexer.cpp @@ -31,6 +31,10 @@ Token Lexer::getToken() { lastChar = getNextChar(); return Token::complex_scalar; } + if (identifier == "transpose") + return Token::transpose; + if (identifier == "elemMul") + return Token::elemMul; return Token::identifier; } if(isdigit(lastChar)) { @@ -90,6 +94,13 @@ Token Lexer::getToken() { case EOF: token = Token::eot; break; + case '^': + lastChar = getNextChar(); + if (lastChar == 'T') + token = Token::caretT; + else + token = Token::error; + break; default: token = Token::error; break; @@ -161,6 +172,15 @@ std::string Lexer::tokenString(const Token& token) { case Token::error: str = "error"; break; + case Token::caretT: + str = "^T"; + break; + case Token::elemMul: + str = "elemMul"; + break; + case Token::transpose: + str = "transpose"; + break; case Token::eot: default: taco_ierror; diff --git a/src/parser/linalg_parser.cpp b/src/parser/linalg_parser.cpp new file mode 100644 index 000000000..42b4a20fe --- /dev/null +++ b/src/parser/linalg_parser.cpp @@ -0,0 +1,385 @@ +#include "taco/parser/linalg_parser.h" +#include "taco/parser/parser.h" + +#include + +#include "taco/parser/lexer.h" +#include "taco/tensor.h" +#include "taco/format.h" + +#include "taco/linalg_notation/linalg_notation.h" +#include "taco/linalg_notation/linalg_notation_nodes.h" +#include "taco/linalg.h" + +#include "taco/util/collections.h" + +using namespace std; + +namespace taco { +namespace parser { + +struct LinalgParser::Content { + /// Tensor formats + map formats; + map dataTypes; + + /// Tensor dimensions + map> tensorDimensions; + map indexVarDimensions; + + int defaultDimension; + + /// Track which modes have default values, so that we can change them + /// to values inferred from other tensors (read from files). + set> modesWithDefaults; + + Lexer lexer; + Token currentToken; + bool parsingLhs = false; + + map indexVars; + std::map linalgShapes; + std::map linalgVecShapes; + + TensorBase resultTensor; + map tensors; +}; + + LinalgParser::LinalgParser(string expression, const map& formats, + const map& dataTypes, + const map>& tensorDimensions, + const std::map& tensors, + const std::map& linalgShapes, const std::map& linalgVecShapes, + int defaultDimension) + : content(new LinalgParser::Content) { + content->lexer = Lexer(expression); + content->formats = formats; + content->tensorDimensions = tensorDimensions; + content->defaultDimension = defaultDimension; + content->tensors = tensors; + content->dataTypes = dataTypes; + + content->linalgShapes = linalgShapes; + content->linalgVecShapes = linalgVecShapes; + idxcount = 0; + + nextToken(); + } + +void LinalgParser::parse() { + LinalgBase linalgBase = parseAssign(); + linalgBase.rewrite(); + content->resultTensor = *linalgBase.tensorBase; +} + +const TensorBase& LinalgParser::getResultTensor() const { + return content->resultTensor; +} + +LinalgBase LinalgParser::parseAssign() { + content->parsingLhs = true; + LinalgBase lhs = parseVar(); + const TensorVar var = lhs.tensorBase->getTensorVar(); + content->parsingLhs = false; + + consume(Token::eq); + LinalgExpr rhs = parseExpr(); + lhs = rhs; + + return lhs; +} + +LinalgExpr LinalgParser::parseExpr() { + LinalgExpr expr = parseTerm(); + while (content->currentToken == Token::add || + content->currentToken == Token::sub) { + switch (content->currentToken) { + case Token::add: + consume(Token::add); + expr = expr + parseTerm(); + break; + case Token::sub: + consume(Token::sub); + expr = expr - parseTerm(); + break; + default: + taco_unreachable; + } + } + return expr; +} + +LinalgExpr LinalgParser::parseTerm() { + + + LinalgExpr term = parseFactor(); + while (content->currentToken == Token::mul || + content->currentToken == Token::div) { + switch (content->currentToken) { + case Token::mul: { + consume(Token::mul); + term = term * parseFactor(); + break; + } + case Token::div: { + consume(Token::div); + term = term / parseFactor(); + break; + } + + default: + taco_unreachable; + } + } + return term; +} + +LinalgExpr LinalgParser::parseFactor() { + switch (content->currentToken) { + case Token::lparen: { + consume(Token::lparen); + LinalgExpr factor = parseExpr(); + consume(Token::rparen); + return factor; + } + case Token::sub: { + consume(Token::sub); + return -parseFactor(); + } + case Token::transpose: { + consume(Token::transpose); + consume(Token::lparen); + LinalgExpr factor = parseExpr(); + consume(Token::rparen); + return transpose(factor); + } + default: + break; + } + + + if (content->currentToken == Token::caretT) { + LinalgExpr factor = parseFactor(); + consume(Token::caretT); + return transpose(factor); + } + + LinalgExpr final = parseFinal(); + return final; +} + +LinalgExpr LinalgParser::parseFinal() { + std::istringstream value (content->lexer.getIdentifier()); + switch (content->currentToken) { + case Token::complex_scalar: + { + consume(Token::complex_scalar); + std::complex complex_value; + value >> complex_value; + return LinalgExpr(complex_value); + } + case Token::int_scalar: + { + consume(Token::int_scalar); + int64_t int_value; + value >> int_value; + return LinalgExpr(int_value); + } + case Token::uint_scalar: + { + consume(Token::uint_scalar); + uint64_t uint_value; + value >> uint_value; + return LinalgExpr(uint_value); + } + case Token::float_scalar: + { + consume(Token::float_scalar); + double float_value; + value >> float_value; + return LinalgExpr(float_value); + } + default: + return parseCall(); + } +} + +LinalgExpr LinalgParser::parseCall() { + switch (content->currentToken) { + case Token::elemMul: { + consume(Token::elemMul); + consume(Token::lparen); + LinalgExpr term = parseExpr(); + consume(Token::comma); + term = elemMul(term, parseExpr()); + consume(Token::rparen); + return term; + } + case Token::transpose: { + consume(Token::transpose); + consume(Token::lparen); + LinalgExpr term = parseExpr(); + consume(Token::rparen); + return transpose(term); + } + default: + break; + } + return parseVar(); +} + +LinalgBase LinalgParser::parseVar() { + + if(content->currentToken != Token::identifier) { + cout << currentTokenString(); + throw ParseError("Expected linalg name"); + } + string tensorName = content->lexer.getIdentifier(); + consume(Token::identifier); + names.push_back(tensorName); + + size_t order = 0; + bool isColVec = false; + // LinalgParser: By default assume capital variables are Matrices and lower case variables are vectors + if (isupper(tensorName.at(0))) { + order = 2; + } else { + order = 1; + isColVec = true; + } + + if (content->linalgShapes.find(tensorName) != content->linalgShapes.end()) { + if (content->formats.find(tensorName) != content->formats.end()) { + taco_uassert(content->linalgShapes.at(tensorName) == content->formats.at(tensorName).getOrder()) + << "Linalg shape and tensor format must match" << endl; + + } + if (content->tensorDimensions.find(tensorName) != content->tensorDimensions.end()) + taco_uassert(content->linalgShapes.at(tensorName) == (int)content->tensorDimensions.at(tensorName).size()) + << "Linalg shape and the number of tensor dimensions must match" << endl; + + order = content->linalgShapes.at(tensorName); + isColVec = content->linalgVecShapes.at(tensorName); + } + else if (content->formats.find(tensorName) != content->formats.end()) { + + if (content->tensorDimensions.find(tensorName) != content->tensorDimensions.end()) + taco_uassert(content->formats.at(tensorName).getOrder() == (int)content->tensorDimensions.at(tensorName).size()) + << "Tensor format and tensor dimensions must match" << endl; + + order = content->formats.at(tensorName).getOrder(); + } else { + if (content->tensorDimensions.find(tensorName) != content->tensorDimensions.end()) + order = content->tensorDimensions.at(tensorName).size(); + } + + Format format; + if (util::contains(content->formats, tensorName)) { + format = content->formats.at(tensorName); + } + else { + format = Format(std::vector(order, Dense)); + } + + TensorBase tensor; + if (util::contains(content->tensors, tensorName)) { + tensor = content->tensors.at(tensorName); + } + else { + vector tensorDimensions(order); + vector modesWithDefaults(order, false); + for (size_t i = 0; i < tensorDimensions.size(); i++) { + if (util::contains(content->tensorDimensions, tensorName)) { + tensorDimensions[i] = content->tensorDimensions.at(tensorName)[i]; + } + else { + tensorDimensions[i] = content->defaultDimension; + modesWithDefaults[i] = true; + } + } + Datatype dataType = Float(); + if (util::contains(content->dataTypes, tensorName)) { + dataType = content->dataTypes.at(tensorName); + } + tensor = TensorBase(tensorName,dataType,tensorDimensions,format); + for (size_t i = 0; i < tensorDimensions.size(); i++) { + if (modesWithDefaults[i]) { + content->modesWithDefaults.insert({tensor.getTensorVar(), i}); + } + } + + content->tensors.insert({tensorName,tensor}); + } + LinalgBase resultLinalg(tensor.getName(), tensor.getTensorVar().getType(), tensor.getComponentType(), + tensor.getDimensions(), tensor.getFormat(), isColVec); + return resultLinalg; + //return LinalgBase(tensor.getName(), tensor.getComponentType(), tensor.getFormat() ); +} + +vector LinalgParser::getUniqueIndices(size_t order) { + vector result; + for (int i = idxcount; i < (idxcount + (int)order); i++) { + string name = "i" + to_string(i); + IndexVar indexVar = getIndexVar(name); + result.push_back(indexVar); + } + idxcount += order; + return result; +} + +IndexVar LinalgParser::getIndexVar(string name) const { + taco_iassert(name != ""); + if (!hasIndexVar(name)) { + IndexVar var(name); + content->indexVars.insert({name, var}); + + // tensorDimensions can also store index var dimensions + if (util::contains(content->tensorDimensions, name)) { + content->indexVarDimensions.insert({var, content->tensorDimensions.at(name)[0]}); + } + } + return content->indexVars.at(name); +} + +bool LinalgParser::hasIndexVar(std::string name) const { + return util::contains(content->indexVars, name); +} + +void LinalgParser::consume(Token expected) { + if(content->currentToken != expected) { + string error = "Expected \'" + content->lexer.tokenString(expected) + + "\' but got \'" + currentTokenString() + "\'"; + throw ParseError(error); + } + nextToken(); +} + +const std::map& LinalgParser::getTensors() const { + return content->tensors; +} + +// FIXME: Remove this redundancy and try to add it to abstract parser class... +void LinalgParser::nextToken() { + content->currentToken = content->lexer.getToken(); +} + +string LinalgParser::currentTokenString() { + return (content->currentToken == Token::identifier) + ? content->lexer.getIdentifier() + : content->lexer.tokenString(content->currentToken); +} + +const TensorBase& LinalgParser::getTensor(string name) const { + taco_iassert(name != ""); + if (!hasTensor(name)) { + taco_uerror << "Parser error: Tensor name " << name << + " not found in expression" << endl; + } + return content->tensors.at(name); +} + +bool LinalgParser::hasTensor(std::string name) const { + return util::contains(content->tensors, name); +} +} // namespace parser +} // namespace taco diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 472914c60..53df97415 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -54,6 +54,7 @@ Parser::Parser(string expression, const map& formats, content->defaultDimension = defaultDimension; content->tensors = tensors; content->dataTypes = dataTypes; + nextToken(); } diff --git a/test/tests-linalg.cpp b/test/tests-linalg.cpp new file mode 100644 index 000000000..1e1cc613e --- /dev/null +++ b/test/tests-linalg.cpp @@ -0,0 +1,588 @@ +#include "test.h" + +#include "taco/linalg.h" + +using namespace taco; + +TEST(linalg, matmul_index_expr) { + Tensor B("B", {2,2}); + Matrix C("C", 2, 2, dense, dense); + Matrix A("A", 2, 2, dense, dense); + + B(0,0) = 2; + B(1,1) = 1; + B(0,1) = 2; + C(0,0) = 2; + C(1,1) = 2; + + IndexVar i, j, k; + A(i,j) = B(i,k) * C(k,j); + + ASSERT_EQ((double) A(0,0), 4); + ASSERT_EQ((double) A(0,1), 4); + ASSERT_EQ((double) A(1,0), 0); + ASSERT_EQ((double) A(1,1), 2); +} + +TEST(linalg, vecmat_mul_index_expr) { + Vector x("x", 2, dense, false); + Vector b("b", 2, dense, false); + Matrix A("A", 2, 2, dense, dense); + + b(0) = 3; + b(1) = -2; + + A(0,0) = 5; + A(0,1) = 2; + A(1,0) = -1; + + IndexVar i, j; + x(i) = b(j) * A(j,i); + + ASSERT_EQ((double) x(0), 17); + ASSERT_EQ((double) x(1), 6); +} + + +TEST(linalg, inner_mul_index_expr) { + Scalar x("x"); + Vector b("b", 2, dense, false); + Vector a("a", 2, dense, true); + + b(0) = 2; + b(1) = 3; + + a(0) = -3; + a(1) = 5; + + IndexVar i; + x = b(i) * a(i); + + ASSERT_EQ((double) x, 9); +} + +TEST(linalg, matmul) { + Matrix B("B", 2, 2, dense, dense); + Matrix C("C", 2, 2, sparse, sparse); + Matrix A("A", 2, 2, dense, dense); + + B(0,0) = 2; + B(1,1) = 1; + B(0,1) = 2; + C(0,0) = 2; + C(1,1) = 2; + + A = B * C; + + ASSERT_EQ((double) A(0,0), 4); + ASSERT_EQ((double) A(0,1), 4); + ASSERT_EQ((double) A(1,0), 0); + ASSERT_EQ((double) A(1,1), 2); + + // Equivalent Tensor API computation + Tensor tB("B", {2, 2}, dense); + Tensor tC("C", {2, 2}, dense); + Tensor tA("A", {2, 2}, dense); + + tB(0,0) = 2; + tB(1,1) = 1; + tB(0,1) = 2; + tC(0,0) = 2; + tC(1,1) = 2; + + IndexVar i,j,k; + tA(i,j) = tB(i,k) * tC(k,j); + + ASSERT_TENSOR_EQ(A,tA); +} + +TEST(linalg, matmat_add) { + Matrix B("B", 2, 2, dense, dense); + Matrix C("C", 2, 2, dense, dense); + Matrix A("A", 2, 2, dense, dense); + + B(0,0) = 1; + B(1,1) = 4; + + C(0,1) = 2; + C(1,0) = 3; + + A = B + C; + + ASSERT_EQ((double) A(0,0), 1); + ASSERT_EQ((double) A(0,1), 2); + ASSERT_EQ((double) A(1,0), 3); + ASSERT_EQ((double) A(1,1), 4); +} + +TEST(linalg, matvec_mul) { + Vector x("x", 2, dense); + Vector b("b", 2, dense); + Matrix A("A", 2, 2, dense, dense); + + b(0) = 2; + b(1) = 1; + + A(0,0) = 1; + A(0,1) = 3; + A(1,1) = 2; + + x = A*b; + + ASSERT_EQ((double) x(0), 5); + ASSERT_EQ((double) x(1), 2); +} + +TEST(linalg, vecmat_mul) { + Vector x("x", 2, dense, false); + Vector b("b", 2, dense, false); + Matrix A("A", 2, 2, dense, dense); + + b(0) = 3; + b(1) = -2; + + A(0,0) = 5; + A(0,1) = 2; + A(1,0) = -1; + + x = b * A; + + ASSERT_EQ((double) x(0), 17); + ASSERT_EQ((double) x(1), 6); +} + +TEST(linalg, inner_mul) { + Scalar x("x"); + Vector b("b", 2, dense, false); + Vector a("a", 2, dense, true); + + b(0) = 2; + b(1) = 3; + + a(0) = -3; + a(1) = 5; + + x = b * a; + + ASSERT_EQ((double) x, 9); +} + +TEST(linalg, outer_mul) { + Matrix X("X", 2, 2, dense, dense); + Vector b("b", 2, dense, false); + Vector a("a", 2, dense, true); + + b(0) = 2; + b(1) = 3; + + a(0) = -3; + a(1) = 5; + + X = a * b; + + // Tensor API equivalent + Tensor tX("X", {2, 2}, dense); + Tensor tb("b", {2}, dense); + Tensor ta("a", {2}, dense); + + tb(0) = 2; + tb(1) = 3; + + ta(0) = -3; + ta(1) = 5; + + IndexVar i,j; + tX(i,j) = a(i) * b(j); + + ASSERT_TENSOR_EQ(X,tX); +} + +TEST(linalg, rowvec_transpose) { + Vector b("b", 2, dense, true); + Matrix A("A", 2, 2, dense, dense); + Scalar a("a"); + + b(0) = 2; + b(1) = 5; + + A(0,0) = 1; + A(0,1) = 2; + A(1,1) = 4; + + a = transpose(transpose(b) * A * b); + + ASSERT_EQ((double) a, 124); +} + +TEST(linalg, compound_expr_elemmul_elemadd) { + Matrix A("A", 2, 2, dense, dense); + Matrix B("B", 2, 2, dense, dense); + Matrix C("C", 2, 2, dense, dense); + Matrix D("D", 2, 2, dense, dense); + + A(0,0) = 1; + A(0,1) = 2; + A(0,2) = 3; + + D(0,0) = 2; + D(0,1) = 3; + D(0,2) = 4; + + A = elemMul(B+C, D); + + // Tensor API equivalent + Tensor tA("A", {2,2}, dense); + Tensor tB("B", {2,2}, dense); + Tensor tC("C", {2,2}, dense); + Tensor tD("D", {2,2}, dense); + + tA(0,0) = 1; + tA(0,1) = 2; + tA(0,2) = 3; + + tD(0,0) = 2; + tD(0,1) = 3; + tD(0,2) = 4; + + IndexVar i,j; + tA(i,j) = (tB(i,j) + tC(i,j)) * tD(i,j); + + ASSERT_TENSOR_EQ(A,tA); +} + +TEST(linalg, compound_sparse_matmul_transpose_outer) { + Matrix A("A", 16, 16, dense, sparse); + Matrix B("B", 16, 16, dense, sparse); + Matrix C("C", 16, 16, dense, sparse); + Matrix D("D", 16, 16, dense, dense); + Vector e("e", 16, sparse); + Vector f("f", 16, sparse); + + A(0,0) = 1; + A(0,1) = 2; + A(0,2) = 3; + B(0,0) = 1; + B(1,1) = 2; + B(2,2) = 3; + C(0, 0) = 8; + D(0,0) = 2; + D(0,1) = 3; + D(0,2) = 4; + + e(0) = 43; + f(1) = 2; + A = ((B*C)*D) + transpose(e*transpose(f)); + + // Tensor API equivalent + Tensor tA("tA", {16,16}, {dense, sparse}); + Tensor tB("tB", {16,16}, {dense, sparse}); + Tensor tC("tC", {16,16}, {dense, sparse}); + Tensor tD("tD", {16,16}, dense); + Tensor te("te", {16}, {sparse}); + Tensor tf("tf", {16}, {sparse}); + tA(0,0) = 1; + tA(0,1) = 2; + tA(0,2) = 3; + tB(0,0) = 1; + tB(1,1) = 2; + tB(2,2) = 3; + tC(0, 0) = 8; + tD(0,0) = 2; + tD(0,1) = 3; + tD(0,2) = 4; + + te(0) = 43; + tf(1) = 2; + IndexVar i,j, k, l; + tA(i,j) = ((tB(i,k) * tC(k,l)) * tD(l,j)) + (te(j)*tf(i)); + + ASSERT_TENSOR_EQ(tA, A); +} + +TEST(linalg, compound_ATCA) { + Matrix A("A", 16, 16, dense, dense); + Matrix B("B", 16, 16, dense, dense); + Matrix C("C", 16, 16, dense, dense); + + for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + C(i, j) = i*j; + } + } + + for (int i = 0; i < 16; i++) { + A(i, i) = i; + } + + B = (transpose(A) * C) * A; + + // Tensor API equivalent + Tensor tA("tA", {16,16}, {dense, dense}); + Tensor tB("tB", {16,16}, {dense, dense}); + Tensor tC("tC", {16,16}, {dense, dense}); + + for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + tC(i, j) = i*j; + } + } + + for (int i = 0; i < 16; i++) { + tA(i, i) = i; + } + + IndexVar i, j, k, l; + tB(i, j) = (tA(k, i) * tC(k, l)) * tA(l, j); + + ASSERT_TENSOR_EQ(tB, B); +} + +TEST(linalg, print) { + Matrix A("A", 16, 16, dense, dense); + Matrix B("B", 16, 16, dense, dense); + Matrix C("C", 16, 16, dense, dense); + + for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + C(i, j) = i*j; + } + } + + for (int i = 0; i < 16; i++) { + A(i, i) = i; + } + + B = (transpose(A) * C) * A; + + std::stringstream linalgBuffer; + linalgBuffer << B << endl; + for (int i = 1; i < 16; i++) { + for (int j = 1; j < 16; j++) { + linalgBuffer << i << ", " << j << ": "; + linalgBuffer << B(i,j) << endl; + } + } + linalgBuffer << B << endl; + + // Tensor API equivalent + Tensor tA("A", {16,16}, {dense, dense}); + Tensor tB("B", {16,16}, {dense, dense}); + Tensor tC("C", {16,16}, {dense, dense}); + + for (int i = 0; i < 16; i++) { + for (int j = 0; j < 16; j++) { + tC(i, j) = i*j; + } + } + + for (int i = 0; i < 16; i++) { + tA(i, i) = i; + } + IndexVar i, j, k, l; + tB(i, j) = (tA(k, i) * tC(k, l)) * tA(l, j); + + std::stringstream tensorBuffer; + tensorBuffer << tB << endl; + for (int i = 1; i < 16; i++) { + for (int j = 1; j < 16; j++) { + tensorBuffer << i << ", " << j << ": "; + tensorBuffer << tB(i,j) << endl; + } + } + tensorBuffer << tB << endl; + ASSERT_EQ(tensorBuffer.str(), linalgBuffer.str()); +} + +TEST(linalg, matrix_constructors) { + Matrix A("A"); + Matrix B("B", {2, 2}); + Matrix C("C", 2, 2, dense, dense); + Matrix D("D", 2, 2); + Matrix E("E", 2, 2, {dense, dense}); + Matrix F("F", {2, 2}, {dense, dense}); + + Vector a("a"); + Vector b("b", 2, false); + Vector c("c", 2, dense); + Vector d("d", 2, {dense}); +} + +TEST(linalg, reassignment) { + Matrix A("A", {2,2}); + Matrix B1("B1", {2,2}); + Matrix B2("B2", {2,2}); + Matrix B3("B3", {2,2}); + Matrix C1("C1", {2,2}); + Matrix C2("C2", {2,2}); + Matrix C3("C3", {2,2}); + + B1(0,0) = 1; + B1(0,1) = 2; + B1(1,0) = 3; + B1(1,1) = 4; + C1(0,0) = 1; + C1(0,1) = 2; + C1(1,0) = 3; + C1(1,1) = 4; + + A = B1 * C1; + + ASSERT_EQ((double) A(0,0), 7); + ASSERT_EQ((double) A(0,1), 10); + ASSERT_EQ((double) A(1,0), 15); + ASSERT_EQ((double) A(1,1), 22); + + B2(0,0) = 2; + B2(0,1) = 1; + B2(1,0) = 4; + B2(1,1) = 3; + C2(0,0) = 2; + C2(0,1) = 1; + C2(1,0) = 4; + C2(1,1) = 3; + + IndexVar i,j,k; + A(i,j) = B2(i,k) * C2(k,j); + + ASSERT_EQ((double) A(0,0), 8); + ASSERT_EQ((double) A(0,1), 5); + ASSERT_EQ((double) A(1,0), 20); + ASSERT_EQ((double) A(1,1), 13); + + B3(0,0) = 2; + B3(0,1) = 1; + B3(1,0) = 5; + B3(1,1) = 3; + C3(0,0) = 2; + C3(0,1) = 1; + C3(1,0) = 5; + C3(1,1) = 3; + + A = B3 * C3; + + ASSERT_EQ((double) A(0,0), 9); + ASSERT_EQ((double) A(0,1), 5); + ASSERT_EQ((double) A(1,0), 25); + ASSERT_EQ((double) A(1,1), 14); +} + +TEST(linalg, tensor_comparison) { + Matrix A("A", {2,2}); + Tensor B("B", {2,2}); + + A(0,0) = 1; + A(1,1) = 1; + + B(0,0) = 1; + B(1,1) = 1; + + ASSERT_TENSOR_EQ(A,B); + + Vector a("a", 2); + Tensor ta("ta", {2}); + + a(0) = 1; + ta(0) = 1; + + ASSERT_TENSOR_EQ(a,ta); +} + +TEST(linalg, scalar_assignment) { + Scalar x("x"); + Scalar y("y"); + Scalar z("z"); + x = 1; + y = x; + z = 1; + ASSERT_TENSOR_EQ(x,y); + ASSERT_TENSOR_EQ(x,z); +} + +TEST(linalg, scalar_coeff_vector) { + Scalar x("x"); + x = 2; + Vector y("y", 5); + for(int i=0;i<5;i++) { + y(i) = i; + } + + Vector z("z", 5); + for(int i=0;i<5;i++) { + z(i) = 2*i; + } + + Vector xy("xy", 5); + xy = x * y; + + ASSERT_TENSOR_EQ(xy,z); +} + +TEST(linalg, scalar_coeff_matrix) { + Scalar x("x"); + x = 2; + Matrix A("A", 5,5); + Matrix xA("xA", 5,5); + + A(0,0) = 1; + A(2,3) = 2; + A(4,0) = 3; + + xA = x * A; + + Matrix B("B", 5,5); + + B(0,0) = 2; + B(2,3) = 4; + B(4,0) = 6; + + ASSERT_TENSOR_EQ(xA,B); +} + +TEST(linalg, compound_scalar_expr) { + Matrix A("A", {3,3}); + Matrix B("B", {3,3}); + Vector x("x", 3); + Vector y("y", 3, false); + Scalar a("a"); + Scalar b("b"); + Scalar c("c"); + + a = 2; + b = 3; + c = 4; + B(0,0) = 2; + B(0,1) = 3; + B(1,2) = 4; + x(0) = 2; + x(1) = 5; + x(2) = 1; + y(0) = 4; + y(1) = 5; + y(2) = 1; + + A = a*B + b*x*y*c; + + // Tensor API equivalent + Tensor tA("tA", {3,3}); + Tensor tB("tB", {3,3}); + Tensor tx("tx", {3}); + Tensor ty("ty", {3}); + Tensor ta(2); + Tensor tb(3); + Tensor tc(4); + + tB(0,0) = 2; + tB(0,1) = 3; + tB(1,2) = 4; + tx(0) = 2; + tx(1) = 5; + tx(2) = 1; + ty(0) = 4; + ty(1) = 5; + ty(2) = 1; + + IndexVar i,j; + tA(i,j) = ta() * tB(i,j) + tb() * tx(i) * ty(j) * tc(); + + ASSERT_TENSOR_EQ(A,tA); +} diff --git a/tools/taco.cpp b/tools/taco.cpp index fcc654e08..f492b8ccf 100644 --- a/tools/taco.cpp +++ b/tools/taco.cpp @@ -11,6 +11,7 @@ #include "taco/error.h" #include "taco/parser/lexer.h" #include "taco/parser/parser.h" +#include "taco/parser/linalg_parser.h" #include "taco/parser/schedule_parser.h" #include "taco/storage/storage.h" #include "taco/ir/ir.h" @@ -193,6 +194,15 @@ static void printUsageInfo() { printFlag("nthreads", "Specify number of threads for parallel execution"); cout << endl; printFlag("prefix", "Specify a prefix for generated function names"); + cout << endl; + printFlag("linalg", "Specify if the input should be in Linear Algebra (not index) Notation"); + cout << endl; + printFlag("k=:,", + "[LINALG NOTATION ONLY -linalg] Specify the shape of the linear algebra var. " + "Specify the number of dimensions, shape (0, 1, or 2), and an optional is col vec" + "flag when order == 1 (1 or 0). " + "Examples: A:2, A:0, A:1,1, A:1,0"); + cout << endl; } static int reportError(string errorMessage, int errorCode) { @@ -212,7 +222,7 @@ static void printCommandLine(ostream& os, int argc, char* argv[]) { } } -static bool setSchedulingCommands(vector> scheduleCommands, parser::Parser& parser, IndexStmt& stmt) { +static bool setSchedulingCommands(vector> scheduleCommands, parser::AbstractParser& parser, IndexStmt& stmt) { auto findVar = [&stmt](string name) { ProvenanceGraph graph(stmt); for (auto v : graph.getAllIndexVars()) { @@ -493,6 +503,7 @@ int main(int argc, char* argv[]) { bool cuda = false; bool setSchedule = false; + bool linalg = false; ParallelSchedule sched = ParallelSchedule::Static; int chunkSize = 0; @@ -521,6 +532,9 @@ int main(int argc, char* argv[]) { string writeTimeFilename; vector declaredTensors; + map linalgShapes; + map linalgVecShapes; + vector kernelFilenames; vector> scheduleCommands; @@ -823,6 +837,33 @@ int main(int argc, char* argv[]) { else if ("-prefix" == argName) { prefix = argValue; } + else if ("-linalg" == argName) { + linalg = true; + } + else if ("-k" == argName) { + vector descriptor = util::split(argValue, ":"); + string tensorName = descriptor[0]; + vector shapes = util::split(descriptor[1], ","); + + int linalgShape = 0; + bool linalgVecShape = false; + if (shapes.size() == 1) { + linalgShape = std::stoi(shapes[0]); + taco_uassert(linalgShape >= 0 && linalgShape <= 2) << "Shape is not compatible with linalg notation" << endl; + if (linalgShape == 1) + linalgVecShape = true; + } else if (shapes.size() == 2) { + linalgShape = std::stoi(shapes[0]); + taco_uassert(linalgShape >= 0 && linalgShape <= 2) << "Shape is not compatible with linalg notation" << endl; + linalgVecShape = (bool) std::stoi(shapes[1]); + taco_uassert(linalgVecShape == 0 || linalgVecShape == 1) << "Vector type is not compatible with linalg notation" << endl; + if (linalgShape != 1 ) { + linalgVecShape = false; + } + } + linalgShapes.insert({tensorName, linalgShape}); + linalgVecShapes.insert({tensorName, linalgVecShape}); + } else { if (exprStr.size() != 0) { printUsageInfo(); @@ -872,17 +913,22 @@ int main(int argc, char* argv[]) { } TensorBase tensor; - parser::Parser parser(exprStr, formats, dataTypes, tensorsDimensions, loadedTensors, 42); + parser::AbstractParser *parser; + if (linalg) + parser = new parser::LinalgParser(exprStr, formats, dataTypes, tensorsDimensions, loadedTensors, linalgShapes, linalgVecShapes, 42); + else + parser = new parser::Parser(exprStr, formats, dataTypes, tensorsDimensions, loadedTensors, 42); + try { - parser.parse(); - tensor = parser.getResultTensor(); + parser->parse(); + tensor = parser->getResultTensor(); } catch (parser::ParseError& e) { return reportError(e.getMessage(), 6); } // Generate tensors for (auto& fills : tensorsFill) { - TensorBase tensor = parser.getTensor(fills.first); + TensorBase tensor = parser->getTensor(fills.first); util::fillTensor(tensor,fills.second); loadedTensors.insert({fills.first, tensor}); @@ -894,8 +940,8 @@ int main(int argc, char* argv[]) { // If all input tensors have been initialized then we should evaluate bool benchmark = true; - for (auto& tensor : parser.getTensors()) { - if (tensor.second == parser.getResultTensor()) { + for (auto& tensor : parser->getTensors()) { + if (tensor.second == parser->getResultTensor()) { continue; } if (!util::contains(loadedTensors, tensor.second.getName())) { @@ -915,7 +961,7 @@ int main(int argc, char* argv[]) { stmt = reorderLoopsTopologically(stmt); if (setSchedule) { - cuda |= setSchedulingCommands(scheduleCommands, parser, stmt); + cuda |= setSchedulingCommands(scheduleCommands, *parser, stmt); } else { stmt = insertTemporaries(stmt); @@ -980,8 +1026,8 @@ int main(int argc, char* argv[]) { // TODO: Replace this redundant parsing with just a call to set the expr try { - auto operands = parser.getTensors(); - operands.erase(parser.getResultTensor().getName()); + auto operands = parser->getTensors(); + operands.erase(parser->getResultTensor().getName()); parser::Parser parser2(exprStr, formats, dataTypes, tensorsDimensions, operands, 42); parser2.parse(); @@ -1252,7 +1298,7 @@ int main(int argc, char* argv[]) { write(outputFileName, FileType::tns, tensor); TensorBase paramTensor; for (const auto &fills : tensorsFill ) { - paramTensor = parser.getTensor(fills.first); + paramTensor = parser->getTensor(fills.first); outputFileName = outputDirectory + "/" + paramTensor.getName() + ".tns"; write(outputFileName, FileType::tns, paramTensor); }