Skip to content

Commit

Permalink
Merge pull request #2 from CBalaa/fegen-chh
Browse files Browse the repository at this point in the history
[FrontendGen]  Add fegen manager data structs and fix antlr4 dependency.
  • Loading branch information
CBalaa authored Jun 14, 2024
2 parents 73ac0a7 + 9918525 commit 329930b
Show file tree
Hide file tree
Showing 11 changed files with 1,485 additions and 195 deletions.
4 changes: 1 addition & 3 deletions examples/FrontendGen/.gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
Toy.g4
MLIRToyVisitor.h

test/
2 changes: 1 addition & 1 deletion examples/FrontendGen/example.fegen
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ typedef struct {
parameters [list<Type> elementTypes]
}

Type Toy_Type = any<[tensor, struct]>;
Type Toy_Type = any<[Tensor, struct]>;

opdef constant {
arguments [attribute double numberAttr]
Expand Down
264 changes: 217 additions & 47 deletions frontend/FrontendGen/include/FegenManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,45 @@

#include "FegenParser.h"

#define FEGEN_PLACEHOLDER "Placeholder"
#define FEGEN_TYPE "Type"
#define FEGEN_TYPETEMPLATE "TypeTemplate"
#define FEGEN_INTEGER "Integer"
#define FEGEN_FLOATPOINT "FloatPoint"
#define FEGEN_CHAR "Char"
#define FEGEN_STRING "String"
#define FEGEN_VECTOR "Vector"
#define FEGEN_TENSOR "Tensor"
#define FEGEN_LIST "List"
#define FEGEN_OPTINAL "Optional"
#define FEGEN_ANY "Any"

namespace fegen {

class FegenType;
class FegenManager;

// binary operation

enum class FegenOperator {
OR,
AND,
EQUAL,
NOT_EQUAL,
LESS,
LESS_EQUAL,
GREATER,
GREATER_EQUAL,
ADD,
SUB,
MUL,
DIV,
MOD,
POWER,
NEG,
NOT
};

// user defined function
class FegenFunction {
private:
Expand Down Expand Up @@ -61,6 +95,8 @@ class FegenOperation {
~FegenOperation() = default;
};

class FegenTypeDefination;

class FegenType {
friend class FegenValue;

Expand All @@ -69,16 +105,27 @@ class FegenType {

private:
TypeKind kind;
std::string dialectName;
std::string typeName;
std::vector<FegenValue> parameters;
std::vector<FegenValue*> parameters;
FegenTypeDefination *typeDefine;
bool ifTemplate;

public:
FegenType(TypeKind kind, std::string dialectName, std::string typeName,
std::vector<FegenValue> parameters);
FegenType(TypeKind kind, std::vector<FegenValue*> parameters,
FegenTypeDefination *tyDef, bool isTemplate);
FegenType(const FegenType &);
FegenType(FegenType &&);

TypeKind getTypeKind();
void setTypeKind(TypeKind kind);
std::vector<FegenValue*> &getParameters();
void setParameters(std::vector<FegenValue*> &params);
FegenTypeDefination *getTypeDefination();
void setTypeDefination(FegenTypeDefination *tyDef);
std::string getTypeName();
bool isTemplate();
~FegenType();
// placeholder
static FegenType getPlaceHolder();
// Type
static FegenType getMetaType();

Expand All @@ -98,10 +145,10 @@ class FegenType {
static FegenType getBoolType();

// Integer<size>
static FegenType getIntegerType(FegenValue size);
static FegenType getIntegerType(FegenValue* size);

// FloatPoint<size>
static FegenType getFloatPointType(FegenValue size);
static FegenType getFloatPointType(FegenValue* size);

// char
static FegenType getCharType();
Expand All @@ -110,49 +157,159 @@ class FegenType {
static FegenType getStringType();

// Vector<size, elementType>
static FegenType getVectorType(FegenValue size, FegenValue elementType);
static FegenType getVectorType(FegenValue size, FegenType elementType);
static FegenType getVectorType(FegenValue* size, FegenValue* elementType);
static FegenType getVectorType(FegenValue* size, FegenType elementType);

// Tensor<shape, elementType>
static FegenType getTensorType(FegenValue shape, FegenValue elementType);
static FegenType getTensorType(FegenValue shape, FegenType elementType);
static FegenType getTensorType(FegenValue* shape, FegenValue* elementType);
static FegenType getTensorType(FegenValue* shape, FegenType elementType);

// List<elementType>
static FegenType getListType(FegenValue* elementType);
static FegenType getListType(FegenType elementType);

// Optional<elementType>
static FegenType getOptionalType(FegenValue* elementType);
static FegenType getOptionalType(FegenType elementType);

// Any<elementType1, elementType2, ...>
static FegenType getAnyType(std::vector<FegenValue*> elementTypes);
static FegenType getAnyType(std::vector<FegenType> elementTypes);

static FegenType getIntegerTemplate();
static FegenType getFloatPointTemplate();
static FegenType getVectorTemplate();
static FegenType getTensorTemplate();
static FegenType getListTemplate();
static FegenType getOptionalTemplate();
static FegenType getAnyTemplate();

static FegenType getInstanceType(FegenTypeDefination *typeDefination,
std::vector<FegenValue*> parameters);

static FegenType getTemplateType(FegenTypeDefination *typeDefination);
};

class FegenTypeDefination {
friend class FegenManager;

// static FegenType get(TypeKind kind, std::string dialectName,
// std::string typeName,
// std::vector<FegenValue> parameters);
private:
std::string dialectName;
std::string name;
std::vector<fegen::FegenValue*> parameters;
FegenParser::TypeDefinationDeclContext *ctx;
bool ifCustome;

~FegenType() = default;
public:
FegenTypeDefination(std::string dialectName, std::string name,
std::vector<fegen::FegenValue*> parameters,
FegenParser::TypeDefinationDeclContext *ctx,
bool ifCustome);
static FegenTypeDefination *get(std::string dialectName, std::string name,
std::vector<fegen::FegenValue*> parameters,
FegenParser::TypeDefinationDeclContext *ctx,
bool ifCustome = true);
std::string getDialectName();
std::string getName();
const std::vector<fegen::FegenValue*> &getParameters();
FegenParser::TypeDefinationDeclContext *getCtx();
bool isCustome();
};

class FegenLiteral {
/// @brief Represent right value, and pass by value.
class FegenRightValue {
friend class FegenType;
friend class FegenValue;
using literalType = std::variant<int, float, std::string, FegenType,
std::vector<FegenLiteral>>;

private:
literalType content;
public:
enum class LiteralKind {
MONOSTATE,
INT,
FLOAT,
STRING,
TYPE,
VECTOR,
EXPRESSION,
LEFT_VAR
};

struct Expression {
bool ifTerminal;
LiteralKind kind;
FegenType exprType;
Expression(bool, LiteralKind, FegenType&);
virtual ~Expression() = default;
virtual bool isTerminal();
virtual std::string toString() = 0;
LiteralKind getKind();
virtual std::any getContent() = 0;
};

struct ExpressionNode : public Expression {
using opType =
std::variant<FegenFunction *, FegenOperation *, FegenOperator>;
opType op;
std::vector<Expression*> params;
ExpressionNode(std::vector<Expression*>, opType, FegenType&);
ExpressionNode(ExpressionNode&)=default;
~ExpressionNode();
virtual std::string toString() override;
virtual std::any getContent() override;

/// @brief operate lhs and rhs using binary operator.
static ExpressionNode *binaryOperation(Expression *lhs, Expression *rhs,
FegenOperator op);
/// @brief operate expr using unary operator
static ExpressionNode *unaryOperation(Expression *, FegenOperator);

// TODO: callFunction
static ExpressionNode* callFunction(std::vector<Expression*>, FegenFunction*);

// TODO: callOperation
static ExpressionNode* callOperation(std::vector<Expression*>, FegenOperation*);
};

struct ExpressionTerminal : public Expression {
// monostate, int literal, float literal, string literal, type literal, list
// literal, reference of variable
using primLiteralType =
std::variant<std::monostate, int, float, std::string, FegenType,
std::vector<Expression*>, FegenValue *>;
primLiteralType content;
ExpressionTerminal(primLiteralType, LiteralKind, FegenType);
ExpressionTerminal(ExpressionTerminal&)=default;
~ExpressionTerminal();
virtual std::string toString() override;
virtual std::any getContent() override;
static ExpressionTerminal *get(std::monostate);
static ExpressionTerminal *get(int);
static ExpressionTerminal *get(float);
static ExpressionTerminal *get(std::string);
static ExpressionTerminal *get(FegenType &);
static ExpressionTerminal *get(std::vector<Expression*> &);
static ExpressionTerminal *get(fegen::FegenValue *);
};

public:
enum class LiteralKind { INT, FLOAT, STRING, TYPE, VECTOR };
FegenLiteral(literalType content);
FegenLiteral(const FegenLiteral &);
FegenLiteral(FegenLiteral &&);
static FegenLiteral get(int content);
static FegenLiteral get(float content);
static FegenLiteral get(std::string content);
static FegenLiteral get(FegenType content);

/// @brief receive vector of number string, FegenType or vector and build it
/// to FegenLiteral
/// @tparam T element type, should be one of int, float, std::string,
/// FegenType or std::vector
template <typename T> static FegenLiteral get(std::vector<T> content);

template <typename T> T getContent() { return std::get<T>(this->content); }
FegenRightValue(Expression *content);
FegenRightValue(const FegenRightValue&);
FegenRightValue(FegenRightValue&&);
FegenRightValue::LiteralKind getKind();
std::string toString();
std::any getContent();

static FegenRightValue get();
static FegenRightValue get(int content);
static FegenRightValue get(float content);
static FegenRightValue get(std::string content);
static FegenRightValue get(FegenType& content);
static FegenRightValue get(std::vector<Expression*> & content);
static FegenRightValue get(fegen::FegenValue * content);
static FegenRightValue get(Expression* expr);
~FegenRightValue();

private:
LiteralKind kind;
Expression *content;
};

class FegenValue {
Expand All @@ -161,20 +318,22 @@ class FegenValue {
private:
FegenType type;
std::string name;
FegenLiteral content;
FegenRightValue content;

public:
FegenValue(FegenType type, std::string name, FegenLiteral content);
FegenValue(FegenType type, std::string name, FegenRightValue content);
FegenValue(const FegenValue &rhs);
FegenValue(FegenValue &&rhs);

static FegenValue *get(FegenType type, std::string name,
FegenLiteral constant);

llvm::StringRef getName();

template <typename T> T getContent() { return this->content.getContent<T>(); }

FegenRightValue constant);

std::string getName();
FegenType &getType();
/// @brief return content of right value, get ExprssionNode* if kind is EXPRESSION.
template <typename T> T getContent() { return std::any_cast<T>(this->content.getContent()); }
FegenRightValue::LiteralKind getContentKind();
std::string getContentString();
~FegenValue() = default;
};

Expand Down Expand Up @@ -234,20 +393,31 @@ class FegenManager {
friend class FegenVisitor;

private:
FegenManager();
FegenManager(const FegenManager &) = delete;
const FegenManager &operator=(const FegenManager &) = delete;
// release nodes, type, operation, function
~FegenManager();
std::string moduleName;
std::vector<std::string> headFiles;
std::map<std::string, FegenNode *> nodeMap;
llvm::StringMap<FegenType *> typeMap;
std::map<std::string, FegenTypeDefination *> typeDefMap;
llvm::StringMap<FegenOperation *> operationMap;
llvm::StringMap<FegenFunction *> functionMap;
void initbuiltinTypes();

public:
static FegenManager &getManager();
void setModuleName(std::string name);

FegenTypeDefination *getTypeDefination(std::string name);
bool addTypeDefination(FegenTypeDefination *tyDef);
std::string emitG4();
// release nodes, type, operation, function
~FegenManager();
};

FegenType inferenceType(std::vector<FegenRightValue::Expression*>, FegenOperator);

} // namespace fegen

#endif
Loading

0 comments on commit 329930b

Please sign in to comment.