Skip to content

Commit

Permalink
thread creation after inling
Browse files Browse the repository at this point in the history
  • Loading branch information
tonghaining committed Feb 19, 2025
1 parent d10d9aa commit 8067a55
Show file tree
Hide file tree
Showing 17 changed files with 178 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.dat3m.dartagnan.expression.Type;
import com.dat3m.dartagnan.expression.integers.IntLiteral;
import com.dat3m.dartagnan.expression.type.FunctionType;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.program.*;
import com.dat3m.dartagnan.program.Thread;
Expand Down Expand Up @@ -52,7 +51,7 @@ public class ProgramBuilder {
// ----------------------------------------------------------------------------------------------------------------
// Construction
private ProgramBuilder(SourceLanguage format) {
this.program = new Program(new Memory(), format);
this.program = new Program(new Memory(), format, null);
}

public static ProgramBuilder forArch(SourceLanguage format, Arch arch) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class VisitorLlvm extends LLVMIRBaseVisitor<Expression> {
private static final Logger logger = LogManager.getLogger(VisitorLlvm.class);

// Global context
private final Program program = new Program(new Memory(), Program.SourceLanguage.LLVM);
private final Program program = new Program(new Memory(), Program.SourceLanguage.LLVM, null);
private final TypeFactory types = TypeFactory.getInstance();
private final ExpressionFactory expressions = ExpressionFactory.getInstance();
private final Type pointerType = types.getPointerType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import com.dat3m.dartagnan.parsers.SpirvParser;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.*;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.SpirvThreadGrid;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Register;
import org.antlr.v4.runtime.tree.ParseTree;
Expand Down Expand Up @@ -61,7 +61,7 @@ public Program visitOp(SpirvParser.OpContext ctx) {
}

private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) {
ThreadGrid grid = new ThreadGrid(1, 1, 1, 1);
SpirvThreadGrid grid = new SpirvThreadGrid(1, 1, 1, 1);
boolean hasConfig = false;
for (SpirvParser.SpvHeaderContext header : ctx.spvHeaders().spvHeader()) {
SpirvParser.ConfigHeaderContext cfgCtx = header.configHeader();
Expand All @@ -74,7 +74,7 @@ private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) {
int sg = Integer.parseInt(literals.get(0).getText());
int wg = Integer.parseInt(literals.get(1).getText());
int qf = Integer.parseInt(literals.get(2).getText());
grid = new ThreadGrid(sg, wg, qf, 1);
grid = new SpirvThreadGrid(sg, wg, qf, 1);
}
}
return new ProgramBuilder(grid);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.*;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.SpirvThreadGrid;

import java.util.EnumMap;

Expand All @@ -12,7 +12,7 @@ public class DecorationsBuilder {

private final EnumMap<DecorationType, Decoration> mapping = new EnumMap<>(DecorationType.class);

public DecorationsBuilder(ThreadGrid grid) {
public DecorationsBuilder(SpirvThreadGrid grid) {
mapping.put(BUILT_IN, new BuiltIn(grid));
mapping.put(OFFSET, new Offset());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import com.dat3m.dartagnan.expression.type.*;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.BuiltIn;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperTags;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadCreator;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.MemoryTransformer;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.SpirvThreadGrid;
import com.dat3m.dartagnan.program.event.core.Local;
import com.dat3m.dartagnan.program.event.functions.FunctionCall;
import com.dat3m.dartagnan.program.memory.*;
Expand All @@ -28,7 +28,7 @@ public class ProgramBuilder {
protected final Map<String, Expression> parameterValues = new HashMap<>();
protected final Map<String, Expression> inputs = new HashMap<>();
protected final Map<String, String> debugInfos = new HashMap<>();
protected final ThreadGrid grid;
protected final SpirvThreadGrid grid;
protected final Program program;
protected ControlFlowBuilder controlFlowBuilder;
protected DecorationsBuilder decorationsBuilder;
Expand All @@ -37,9 +37,9 @@ public class ProgramBuilder {
protected Arch arch;
protected Set<String> nextOps;

public ProgramBuilder(ThreadGrid grid) {
public ProgramBuilder(SpirvThreadGrid grid) {
this.grid = grid;
this.program = new Program(new Memory(), Program.SourceLanguage.SPV);
this.program = new Program(new Memory(), Program.SourceLanguage.SPV, grid);
this.controlFlowBuilder = new ControlFlowBuilder(expressions);
this.decorationsBuilder = new DecorationsBuilder(grid);
}
Expand All @@ -48,15 +48,12 @@ public Program build() {
validateBeforeBuild();
controlFlowBuilder.build();
BuiltIn builtIn = (BuiltIn) decorationsBuilder.getDecoration(BUILT_IN);
Function entryFunction = getEntryPointFunction();
Set<Function> subFunctions = program.getFunctions().stream()
.filter(f -> !f.equals(entryFunction))
.collect(Collectors.toSet());
new ThreadCreator(grid, entryFunction, subFunctions, getVariables(), builtIn).create();
MemoryTransformer transformer = new MemoryTransformer(grid, getEntryPointFunction(), builtIn, getVariables());
program.addTransformer(transformer);
return program;
}

public ThreadGrid getThreadGrid() {
public SpirvThreadGrid getThreadGrid() {
return grid;
}

Expand Down Expand Up @@ -88,6 +85,7 @@ public void setEntryPointId(String id) {
throw new ParsingException("Multiple entry points are not supported");
}
entryPointId = id;
program.setEntryPoint(id);
}

public String getEntryPointId() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.SpirvThreadGrid;
import com.dat3m.dartagnan.program.memory.MemoryObject;

import java.util.ArrayList;
Expand All @@ -20,11 +20,11 @@ public class BuiltIn implements Decoration {

private static final TypeFactory types = TypeFactory.getInstance();
private static final ExpressionFactory expressions = ExpressionFactory.getInstance();
private final ThreadGrid grid;
private final SpirvThreadGrid grid;
private final Map<String, String> mapping;
private int tid;

public BuiltIn(ThreadGrid grid) {
public BuiltIn(SpirvThreadGrid grid) {
this.grid = grid;
this.mapping = new HashMap<>();
}
Expand Down Expand Up @@ -78,14 +78,14 @@ private Expression getDecorationExpressions(String id, Type type) {
case "SubgroupLocalInvocationId" -> makeScalar(id, type, tid % grid.sgSize());
case "LocalInvocationId" -> makeArray(id, type, tid % grid.wgSize(), 0, 0);
case "LocalInvocationIndex" -> makeScalar(id, type, tid % grid.wgSize()); // scalar of LocalInvocationId
case "GlobalInvocationId" -> makeArray(id, type, tid % grid.dvSize(), 0, 0);
case "GlobalInvocationId" -> makeArray(id, type, tid % grid.threadPoolSize(), 0, 0);
case "DeviceIndex" -> makeScalar(id, type, 0);
case "SubgroupId" -> makeScalar(id, type, grid.sgId(tid));
case "WorkgroupId" -> makeArray(id, type, grid.wgId(tid), 0, 0);
case "SubgroupSize" -> makeScalar(id, type, grid.sgSize());
case "WorkgroupSize" -> makeArray(id, type, grid.wgSize(), 1, 1);
case "GlobalSize" -> makeArray(id, type, grid.dvSize(), 1, 1);
case "NumWorkgroups" -> makeArray(id, type, grid.dvSize() / grid.wgSize(), 1, 1);
case "GlobalSize" -> makeArray(id, type, grid.threadPoolSize(), 1, 1);
case "NumWorkgroups" -> makeArray(id, type, grid.threadPoolSize() / grid.wgSize(), 1, 1);
default -> throw new ParsingException("Unsupported decoration '%s'", mapping.get(id));
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import com.dat3m.dartagnan.expression.type.*;
import com.dat3m.dartagnan.parsers.SpirvParser;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.builders.ProgramBuilder;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.SpirvThreadGrid;
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.memory.ScopedPointerVariable;

Expand Down Expand Up @@ -118,9 +118,9 @@ private Void setPushConstantValue(String argument, String offsetId, String sizeI
}

private List<Integer> computePushConstantValue(String command) {
ThreadGrid grid = builder.getThreadGrid();
SpirvThreadGrid grid = builder.getThreadGrid();
return switch (command) {
case "PushConstantGlobalSize" -> List.of(grid.dvSize(), 1, 1);
case "PushConstantGlobalSize" -> List.of(grid.threadPoolSize(), 1, 1);
case "PushConstantEnqueuedLocalSize" -> List.of(grid.wgSize(), 1, 1);
case "PushConstantNumWorkgroups" -> List.of(grid.qfSize() / grid.wgSize(), 1, 1);
case "PushConstantGlobalOffset",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
import com.dat3m.dartagnan.program.Function;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Register;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.memory.MemoryObject;
import com.dat3m.dartagnan.program.memory.ScopedPointerVariable;
import com.dat3m.dartagnan.program.memory.VirtualMemoryObject;
import com.dat3m.dartagnan.program.misc.NonDetValue;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.*;
import java.util.function.IntUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -28,9 +26,8 @@ public class MemoryTransformer extends ExprTransformer {
private static final List<String> namePrefixes = List.of("T", "S", "W", "Q", "D");

private final Program program;
private final Function entryFunction;
private final Function function;
private final BuiltIn builtIn;
private final Set<Function> subFunctions;
private final List<? extends Map<MemoryObject, MemoryObject>> scopeMapping;
private final Map<MemoryObject, ScopedPointerVariable> pointerMapping;
private final List<IntUnaryOperator> scopeIdProvider;
Expand All @@ -39,12 +36,10 @@ public class MemoryTransformer extends ExprTransformer {
private Map<NonDetValue, NonDetValue> nonDetMapping;
private int tid;

public MemoryTransformer(ThreadGrid grid, Function entryFunction, Set<Function> subFunctions,
BuiltIn builtIn, Set<ScopedPointerVariable> variables) {
this.program = entryFunction.getProgram();
this.entryFunction = entryFunction;
public MemoryTransformer(SpirvThreadGrid grid, Function function, BuiltIn builtIn, Set<ScopedPointerVariable> variables) {
this.program = function.getProgram();
this.function = function;
this.builtIn = builtIn;
this.subFunctions = subFunctions;
this.scopeMapping = Stream.generate(() -> new HashMap<MemoryObject, MemoryObject>()).limit(namePrefixes.size()).toList();
this.pointerMapping = variables.stream().collect(Collectors.toMap((ScopedPointerVariable::getAddress), (v -> v)));
this.scopeIdProvider = List.of(grid::thId, grid::sgId, grid::wgId, grid::qfId, grid::dvId);
Expand All @@ -53,30 +48,36 @@ public MemoryTransformer(ThreadGrid grid, Function entryFunction, Set<Function>
i -> i / grid.sgSize(),
i -> i / grid.wgSize(),
i -> i / grid.qfSize(),
i -> i / grid.dvSize());
i -> i / grid.threadPoolSize());
}

public Register getRegisterMapping(Register register) {
return registerMapping.get(register);
}

public void setTransferFunction(Function function) {
int newTid = function.getId();
public void setThread(Thread thread) {
int newTid = thread.getId();
int depth = getScopeIdx(newTid, scopeIdProvider);
for (int i = 0; i <= depth; i++) {
scopeMapping.get(i).clear();
}
tid = newTid;
builtIn.setThreadId(tid);
registerMapping = entryFunction.getRegisters().stream().collect(
toMap(r -> r, r -> function.getOrNewRegister(r.getName(), r.getType())));
for (Function subfunction : subFunctions) {
registerMapping.putAll(subfunction.getRegisters().stream().collect(
toMap(r -> r, r -> function.getOrNewRegister(r.getName(), r.getType()))));
}
registerMapping = function.getRegisters().stream().collect(
toMap(r -> r, r -> thread.getOrNewRegister(r.getName(), r.getType())));
nonDetMapping = new HashMap<>();
}

public List<MemoryObject> getThreadLocalMemoryObjects() {
List<MemoryObject> threadLocalMemoryObjects = new ArrayList<>();
for (MemoryObject memoryObject : pointerMapping.keySet()) {
if (memoryObject.isThreadLocal()) {
threadLocalMemoryObjects.add(memoryObject);
}
}
return threadLocalMemoryObjects;
}

@Override
public Expression visitRegister(Register register) {
return registerMapping.get(register);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@

import com.dat3m.dartagnan.exception.ParsingException;
import com.dat3m.dartagnan.program.ScopeHierarchy;
import com.dat3m.dartagnan.program.ThreadGrid;

import java.util.List;

public class ThreadGrid {
public class SpirvThreadGrid implements ThreadGrid {

private final int sg;
private final int wg;
private final int qf;
private final int dv;

public ThreadGrid(int sg, int wg, int qf, int dv) {
public SpirvThreadGrid(int sg, int wg, int qf, int dv) {
List<Integer> elements = List.of(sg, wg, qf, dv);
if (elements.stream().anyMatch(i -> i <= 0)) {
throw new ParsingException("Thread grid dimensions must be positive");
Expand All @@ -35,7 +36,8 @@ public int qfSize() {
return sg * wg * qf;
}

public int dvSize() {
@Override
public int threadPoolSize() {
return sg * wg * qf * dv;
}

Expand All @@ -52,15 +54,15 @@ public int wgId(int tid) {
}

public int qfId(int tid) {
return (tid % dvSize()) / qfSize();
return (tid % threadPoolSize()) / qfSize();
}

public int dvId(int tid) {
return tid / dvSize();
return tid / threadPoolSize();
}

@Override
public ScopeHierarchy getScoreHierarchy(int tid) {
return ScopeHierarchy.ScopeHierarchyForVulkan(qfId(tid), wgId(tid), sgId(tid));

}
}
Loading

0 comments on commit 8067a55

Please sign in to comment.