From e5b6d32b13d43254d7296c1c8044404074520828 Mon Sep 17 00:00:00 2001 From: Jan Lahoda Date: Wed, 1 Nov 2023 09:04:33 +0100 Subject: [PATCH] Attempting to speeding SwitchBootstraps. --- .../java/lang/runtime/SwitchBootstraps.java | 317 +++++++++++------- .../lang/runtime/SwitchBootstrapsTest.java | 9 +- 2 files changed, 211 insertions(+), 115 deletions(-) diff --git a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java index 4743b00997030..a80b1e3257266 100644 --- a/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java +++ b/src/java.base/share/classes/java/lang/runtime/SwitchBootstraps.java @@ -25,18 +25,31 @@ package java.lang.runtime; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; import java.lang.Enum.EnumDesc; +import java.lang.constant.ClassDesc; +import java.lang.constant.MethodTypeDesc; import java.lang.invoke.CallSite; import java.lang.invoke.ConstantCallSite; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; +import java.lang.reflect.AccessFlag; +import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.function.BiFunction; import java.util.stream.Stream; import jdk.internal.access.SharedSecrets; +import jdk.internal.classfile.Classfile; +import jdk.internal.classfile.Label; +import jdk.internal.classfile.instruction.SwitchCase; import jdk.internal.vm.annotation.Stable; +import static java.lang.invoke.MethodHandles.Lookup.ClassOption.NESTMATE; +import static java.lang.invoke.MethodHandles.Lookup.ClassOption.STRONG; import static java.util.Objects.requireNonNull; /** @@ -54,10 +67,6 @@ private SwitchBootstraps() {} private static final Object SENTINEL = new Object(); private static final MethodHandles.Lookup LOOKUP = MethodHandles.lookup(); - private static final MethodHandle INSTANCEOF_CHECK; - private static final MethodHandle INTEGER_EQ_CHECK; - private static final MethodHandle OBJECT_EQ_CHECK; - private static final MethodHandle ENUM_EQ_CHECK; private static final MethodHandle NULL_CHECK; private static final MethodHandle IS_ZERO; private static final MethodHandle CHECK_INDEX; @@ -65,15 +74,6 @@ private SwitchBootstraps() {} static { try { - INSTANCEOF_CHECK = MethodHandles.permuteArguments(LOOKUP.findVirtual(Class.class, "isInstance", - MethodType.methodType(boolean.class, Object.class)), - MethodType.methodType(boolean.class, Object.class, Class.class), 1, 0); - INTEGER_EQ_CHECK = LOOKUP.findStatic(SwitchBootstraps.class, "integerEqCheck", - MethodType.methodType(boolean.class, Object.class, Integer.class)); - OBJECT_EQ_CHECK = LOOKUP.findStatic(Objects.class, "equals", - MethodType.methodType(boolean.class, Object.class, Object.class)); - ENUM_EQ_CHECK = LOOKUP.findStatic(SwitchBootstraps.class, "enumEqCheck", - MethodType.methodType(boolean.class, Object.class, EnumDesc.class, MethodHandles.Lookup.class, ResolvedEnumLabel.class)); NULL_CHECK = LOOKUP.findStatic(Objects.class, "isNull", MethodType.methodType(boolean.class, Object.class)); IS_ZERO = LOOKUP.findStatic(SwitchBootstraps.class, "isZero", @@ -155,7 +155,9 @@ public static CallSite typeSwitch(MethodHandles.Lookup lookup, labels = labels.clone(); Stream.of(labels).forEach(SwitchBootstraps::verifyLabel); - MethodHandle target = createMethodHandleSwitch(lookup, labels); + MethodHandle target = generateInnerClass(lookup, labels); + + target = withIndexCheck(target, labels.length); return new ConstantCallSite(target); } @@ -173,79 +175,6 @@ private static void verifyLabel(Object label) { } } - /* - * Construct test chains for labels inside switch, to handle switch repeats: - * switch (idx) { - * case 0 -> if (selector matches label[0]) return 0; else if (selector matches label[1]) return 1; else ... - * case 1 -> if (selector matches label[1]) return 1; else ... - * ... - * } - */ - private static MethodHandle createRepeatIndexSwitch(MethodHandles.Lookup lookup, Object[] labels) { - MethodHandle def = MethodHandles.dropArguments(MethodHandles.constant(int.class, labels.length), 0, Object.class); - MethodHandle[] testChains = new MethodHandle[labels.length]; - List labelsList = List.of(labels).reversed(); - - for (int i = 0; i < labels.length; i++) { - MethodHandle test = def; - int idx = labels.length - 1; - List currentLabels = labelsList.subList(0, labels.length - i); - - for (int j = 0; j < currentLabels.size(); j++, idx--) { - Object currentLabel = currentLabels.get(j); - if (j + 1 < currentLabels.size() && currentLabels.get(j + 1) == currentLabel) continue; - MethodHandle currentTest; - if (currentLabel instanceof Class) { - currentTest = INSTANCEOF_CHECK; - } else if (currentLabel instanceof Integer) { - currentTest = INTEGER_EQ_CHECK; - } else if (currentLabel instanceof EnumDesc) { - currentTest = MethodHandles.insertArguments(ENUM_EQ_CHECK, 2, lookup, new ResolvedEnumLabel()); - } else { - currentTest = OBJECT_EQ_CHECK; - } - test = MethodHandles.guardWithTest(MethodHandles.insertArguments(currentTest, 1, currentLabel), - MethodHandles.dropArguments(MethodHandles.constant(int.class, idx), 0, Object.class), - test); - } - testChains[i] = MethodHandles.dropArguments(test, 0, int.class); - } - - return MethodHandles.tableSwitch(MethodHandles.dropArguments(def, 0, int.class), testChains); - } - - /* - * Construct code that maps the given selector and repeat index to a case label number: - * if (selector == null) return -1; - * else return "createRepeatIndexSwitch(labels)" - */ - private static MethodHandle createMethodHandleSwitch(MethodHandles.Lookup lookup, Object[] labels) { - MethodHandle mainTest; - MethodHandle def = MethodHandles.dropArguments(MethodHandles.constant(int.class, labels.length), 0, Object.class); - if (labels.length > 0) { - mainTest = createRepeatIndexSwitch(lookup, labels); - } else { - mainTest = MethodHandles.dropArguments(def, 0, int.class); - } - MethodHandle body = - MethodHandles.guardWithTest(MethodHandles.dropArguments(NULL_CHECK, 0, int.class), - MethodHandles.dropArguments(MethodHandles.constant(int.class, -1), 0, int.class, Object.class), - mainTest); - MethodHandle switchImpl = - MethodHandles.permuteArguments(body, MethodType.methodType(int.class, Object.class, int.class), 1, 0); - return withIndexCheck(switchImpl, labels.length); - } - - private static boolean integerEqCheck(Object value, Integer constant) { - if (value instanceof Number input && constant.intValue() == input.intValue()) { - return true; - } else if (value instanceof Character input && constant.intValue() == input.charValue()) { - return true; - } - - return false; - } - private static boolean isZero(int value) { return value == 0; } @@ -330,16 +259,16 @@ public static CallSite enumSwitch(MethodHandles.Lookup lookup, //If all labels are enum constants, construct an optimized handle for repeat index 0: //if (selector == null) return -1 //else if (idx == 0) return mappingArray[selector.ordinal()]; //mapping array created lazily - //else return "createRepeatIndexSwitch(labels)" + //else return "typeSwitch(labels)" MethodHandle body = MethodHandles.guardWithTest(MethodHandles.dropArguments(NULL_CHECK, 0, int.class), MethodHandles.dropArguments(MethodHandles.constant(int.class, -1), 0, int.class, Object.class), MethodHandles.guardWithTest(MethodHandles.dropArguments(IS_ZERO, 1, Object.class), - createRepeatIndexSwitch(lookup, labels), + generateInnerClass(lookup, labels), MethodHandles.insertArguments(MAPPED_ENUM_LOOKUP, 1, lookup, enumClass, labels, new EnumMap()))); target = MethodHandles.permuteArguments(body, MethodType.methodType(int.class, Object.class, int.class), 1, 0); } else { - target = createMethodHandleSwitch(lookup, labels); + target = generateInnerClass(lookup, labels); } target = target.asType(invocationType); @@ -389,41 +318,205 @@ private static > int mappedEnumLookup(T value, MethodHandles.L return enumMap.map[value.ordinal()]; } - private static boolean enumEqCheck(Object value, EnumDesc label, MethodHandles.Lookup lookup, ResolvedEnumLabel resolvedEnum) { - if (resolvedEnum.resolvedEnum == null) { - Object resolved; + private static MethodHandle withIndexCheck(MethodHandle target, int labelsCount) { + MethodHandle checkIndex = MethodHandles.insertArguments(CHECK_INDEX, 1, labelsCount + 1); - try { - Class clazz = label.constantType().resolveConstantDesc(lookup); + return MethodHandles.filterArguments(target, 1, checkIndex); + } - if (value.getClass() != clazz) { - return false; - } + private static final class ResolvedEnumLabels implements BiFunction { - resolved = label.resolveConstantDesc(lookup); - } catch (IllegalArgumentException | ReflectiveOperationException ex) { - resolved = SENTINEL; - } + private final MethodHandles.Lookup lookup; + private final EnumDesc[] enumDescs; + @Stable + private Object[] resolvedEnum; - resolvedEnum.resolvedEnum = resolved; + public ResolvedEnumLabels(MethodHandles.Lookup lookup, EnumDesc[] enumDescs) { + this.lookup = lookup; + this.enumDescs = enumDescs; + this.resolvedEnum = new Object[enumDescs.length]; } - return value == resolvedEnum.resolvedEnum; - } + @Override + public Object apply(Integer labelIndex, Object value) { + Object result = resolvedEnum[labelIndex]; - private static MethodHandle withIndexCheck(MethodHandle target, int labelsCount) { - MethodHandle checkIndex = MethodHandles.insertArguments(CHECK_INDEX, 1, labelsCount + 1); + if (result == null) { + try { + EnumDesc label = enumDescs[labelIndex]; + Class clazz = label.constantType().resolveConstantDesc(lookup); - return MethodHandles.filterArguments(target, 1, checkIndex); - } + if (value.getClass() != clazz) { + return SENTINEL; + } - private static final class ResolvedEnumLabel { - @Stable - public Object resolvedEnum; + result = label.resolveConstantDesc(lookup); + } catch (IllegalArgumentException | ReflectiveOperationException ex) { + result = SENTINEL; + } + + resolvedEnum[labelIndex] = result; + } + + return result; + } } private static final class EnumMap { @Stable public int[] map; } + + /* + * Construct test chains for labels inside switch, to handle switch repeats: + * switch (idx) { + * case 0 -> if (selector matches label[0]) return 0; + * case 1 -> if (selector matches label[1]) return 1; + * ... + * } + */ + @SuppressWarnings("removal") + private static MethodHandle generateInnerClass(MethodHandles.Lookup caller, Object[] labels) { + List> enumDescs = new ArrayList<>(); + + byte[] classBytes = Classfile.of().build(ClassDesc.of(typeSwitchClassName(caller.lookupClass())), clb -> { + clb.withFlags(AccessFlag.FINAL, AccessFlag.SYNTHETIC) + .withMethod("typeSwitch", MethodTypeDesc.ofDescriptor("(Ljava/lang/Object;ILjava/util/function/BiFunction;)I"), Classfile.ACC_STATIC, mb -> { + mb.withFlags(AccessFlag.PUBLIC, AccessFlag.FINAL, AccessFlag.STATIC) + .withCode(cb -> { + cb.aload(0); + Label nonNullLabel = cb.newLabel(); + cb.if_nonnull(nonNullLabel) + .iconst_m1() + .ireturn() + .labelBinding(nonNullLabel); + if (labels.length == 0) { + cb.constantInstruction(0) + .ireturn(); + } + cb.iload(1); + Label dflt = cb.newLabel(); + record Element(Label target, Label next, Object label) {} + List cases = new ArrayList<>(); + List switchCases = new ArrayList<>(); + Object lastLabel = null; + for (int idx = labels.length - 1; idx >= 0; idx--) { + Object currentLabel = labels[idx]; + Label target = cb.newLabel(); + Label next; + if (lastLabel == null) { + next = dflt; + } else if (lastLabel.equals(currentLabel)) { + next = cases.getLast().next(); + } else { + next = cases.getLast().target(); + } + lastLabel = currentLabel; + cases.add(new Element(target, next, currentLabel)); + switchCases.add(SwitchCase.of(idx, target)); + } + cases = cases.reversed(); + switchCases = switchCases.reversed(); + cb.tableswitch(0, labels.length - 1, dflt, switchCases); + for (int idx = 0; idx < cases.size(); idx++) { + Element element = cases.get(idx); + Label next = element.next(); + cb.labelBinding(element.target()); + if (element.label() instanceof Class classLabel) { + cb.aload(0); + cb.instanceof_(classLabel.describeConstable().get()); + cb.ifeq(next); + } else if (element.label() instanceof EnumDesc enumLabel) { + int enumIdx = enumDescs.size(); + enumDescs.add(enumLabel); + cb.aload(2); + cb.constantInstruction(enumIdx); + cb.invokestatic(Integer.class.describeConstable().get(), + "valueOf", + MethodType.methodType(Integer.class, + int.class) + .describeConstable() + .get()); + cb.aload(0); + cb.invokeinterface(BiFunction.class.describeConstable().get(), + "apply", + MethodType.methodType(Object.class, + Object.class, + Object.class) + .describeConstable() + .get()); + cb.aload(0); + cb.if_acmpne(next); + } else if (element.label() instanceof String stringLabel) { + cb.ldc(stringLabel); + cb.aload(0); + cb.invokevirtual(Object.class.describeConstable().get(), + "equals", + MethodType.methodType(boolean.class, + Object.class) + .describeConstable() + .get()); + cb.ifeq(next); + } else if (element.label() instanceof Integer integerLabel) { + Label compare = cb.newLabel(); + Label notNumber = cb.newLabel(); + cb.aload(0); + cb.instanceof_(Number.class.describeConstable().get()); + cb.ifeq(notNumber); + cb.aload(0); + cb.checkcast(Number.class.describeConstable().get()); + cb.invokevirtual(Number.class.describeConstable().get(), + "intValue", + MethodType.methodType(int.class) + .describeConstable() + .get()); + cb.goto_(compare); + cb.labelBinding(notNumber); + cb.aload(0); + cb.instanceof_(Character.class.describeConstable().get()); + cb.ifeq(next); + cb.aload(0); + cb.checkcast(Character.class.describeConstable().get()); + cb.invokevirtual(Character.class.describeConstable().get(), + "charValue", + MethodType.methodType(char.class) + .describeConstable() + .get()); + cb.labelBinding(compare); + cb.ldc(integerLabel); + cb.if_icmpne(next); + } else { + throw new UnsupportedOperationException("not supported label: " + element.label().getClass()); + } + cb.constantInstruction(idx); + cb.ireturn(); + } + cb.labelBinding(dflt); + cb.constantInstruction(cases.size()); + cb.ireturn(); + }); + }); + }); + + try { + // this class is linked at the indy callsite; so define a hidden nestmate + MethodHandles.Lookup lookup; + lookup = caller.defineHiddenClass(classBytes, true, NESTMATE, STRONG); + MethodHandle typeSwitch = lookup.findStatic(lookup.lookupClass(), "typeSwitch", MethodType.methodType(int.class, Object.class, int.class, BiFunction.class)); + return MethodHandles.insertArguments(typeSwitch, 2, new ResolvedEnumLabels(caller, enumDescs.toArray(s -> new EnumDesc[s]))); + } catch (Throwable t) { + throw new IllegalArgumentException(t); + } + } + + private static int idx = 0; + //based on src/java.base/share/classes/java/lang/invoke/InnerClassLambdaMetafactory.java: + private static String typeSwitchClassName(Class targetClass) { + String name = targetClass.getName(); + if (targetClass.isHidden()) { + // use the original class name + name = name.replace('/', '_'); + } + return name + "$$TypeSwitch"; + } } diff --git a/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java b/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java index cbd728e33d5d1..6fada33f2b652 100644 --- a/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java +++ b/test/jdk/java/lang/runtime/SwitchBootstrapsTest.java @@ -108,9 +108,12 @@ public void testTypes() throws Throwable { } catch (IllegalArgumentException ex) { //OK } - testType("", 0, 0, String.class, String.class, String.class); - testType("", 1, 1, String.class, String.class, String.class); - testType("", 2, 2, String.class, String.class, String.class); + testType("", 0, 0, String.class, String.class, String.class, String.class, String.class); + testType("", 1, 1, String.class, String.class, String.class, String.class, String.class); + testType("", 2, 2, String.class, String.class, String.class, String.class, String.class); + testType("", 3, 3, String.class, String.class, String.class, String.class, String.class); + testType("", 3, 3, String.class, String.class, String.class, String.class, String.class); + testType("", 4, 4, String.class, String.class, String.class, String.class, String.class); testType("", 0, 0); }