Skip to content

Commit

Permalink
update code based on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tonghaining committed Feb 17, 2025
1 parent 27a92a3 commit fe99e77
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
import com.dat3m.dartagnan.parsers.SpirvBaseVisitor;
import com.dat3m.dartagnan.parsers.SpirvParser;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder;
import com.dat3m.dartagnan.program.event.Event;

import java.util.List;
import java.util.Set;

public class VisitorOpsComposite extends SpirvBaseVisitor<Event> {
public class VisitorOpsComposite extends SpirvBaseVisitor<Void> {

private final ProgramBuilder builder;

Expand All @@ -23,36 +22,37 @@ public VisitorOpsComposite(ProgramBuilder builder) {
}

@Override
public Event visitOpCompositeExtract(SpirvParser.OpCompositeExtractContext ctx) {
extractCompositeElement(ctx.idResult().getText(), ctx.idResultType().getText(),
ctx.composite().getText(), ctx.indexesLiteralInteger());
return null;
}

private void extractCompositeElement(String id, String typeId, String compositeId,
List<SpirvParser.IndexesLiteralIntegerContext> idxContexts) {
Expression compositeExpression = builder.getExpression(compositeId);
public Void visitOpCompositeExtract(SpirvParser.OpCompositeExtractContext ctx) {
String id = ctx.idResult().getText();
Expression compositeExpression = builder.getExpression(ctx.composite().getText());
if (!(compositeExpression instanceof ConstructExpr)) {
throw new ParsingException("Composite extraction is only supported for ConstructExpr");
}
Type type = builder.getType(typeId);
List<Integer> indexes = idxContexts.stream()
Type type = builder.getType(ctx.idResultType().getText());
List<Integer> indexes = ctx.indexesLiteralInteger().stream()
.map(SpirvParser.IndexesLiteralIntegerContext::getText)
.map(Integer::parseInt)
.toList();
Expression element = compositeExpression;
for (Integer index : indexes) {
element = element.getOperands().get(index);
if (!(element instanceof ConstructExpr)) {
throw new ParsingException("Element is not a ConstructExpr at index: " + index);
}
List<Expression> operands = element.getOperands();
if (index >= operands.size()) {
throw new ParsingException("Index out of bounds: " + index);
}
element = operands.get(index);
}
if (type.equals(element.getType())) {
builder.addExpression(id, element);
return;
return null;
}
if (type instanceof ScopedPointerType scopedPointerType) {
Type pointedType = scopedPointerType.getPointedType();
if (pointedType == element.getType() || TypeFactory.isStaticTypeOf(element.getType(), pointedType)) {
builder.addExpression(id, element);
return;
return null;
}
}
throw new ParsingException("Type mismatch in composite extraction: %s", id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ private Event visitLoopBranch(String labelId) {

private Event visitConditionalJump(Expression guard, String trueLabelId, String falseLabelId) {
if (cfBuilder.isBlockStarted(trueLabelId)) {
if (cfBuilder.isBlockStarted(falseLabelId)) {
throw new ParsingException("Unsupported conditional branch " +
"with two backward jumps to '%s' and '%s'", trueLabelId, falseLabelId);
}
String labelId = trueLabelId;
trueLabelId = falseLabelId;
falseLabelId = labelId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.expression.Type;
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.expression.type.ScopedPointerType;
import com.dat3m.dartagnan.parsers.SpirvBaseVisitor;
Expand All @@ -12,7 +14,6 @@
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.event.core.Local;
import com.dat3m.dartagnan.program.memory.ScopedPointerVariable;

import java.util.Set;

Expand All @@ -29,31 +30,34 @@ public VisitorOpsConversion(ProgramBuilder builder) {
public Void visitOpBitcast(SpirvParser.OpBitcastContext ctx) {
String id = ctx.idResult().getText();
String typeId = ctx.idResultType().getText();
if (builder.getType(typeId) instanceof ScopedPointerType pointerType) {
String operand = ctx.operand().getText();
Expression operandExpr = builder.getExpression(operand);
if (!(operandExpr instanceof ScopedPointerVariable scopedPointerVariable)) {
throw new ParsingException("Type '%s' is not a pointer type", operand);
}
if (!pointerType.getScopeId().equals(scopedPointerVariable.getScopeId())) {
throw new ParsingException("Storage class mismatch in OpBitcast between '%s' and '%s'", typeId, operand);
}
Expression convertedPointer = expressions.makeCast(operandExpr, pointerType.getPointedType());
Register reg = builder.addRegister(id, convertedPointer.getType());
builder.addEvent(new Local(reg, convertedPointer));
return null;
} else {
// TODO: Add support for scalar or vector of numerical-type bitcasts
throw new ParsingException("Type '%s' is not a pointer type", typeId);
String operand = ctx.operand().getText();
Type resultType = builder.getType(typeId);
Expression operandExpr = builder.getExpression(operand);
Type operandType = operandExpr.getType();

if (resultType instanceof ArrayType || operandType instanceof ArrayType ||
(operandType instanceof ScopedPointerType pointerType && pointerType.getPointedType() instanceof ArrayType)) {
// TODO: Support bitcast between arrays
throw new ParsingException("Bitcast between arrays is not supported for id '%s'", id);
}

if (resultType instanceof ScopedPointerType pointerType1 && operandType instanceof ScopedPointerType pointerType2
&& !(pointerType1.getScopeId().equals(pointerType2.getScopeId()))) {
throw new ParsingException("Storage class mismatch in OpBitcast between '%s' and '%s' for id '%s'", typeId, operand, id);
}

Expression convertedExpr = expressions.makeCast(operandExpr, resultType);
Register reg = builder.addRegister(id, convertedExpr.getType());
builder.addEvent(new Local(reg, convertedExpr));
return null;
}

@Override
public Void visitOpConvertPtrToU(SpirvParser.OpConvertPtrToUContext ctx) {
String id = ctx.idResult().getText();
String typeId = ctx.idResultType().getText();
if (!(builder.getType(typeId) instanceof IntegerType)) {
throw new ParsingException("Type '%s' is not an integer type", typeId);
throw new ParsingException("Type '%s' is not an integer type for id '%s'", typeId, id);
}
Expression pointerExpr = builder.getExpression(ctx.pointer().getText());
Expression convertedPointer = expressions.makeCast(pointerExpr, builder.getType(typeId), false);
Expand All @@ -67,23 +71,23 @@ public Void visitOpPtrCastToGeneric(SpirvParser.OpPtrCastToGenericContext ctx) {
String id = ctx.idResult().getText();
String typeId = ctx.idResultType().getText();
if (!(builder.getType(typeId) instanceof ScopedPointerType genericType)) {
throw new ParsingException("Type '%s' is not a pointer type", typeId);
throw new ParsingException("Type '%s' is not a pointer type for id '%s'", typeId, id);
}
if (!genericType.getScopeId().equals(Tag.Spirv.SC_GENERIC)) {
throw new ParsingException("Invalid storage class '%s' for OpPtrCastToGeneric", genericType.getScopeId());
throw new ParsingException("Invalid storage class '%s' for OpPtrCastToGeneric for id '%s'", genericType.getScopeId(), id);
}
String pointerId = ctx.pointer().getText();
Expression pointer = builder.getExpression(pointerId);
if (!(pointer.getType() instanceof ScopedPointerType pointerType)) {
throw new ParsingException("Type '%s' is not a pointer type", pointerId);
throw new ParsingException("Type '%s' is not a pointer type for id '%s'", pointerId, id);
}
String pointerSC = pointerType.getScopeId();
Set<String> supportedSC = Set.of(
Tag.Spirv.SC_CROSS_WORKGROUP,
Tag.Spirv.SC_WORKGROUP,
Tag.Spirv.SC_FUNCTION);
if (!supportedSC.contains(pointerSC)) {
throw new ParsingException("Invalid storage class '%s' for OpPtrCastToGeneric", pointerSC);
throw new ParsingException("Invalid storage class '%s' for OpPtrCastToGeneric for id '%s'", pointerSC, id);
}
Expression convertedExpr = expressions.makeCast(pointer, genericType);
Register reg = builder.addRegister(id, genericType);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.dat3m.dartagnan.parsers.program.visitors.spirv;

import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.expression.integers.IntLiteral;
Expand Down Expand Up @@ -169,6 +170,69 @@ public void doTestCompositeExtractRuntimeArray() {
assertEquals(array, compositeExtract1);
}

@Test
public void compositeExtractElementNotConstructExpr() {
String input = "%extract_value = OpCompositeExtract %uint %composite_value 0";

builder.mockFunctionStart(true);
builder.mockIntType("%uint", 32);
Expression nonConstructExpr = expressions.makeValue(1, (IntegerType) builder.getType("%uint"));
builder.addExpression("%composite_value", nonConstructExpr);

try {
visit(input);
fail("Should throw exception");
} catch (ParsingException e) {
assertEquals("Composite extraction is only supported for ConstructExpr", e.getMessage());
}
}

@Test
public void compositeExtractIndexOutOfBounds() {
String input = "%extract_value = OpCompositeExtract %uint %composite_value 5";

builder.mockFunctionStart(true);
builder.mockIntType("%uint", 32);
List<Expression> elements = List.of(
expressions.makeValue(1, (IntegerType) builder.getType("%uint")),
expressions.makeValue(2, (IntegerType) builder.getType("%uint")),
expressions.makeValue(3, (IntegerType) builder.getType("%uint")),
expressions.makeValue(4, (IntegerType) builder.getType("%uint"))
);
Expression array = expressions.makeArray(builder.getType("%uint"), elements, true);
builder.addExpression("%composite_value", array);

try {
visit(input);
fail("Should throw exception");
} catch (ParsingException e) {
assertEquals("Index out of bounds: 5", e.getMessage());
}
}

@Test
public void compositeExtractIndexTooDeep() {
String input = "%extract_value = OpCompositeExtract %uint %composite_value 0 0";

builder.mockFunctionStart(true);
builder.mockIntType("%uint", 32);
List<Expression> elements = List.of(
expressions.makeValue(1, (IntegerType) builder.getType("%uint")),
expressions.makeValue(2, (IntegerType) builder.getType("%uint")),
expressions.makeValue(3, (IntegerType) builder.getType("%uint")),
expressions.makeValue(4, (IntegerType) builder.getType("%uint"))
);
Expression array = expressions.makeArray(builder.getType("%uint"), elements, true);
builder.addExpression("%composite_value", array);

try {
visit(input);
fail("Should throw exception");
} catch (ParsingException e) {
assertEquals("Element is not a ConstructExpr at index: 0", e.getMessage());
}
}

private void visit(String input) {
new MockSpirvParser(input).spv().accept(new VisitorOpsComposite(builder));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,53 +1,81 @@
package com.dat3m.dartagnan.parsers.program.visitors.spirv;

import com.dat3m.dartagnan.expression.type.ScopedPointerType;
import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockProgramBuilder;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockSpirvParser;
import com.dat3m.dartagnan.program.Register;
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.core.Local;
import com.dat3m.dartagnan.program.memory.ScopedPointerVariable;
import org.junit.Test;

import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

public class VisitorOpsConversionTest {
private MockProgramBuilder builder = new MockProgramBuilder();

@Test
public void doTestOpsBitcastUint32ToUchar() {
// given
public void opBitcastValidPointerToPointer() {
builder.mockIntType("%uint", 32);
builder.mockIntType("%uchar", 8);
builder.mockPtrType("%_ptr_Function_uint", "%uint", "Function");
builder.mockPtrType("%_ptr_Function_uchar", "%uchar", "Function");
builder.mockVariable("%value1", "%_ptr_Function_uint");
String input = "%value2 = OpBitcast %_ptr_Function_uchar %value1";

// when
visit(input);
Expression reg = builder.getExpression("%value2");
assertEquals(builder.getType("%_ptr_Function_uchar"), reg.getType());
}

// then
ScopedPointerVariable value1 = (ScopedPointerVariable) builder.getExpression("%value1");
Register value2 = (Register) builder.getExpression("%value2");
assertEquals(((ScopedPointerType) builder.getType("%_ptr_Function_uint")).getPointedType(), value1.getInnerType());
Local local = (Local) getLastEvent();
assertEquals(value2, local.getResultRegister());
assertEquals(builder.getType("%uchar"), local.getExpr().getType());
@Test
public void opBitcastValidScalarToScalar() {
builder.mockIntType("%uint", 32);
builder.mockIntType("%uchar", 8);
builder.mockConstant("%value1", "%uint", 1);
String input = "%value2 = OpBitcast %uchar %value1";

visit(input);
Expression reg = builder.getExpression("%value2");
assertEquals(builder.getType("%uchar"), reg.getType());
}

private Event getLastEvent() {
return getLastNEvent(0);
@Test
public void opBitcastScalarToPointer() {
builder.mockIntType("%uint", 32);
builder.mockPtrType("%_ptr_Function_uint", "%uint", "Function");
builder.mockConstant("%value1", "%uint", 1);
String input = "%value2 = OpBitcast %_ptr_Function_uint %value1";

visit(input);
Expression reg = builder.getExpression("%value2");
assertEquals(builder.getType("%_ptr_Function_uint"), reg.getType());
}

private Event getLastNEvent(int n) {
List<Event> events = builder.getCurrentFunction().getEvents();
if (!events.isEmpty() && events.size() > n) {
return events.get(events.size() - 1 - n);
@Test
public void opBitcastStorageClassMismatch() {
builder.mockIntType("%uint", 32);
builder.mockPtrType("%_ptr_Function_uint", "%uint", "Function");
builder.mockPtrType("%_ptr_Workgroup_uint", "%uint", "Workgroup");
builder.mockVariable("%value1", "%_ptr_Function_uint");
String input = "%value2 = OpBitcast %_ptr_Workgroup_uint %value1";

try {
visit(input);
fail("Should throw exception");
} catch (ParsingException e) {
assertEquals("Storage class mismatch in OpBitcast between '%_ptr_Workgroup_uint' and '%value1' for id '%value2'", e.getMessage());
}
return null;
}

@Test
public void opConvertPtrToUValid() {
builder.mockIntType("%uint", 32);
builder.mockPtrType("%_ptr_Function_uint", "%uint", "Function");
builder.mockVariable("%value1", "%_ptr_Function_uint");
String input = "%value2 = OpConvertPtrToU %uint %value1";

visit(input);
Expression reg = builder.getExpression("%value2");
assertEquals(builder.getType("%uint"), reg.getType());
}

private void visit(String input) {
Expand Down

0 comments on commit fe99e77

Please sign in to comment.