Skip to content

Commit

Permalink
Reflecting review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
lahodaj committed Nov 3, 2023
1 parent 263bd3d commit fd3bc68
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 120 deletions.
249 changes: 129 additions & 120 deletions src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import java.lang.Enum.EnumDesc;
import java.lang.constant.ClassDesc;
import java.lang.constant.ConstantDescs;
import java.lang.constant.MethodTypeDesc;
import java.lang.invoke.CallSite;
import java.lang.invoke.ConstantCallSite;
Expand All @@ -37,7 +38,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.function.BiPredicate;
import java.util.stream.Stream;
import jdk.internal.access.SharedSecrets;
import jdk.internal.classfile.Classfile;
Expand Down Expand Up @@ -69,6 +70,9 @@ private SwitchBootstraps() {}
private static final MethodHandle CHECK_INDEX;
private static final MethodHandle MAPPED_ENUM_LOOKUP;

private static final MethodTypeDesc typesSwitchDescriptor =
MethodTypeDesc.ofDescriptor("(Ljava/lang/Object;ILjava/util/function/BiPredicate;[Ljava/lang/Class;)I");

static {
try {
NULL_CHECK = LOOKUP.findStatic(Objects.class, "isNull",
Expand Down Expand Up @@ -321,7 +325,7 @@ private static MethodHandle withIndexCheck(MethodHandle target, int labelsCount)
return MethodHandles.filterArguments(target, 1, checkIndex);
}

private static final class ResolvedEnumLabels implements BiFunction<Integer, Object, Object> {
private static final class ResolvedEnumLabels implements BiPredicate<Integer, Object> {

private final MethodHandles.Lookup lookup;
private final EnumDesc<?>[] enumDescs;
Expand All @@ -335,7 +339,7 @@ public ResolvedEnumLabels(MethodHandles.Lookup lookup, EnumDesc<?>[] enumDescs)
}

@Override
public Object apply(Integer labelIndex, Object value) {
public boolean test(Integer labelIndex, Object value) {
Object result = resolvedEnum[labelIndex];

if (result == null) {
Expand All @@ -344,7 +348,7 @@ public Object apply(Integer labelIndex, Object value) {
Class<?> clazz = label.constantType().resolveConstantDesc(lookup);

if (value.getClass() != clazz) {
return SENTINEL;
return false;
}

result = label.resolveConstantDesc(lookup);
Expand All @@ -355,7 +359,7 @@ public Object apply(Integer labelIndex, Object value) {
resolvedEnum[labelIndex] = result;
}

return result;
return result == value;
}
}

Expand All @@ -375,125 +379,128 @@ private static final class EnumMap {
@SuppressWarnings("removal")
private static MethodHandle generateInnerClass(MethodHandles.Lookup caller, Object[] labels) {
List<EnumDesc<?>> enumDescs = new ArrayList<>();
List<Class<?>> extraClassLabels = new ArrayList<>();

byte[] classBytes = Classfile.of().build(ClassDesc.of(typeSwitchClassName(caller.lookupClass())), clb -> {
clb.withFlags(AccessFlag.FINAL, AccessFlag.SYNTHETIC)
.withMethod("typeSwitch", MethodTypeDesc.ofDescriptor("(Ljava/lang/Object;ILjava/util/function/BiFunction;)I"), Classfile.ACC_STATIC, mb -> {
mb.withFlags(AccessFlag.PUBLIC, AccessFlag.FINAL, AccessFlag.STATIC)
.withCode(cb -> {
cb.aload(0);
Label nonNullLabel = cb.newLabel();
cb.if_nonnull(nonNullLabel);
cb.iconst_m1();
cb.ireturn();
cb.labelBinding(nonNullLabel);
if (labels.length == 0) {
cb.constantInstruction(0)
.ireturn();
return ;
}
cb.iload(1);
Label dflt = cb.newLabel();
record Element(Label target, Label next, Object label) {}
List<Element> cases = new ArrayList<>();
List<SwitchCase> switchCases = new ArrayList<>();
Object lastLabel = null;
for (int idx = labels.length - 1; idx >= 0; idx--) {
Object currentLabel = labels[idx];
Label target = cb.newLabel();
Label next;
if (lastLabel == null) {
next = dflt;
} else if (lastLabel.equals(currentLabel)) {
next = cases.getLast().next();
} else {
next = cases.getLast().target();
}
lastLabel = currentLabel;
cases.add(new Element(target, next, currentLabel));
switchCases.add(SwitchCase.of(idx, target));
.withMethodBody("typeSwitch",
typesSwitchDescriptor,
Classfile.ACC_FINAL | Classfile.ACC_PUBLIC | Classfile.ACC_STATIC,
cb -> {
cb.aload(0);
Label nonNullLabel = cb.newLabel();
cb.if_nonnull(nonNullLabel);
cb.iconst_m1();
cb.ireturn();
cb.labelBinding(nonNullLabel);
if (labels.length == 0) {
cb.constantInstruction(0)
.ireturn();
return ;
}
cb.iload(1);
Label dflt = cb.newLabel();
record Element(Label target, Label next, Object caseLabel) {}
List<Element> cases = new ArrayList<>();
List<SwitchCase> switchCases = new ArrayList<>();
Object lastLabel = null;
for (int idx = labels.length - 1; idx >= 0; idx--) {
Object currentLabel = labels[idx];
Label target = cb.newLabel();
Label next;
if (lastLabel == null) {
next = dflt;
} else if (lastLabel.equals(currentLabel)) {
next = cases.getLast().next();
} else {
next = cases.getLast().target();
}
cases = cases.reversed();
switchCases = switchCases.reversed();
cb.tableswitch(0, labels.length - 1, dflt, switchCases);
for (int idx = 0; idx < cases.size(); idx++) {
Element element = cases.get(idx);
Label next = element.next();
cb.labelBinding(element.target());
if (element.label() instanceof Class<?> classLabel) {
cb.aload(0);
cb.instanceof_(classLabel.describeConstable().get());
cb.ifeq(next);
} else if (element.label() instanceof EnumDesc<?> enumLabel) {
int enumIdx = enumDescs.size();
enumDescs.add(enumLabel);
cb.aload(2);
cb.constantInstruction(enumIdx);
cb.invokestatic(Integer.class.describeConstable().get(),
"valueOf",
MethodType.methodType(Integer.class,
int.class)
.describeConstable()
.get());
cb.aload(0);
cb.invokeinterface(BiFunction.class.describeConstable().get(),
"apply",
MethodType.methodType(Object.class,
Object.class,
Object.class)
.describeConstable()
.get());
cb.aload(0);
cb.if_acmpne(next);
} else if (element.label() instanceof String stringLabel) {
cb.ldc(stringLabel);
cb.aload(0);
cb.invokevirtual(Object.class.describeConstable().get(),
"equals",
MethodType.methodType(boolean.class,
Object.class)
.describeConstable()
.get());
cb.ifeq(next);
} else if (element.label() instanceof Integer integerLabel) {
Label compare = cb.newLabel();
Label notNumber = cb.newLabel();
cb.aload(0);
cb.instanceof_(Number.class.describeConstable().get());
cb.ifeq(notNumber);
cb.aload(0);
cb.checkcast(Number.class.describeConstable().get());
cb.invokevirtual(Number.class.describeConstable().get(),
"intValue",
MethodType.methodType(int.class)
.describeConstable()
.get());
cb.goto_(compare);
cb.labelBinding(notNumber);
cb.aload(0);
cb.instanceof_(Character.class.describeConstable().get());
cb.ifeq(next);
cb.aload(0);
cb.checkcast(Character.class.describeConstable().get());
cb.invokevirtual(Character.class.describeConstable().get(),
"charValue",
MethodType.methodType(char.class)
.describeConstable()
.get());
cb.labelBinding(compare);
cb.ldc(integerLabel);
cb.if_icmpne(next);
} else {
throw new InternalError("Unsupported label type: " + element.label().getClass());
}
cb.constantInstruction(idx);
cb.ireturn();
lastLabel = currentLabel;
cases.add(new Element(target, next, currentLabel));
switchCases.add(SwitchCase.of(idx, target));
}
cases = cases.reversed();
switchCases = switchCases.reversed();
cb.tableswitch(0, labels.length - 1, dflt, switchCases);
for (int idx = 0; idx < cases.size(); idx++) {
Element element = cases.get(idx);
Label next = element.next();
cb.labelBinding(element.target());
if (element.caseLabel() instanceof Class<?> classLabel &&
classLabel.describeConstable().isPresent()) {
cb.aload(0);
cb.instanceof_(classLabel.describeConstable().orElseThrow());
cb.ifeq(next);
} else if (element.caseLabel() instanceof Class<?> classLabel) {
cb.aload(3);
cb.constantInstruction(extraClassLabels.size());
cb.aaload();
cb.aload(0);
cb.invokevirtual(ConstantDescs.CD_Class,
"isInstance",
MethodTypeDesc.of(ConstantDescs.CD_boolean,
ConstantDescs.CD_Object));
cb.ifeq(next);
extraClassLabels.add(classLabel);
} else if (element.caseLabel() instanceof EnumDesc<?> enumLabel) {
int enumIdx = enumDescs.size();
enumDescs.add(enumLabel);
cb.aload(2);
cb.constantInstruction(enumIdx);
cb.invokestatic(ConstantDescs.CD_Integer,
"valueOf",
MethodTypeDesc.of(ConstantDescs.CD_Integer,
ConstantDescs.CD_int));
cb.aload(0);
cb.invokeinterface(BiPredicate.class.describeConstable().get(),
"test",
MethodTypeDesc.of(ConstantDescs.CD_boolean,
ConstantDescs.CD_Object,
ConstantDescs.CD_Object));
cb.ifeq(next);
} else if (element.caseLabel() instanceof String stringLabel) {
cb.ldc(stringLabel);
cb.aload(0);
cb.invokevirtual(ConstantDescs.CD_Object,
"equals",
MethodTypeDesc.of(ConstantDescs.CD_boolean,
ConstantDescs.CD_Object));
cb.ifeq(next);
} else if (element.caseLabel() instanceof Integer integerLabel) {
Label compare = cb.newLabel();
Label notNumber = cb.newLabel();
cb.aload(0);
cb.instanceof_(ConstantDescs.CD_Number);
cb.ifeq(notNumber);
cb.aload(0);
cb.checkcast(ConstantDescs.CD_Number);
cb.invokevirtual(ConstantDescs.CD_Number,
"intValue",
MethodTypeDesc.of(ConstantDescs.CD_int));
cb.goto_(compare);
cb.labelBinding(notNumber);
cb.aload(0);
cb.instanceof_(ConstantDescs.CD_Character);
cb.ifeq(next);
cb.aload(0);
cb.checkcast(ConstantDescs.CD_Character);
cb.invokevirtual(ConstantDescs.CD_Character,
"charValue",
MethodTypeDesc.of(ConstantDescs.CD_char));
cb.labelBinding(compare);
cb.ldc(integerLabel);
cb.if_icmpne(next);
} else {
throw new InternalError("Unsupported label type: " +
element.caseLabel().getClass());
}
cb.labelBinding(dflt);
cb.constantInstruction(cases.size());
cb.constantInstruction(idx);
cb.ireturn();
});
});
}
cb.labelBinding(dflt);
cb.constantInstruction(cases.size());
cb.ireturn();
});
});

try {
Expand All @@ -505,8 +512,10 @@ record Element(Label target, Label next, Object label) {}
MethodType.methodType(int.class,
Object.class,
int.class,
BiFunction.class));
return MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(s -> new EnumDesc<?>[s])));
BiPredicate.class,
Class[].class));
return MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(s -> new EnumDesc<?>[s])),
extraClassLabels.toArray(s -> new Class<?>[s]));
} catch (Throwable t) {
throw new IllegalArgumentException(t);
}
Expand Down
Loading

0 comments on commit fd3bc68

Please sign in to comment.