Skip to content

Commit

Permalink
Add parser for new IR
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiafi committed Feb 5, 2025
1 parent 5f31555 commit bf82af6
Show file tree
Hide file tree
Showing 10 changed files with 605 additions and 0 deletions.
120 changes: 120 additions & 0 deletions core/trino-grammar/src/main/antlr4/io/trino/grammar/newir/NewIr.g4
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

grammar NewIr;

tokens {
DELIMITER
}

program
: IR VERSION EQ version=INTEGER_VALUE
operation EOF
;

operation
: resultName=VALUE_NAME EQ operationName
'(' (argumentNames+=VALUE_NAME (',' argumentNames+=VALUE_NAME)*)? ')'
':' '(' (argumentTypes+=type (',' argumentTypes+=type)*)? ')'
'->' resultType=type
'(' (region (',' region)*)? ')'
('{' (attribute (',' attribute)*)? '}')? // does not roundtrip: we don't print empty attributes list // TODO test
;

region
: '{' block+ '}'
;

block
: BLOCK_NAME?
('(' blockParameter (',' blockParameter)* ')')?
operation+
;

blockParameter
: VALUE_NAME ':' type
;

attribute
: attributeName EQ STRING
;

identifier
: IDENTIFIER
| nonReserved
;

dialectName
: identifier
;

operationName
: (dialectName '.')? identifier
;

attributeName
: (dialectName '.')? identifier
;

type
: (dialectName '.')? STRING
;

nonReserved
: IR | VERSION
;

IR: 'IR';
VERSION: 'version';

EQ: '=';

STRING
: '"' ( ~'"' | '""' )* '"'
;

VALUE_NAME
: '%' PREFIXED_IDENTIFIER
;

BLOCK_NAME
: '^' PREFIXED_IDENTIFIER
;

INTEGER_VALUE
: DIGIT+
;

IDENTIFIER
: (LETTER | '_') (LETTER | DIGIT | '_')*
;

PREFIXED_IDENTIFIER
: (LETTER | DIGIT | '_')+
;

fragment DIGIT
: [0-9]
;

fragment LETTER
: [a-z] | [A-Z]
;

WS
: [ \r\n\t]+ -> channel(HIDDEN)
;

// Catch-all for anything we can't recognize.
UNRECOGNIZED: .;
123 changes: 123 additions & 0 deletions core/trino-parser/src/main/java/io/trino/newir/ProgramParser.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.newir;

import io.trino.grammar.newir.NewIrBaseListener;
import io.trino.grammar.newir.NewIrLexer;
import io.trino.grammar.newir.NewIrParser;
import io.trino.newir.tree.ProgramNode;
import io.trino.sql.parser.ParsingException;
import io.trino.sql.tree.NodeLocation;
import org.antlr.v4.runtime.ANTLRErrorListener;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonToken;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.ParserRuleContext;
import org.antlr.v4.runtime.RecognitionException;
import org.antlr.v4.runtime.Recognizer;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.atn.PredictionMode;
import org.antlr.v4.runtime.misc.Pair;
import org.antlr.v4.runtime.tree.TerminalNode;

import java.util.Arrays;
import java.util.List;

public class ProgramParser
{
private static final ANTLRErrorListener ERROR_LISTENER = new BaseErrorListener()
{
@Override
public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line, int charPositionInLine, String message, RecognitionException e)
{
throw new ParsingException("parsing of the program failed: " + message, e, line, charPositionInLine + 1);
}
};

private ProgramParser() {}

public static ProgramNode parseProgram(String program)
{
try {
NewIrLexer lexer = new NewIrLexer(CharStreams.fromString(program));
CommonTokenStream tokenStream = new CommonTokenStream(lexer);
NewIrParser parser = new NewIrParser(tokenStream);

parser.addParseListener(new PostProcessor(Arrays.asList(parser.getRuleNames()), parser));

lexer.removeErrorListeners();
lexer.addErrorListener(ERROR_LISTENER);

parser.removeErrorListeners();
parser.addErrorListener(ERROR_LISTENER);

ParserRuleContext tree;
try {
// first, try parsing with potentially faster SLL mode
parser.getInterpreter().setPredictionMode(PredictionMode.SLL);
tree = parser.program();
}
catch (ParsingException ex) {
// if we fail, parse with LL mode
tokenStream.seek(0); // rewind input stream
parser.reset();

parser.getInterpreter().setPredictionMode(PredictionMode.LL);
tree = parser.program();
}

return (ProgramNode) new ProgramTreeBuilder().visit(tree);
}
catch (StackOverflowError e) {
throw new ParsingException("stack overflow while parsing the program: ", new NodeLocation(1, 1));
}
}

private static class PostProcessor
extends NewIrBaseListener
{
private final List<String> ruleNames;
private final NewIrParser parser;

public PostProcessor(List<String> ruleNames, NewIrParser parser)
{
this.ruleNames = ruleNames;
this.parser = parser;
}

@Override
public void exitNonReserved(NewIrParser.NonReservedContext context)
{
// only a terminal can be replaced during rule exit event handling. Make sure that the nonReserved item is a token
if (!(context.getChild(0) instanceof TerminalNode)) {
int rule = ((ParserRuleContext) context.getChild(0)).getRuleIndex();
throw new AssertionError("nonReserved can only contain tokens. Found nested rule: " + ruleNames.get(rule));
}

// replace nonReserved keyword with IDENTIFIER token
context.getParent().removeLastChild();

Token token = (Token) context.getChild(0).getPayload();
Token newToken = new CommonToken(
new Pair<>(token.getTokenSource(), token.getInputStream()),
NewIrLexer.IDENTIFIER,
token.getChannel(),
token.getStartIndex(),
token.getStopIndex());

context.getParent().addChild(parser.createTerminalNode(context.getParent(), newToken));
}
}
}
128 changes: 128 additions & 0 deletions core/trino-parser/src/main/java/io/trino/newir/ProgramTreeBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.newir;

import io.trino.grammar.newir.NewIrBaseVisitor;
import io.trino.grammar.newir.NewIrParser;
import io.trino.newir.tree.AttributeNode;
import io.trino.newir.tree.BlockNode;
import io.trino.newir.tree.NewIrNode;
import io.trino.newir.tree.OperationNode;
import io.trino.newir.tree.ProgramNode;
import io.trino.newir.tree.RegionNode;
import io.trino.newir.tree.TypeNode;
import io.trino.sql.parser.ParsingException;
import org.antlr.v4.runtime.RuleContext;
import org.antlr.v4.runtime.Token;
import org.antlr.v4.runtime.tree.ParseTree;

import java.util.Optional;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Integer.parseInt;

public class ProgramTreeBuilder
extends NewIrBaseVisitor<NewIrNode>
{
@Override
public ProgramNode visitProgram(NewIrParser.ProgramContext context)
{
int version;
String versionAsText = context.version.getText();
if (versionAsText.startsWith("0")) {
throw new ParsingException("invalid version. starts with '0': " + versionAsText, null, context.version.getLine(), context.version.getCharPositionInLine() + 1);
}
try {
version = parseInt(versionAsText);
}
catch (NumberFormatException e) {
throw new ParsingException("invalid version. not an integer: " + versionAsText, null, context.version.getLine(), context.version.getCharPositionInLine() + 1);
}

OperationNode root = (OperationNode) visit(context.operation());

return new ProgramNode(version, root);
}

@Override
public NewIrNode visitOperation(NewIrParser.OperationContext context)
{
return new OperationNode(
Optional.ofNullable(context.operationName().dialectName()).map(RuleContext::getText),
context.operationName().identifier().getText(),
context.resultName.getText(),
(TypeNode) visit(context.resultType),
context.argumentNames.stream()
.map(Token::getText)
.collect(toImmutableList()),
context.argumentTypes.stream()
.map(this::visit)
.map(TypeNode.class::cast)
.collect(toImmutableList()),
context.region().stream()
.map(this::visit)
.map(RegionNode.class::cast)
.collect(toImmutableList()),
context.attribute().stream()
.map(this::visit)
.map(AttributeNode.class::cast)
.collect(toImmutableList()));
}

@Override
public NewIrNode visitType(NewIrParser.TypeContext context)
{
return new TypeNode(
Optional.ofNullable(context.dialectName()).map(RuleContext::getText),
context.STRING().getText());
}

@Override
public NewIrNode visitRegion(NewIrParser.RegionContext context)
{
return new RegionNode(context.block().stream()
.map(this::visit)
.map(BlockNode.class::cast)
.collect(toImmutableList()));
}

@Override
public NewIrNode visitBlock(NewIrParser.BlockContext context)
{
return new BlockNode(
Optional.ofNullable(context.BLOCK_NAME()).map(ParseTree::getText),
context.blockParameter().stream()
.map(NewIrParser.BlockParameterContext::VALUE_NAME)
.map(ParseTree::getText)
.collect(toImmutableList()),
context.blockParameter().stream()
.map(NewIrParser.BlockParameterContext::type)
.map(this::visit)
.map(TypeNode.class::cast)
.collect(toImmutableList()),
context.operation().stream()
.map(this::visit)
.map(OperationNode.class::cast)
.collect(toImmutableList()));
}

@Override
public NewIrNode visitAttribute(NewIrParser.AttributeContext context)
{
return new AttributeNode(
Optional.ofNullable(context.attributeName().dialectName()).map(RuleContext::getText),
context.attributeName().identifier().getText(),
context.STRING().getText());
}
}
Loading

0 comments on commit bf82af6

Please sign in to comment.