From f03fb4e4ee4d59ed692d0c26ddce260511f544e7 Mon Sep 17 00:00:00 2001 From: Jatin Bhateja Date: Tue, 19 Sep 2023 16:42:00 +0000 Subject: [PATCH] 8308363: Initial compiler support for FP16 scalar operations. Reviewed-by: sviswanathan --- make/common/JavaCompilation.gmk | 2 +- src/hotspot/cpu/x86/assembler_x86.cpp | 34 +++- src/hotspot/cpu/x86/assembler_x86.hpp | 6 + src/hotspot/cpu/x86/vm_version_x86.cpp | 5 + src/hotspot/cpu/x86/vm_version_x86.hpp | 8 +- src/hotspot/cpu/x86/x86.ad | 60 ++++++ src/hotspot/share/adlc/formssel.cpp | 2 +- .../share/classfile/classFileParser.cpp | 20 +- .../share/classfile/classFileParser.hpp | 4 + src/hotspot/share/classfile/vmClassMacros.hpp | 1 + src/hotspot/share/classfile/vmIntrinsics.hpp | 6 + src/hotspot/share/classfile/vmSymbols.hpp | 2 + src/hotspot/share/oops/inlineKlass.cpp | 3 + src/hotspot/share/opto/addnode.hpp | 10 +- src/hotspot/share/opto/c2compiler.cpp | 4 +- src/hotspot/share/opto/classes.hpp | 4 + src/hotspot/share/opto/convertnode.cpp | 32 ++++ src/hotspot/share/opto/convertnode.hpp | 19 ++ src/hotspot/share/opto/inlinetypenode.hpp | 6 +- src/hotspot/share/opto/library_call.cpp | 30 +++ src/hotspot/share/opto/library_call.hpp | 1 + src/hotspot/share/opto/superword.cpp | 10 + src/hotspot/share/opto/vectornode.cpp | 21 +- src/hotspot/share/opto/vectornode.hpp | 12 ++ src/hotspot/share/runtime/vmStructs.cpp | 4 + .../share/classes/java/lang/Float16.java | 180 ++++++++++++++++++ .../share/classes/jdk/vm/ci/amd64/AMD64.java | 1 + .../intrinsics/float16/TestFP16ScalarAdd.java | 77 ++++++++ .../compiler/lib/ir_framework/IRNode.java | 20 ++ .../ir_framework/test/IREncodingPrinter.java | 1 + .../vectorization/TestFloat16VectorSum.java | 76 ++++++++ .../lang/Float16/FP16ReductionOperations.java | 135 +++++++++++++ .../lang/Float16/FP16ScalarOperations.java | 97 ++++++++++ 33 files changed, 880 insertions(+), 13 deletions(-) create mode 100644 src/java.base/share/classes/java/lang/Float16.java create mode 100644 test/hotspot/jtreg/compiler/intrinsics/float16/TestFP16ScalarAdd.java create mode 100644 test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorSum.java create mode 100644 test/jdk/java/lang/Float16/FP16ReductionOperations.java create mode 100644 test/jdk/java/lang/Float16/FP16ScalarOperations.java diff --git a/make/common/JavaCompilation.gmk b/make/common/JavaCompilation.gmk index ff7c90e5785..190a19905c9 100644 --- a/make/common/JavaCompilation.gmk +++ b/make/common/JavaCompilation.gmk @@ -274,7 +274,7 @@ define SetupJavaCompilationBody PARANOIA_FLAGS := -implicit:none -Xprefer:source -XDignore.symbol.file=true -encoding ascii $1_FLAGS += -g -Xlint:all $$($1_TARGET_RELEASE) $$(PARANOIA_FLAGS) $$(JAVA_WARNINGS_ARE_ERRORS) - $1_FLAGS += $$($1_JAVAC_FLAGS) + $1_FLAGS += $$($1_JAVAC_FLAGS) -XDenablePrimitiveClasses ifneq ($$($1_DISABLED_WARNINGS), ) $1_FLAGS += -Xlint:$$(call CommaList, $$(addprefix -, $$($1_DISABLED_WARNINGS))) diff --git a/src/hotspot/cpu/x86/assembler_x86.cpp b/src/hotspot/cpu/x86/assembler_x86.cpp index 0219a5419c1..408c78cc510 100644 --- a/src/hotspot/cpu/x86/assembler_x86.cpp +++ b/src/hotspot/cpu/x86/assembler_x86.cpp @@ -3144,6 +3144,22 @@ void Assembler::vmovdqu(XMMRegister dst, XMMRegister src) { emit_int16(0x6F, (0xC0 | encode)); } +void Assembler::vmovw(XMMRegister dst, Register src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), 0, src->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x6E, (0xC0 | encode)); +} + +void Assembler::vmovw(Register dst, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(src->encoding(), 0, dst->encoding(), VEX_SIMD_66, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x7E, (0xC0 | encode)); +} + void Assembler::vmovdqu(XMMRegister dst, Address src) { assert(UseAVX > 0, ""); InstructionMark im(this); @@ -7311,6 +7327,22 @@ void Assembler::vpaddq(XMMRegister dst, XMMRegister nds, Address src, int vector emit_operand(dst, src, 0); } +void Assembler::evaddph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(vector_len, false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ true); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_NONE, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x58, (0xC0 | encode)); +} + +void Assembler::evaddsh(XMMRegister dst, XMMRegister nds, XMMRegister src) { + assert(VM_Version::supports_avx512_fp16(), "requires AVX512-FP16"); + InstructionAttr attributes(AVX_128bit, false, /* legacy_mode */ false, /* no_mask_reg */ true, /* uses_vl */ false); + attributes.set_is_evex_instruction(); + int encode = vex_prefix_and_encode(dst->encoding(), nds->encoding(), src->encoding(), VEX_SIMD_F3, VEX_OPCODE_MAP5, &attributes); + emit_int16(0x58, (0xC0 | encode)); +} + void Assembler::psubb(XMMRegister dst, XMMRegister src) { NOT_LP64(assert(VM_Version::supports_sse2(), "")); InstructionAttr attributes(AVX_128bit, /* rex_w */ false, /* legacy_mode */ _legacy_mode_bw, /* no_mask_reg */ true, /* uses_vl */ true); @@ -11480,7 +11512,7 @@ void Assembler::evex_prefix(bool vex_r, bool vex_b, bool vex_x, bool evex_r, boo int byte2 = (vex_r ? VEX_R : 0) | (vex_x ? VEX_X : 0) | (vex_b ? VEX_B : 0) | (evex_r ? EVEX_Rb : 0); byte2 = (~byte2) & 0xF0; // confine opc opcode extensions in mm bits to lower two bits - // of form {0F, 0F_38, 0F_3A} + // of form {0F, 0F_38, 0F_3A, MAP5} byte2 |= opc; // P1: byte 3 as Wvvvv1pp diff --git a/src/hotspot/cpu/x86/assembler_x86.hpp b/src/hotspot/cpu/x86/assembler_x86.hpp index 25101c0f052..c59b5fc02c3 100644 --- a/src/hotspot/cpu/x86/assembler_x86.hpp +++ b/src/hotspot/cpu/x86/assembler_x86.hpp @@ -547,6 +547,7 @@ class Assembler : public AbstractAssembler { VEX_OPCODE_0F = 0x1, VEX_OPCODE_0F_38 = 0x2, VEX_OPCODE_0F_3A = 0x3, + VEX_OPCODE_MAP5 = 0x5, VEX_OPCODE_MASK = 0x1F }; @@ -1649,6 +1650,9 @@ class Assembler : public AbstractAssembler { void movsbl(Register dst, Address src); void movsbl(Register dst, Register src); + void vmovw(XMMRegister dst, Register src); + void vmovw(Register dst, XMMRegister src); + #ifdef _LP64 void movsbq(Register dst, Address src); void movsbq(Register dst, Register src); @@ -2394,6 +2398,8 @@ class Assembler : public AbstractAssembler { void vpaddw(XMMRegister dst, XMMRegister nds, Address src, int vector_len); void vpaddd(XMMRegister dst, XMMRegister nds, Address src, int vector_len); void vpaddq(XMMRegister dst, XMMRegister nds, Address src, int vector_len); + void evaddsh(XMMRegister dst, XMMRegister nds, XMMRegister src); + void evaddph(XMMRegister dst, XMMRegister nds, XMMRegister src, int vector_len); // Leaf level assembler routines for masked operations. void evpaddb(XMMRegister dst, KRegister mask, XMMRegister nds, XMMRegister src, bool merge, int vector_len); diff --git a/src/hotspot/cpu/x86/vm_version_x86.cpp b/src/hotspot/cpu/x86/vm_version_x86.cpp index 8e19c09f9b4..46e84e4ecb3 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.cpp +++ b/src/hotspot/cpu/x86/vm_version_x86.cpp @@ -949,6 +949,7 @@ void VM_Version::get_processor_features() { _features &= ~CPU_AVX512_VBMI2; _features &= ~CPU_AVX512_BITALG; _features &= ~CPU_AVX512_IFMA; + _features &= ~CPU_AVX512_FP16; } if (UseAVX < 2) @@ -982,6 +983,7 @@ void VM_Version::get_processor_features() { _features &= ~CPU_GFNI; _features &= ~CPU_AVX512_BITALG; _features &= ~CPU_AVX512_IFMA; + _features &= ~CPU_AVX512_FP16; } } @@ -3017,6 +3019,9 @@ uint64_t VM_Version::feature_flags() { } if (_cpuid_info.sef_cpuid7_edx.bits.serialize != 0) result |= CPU_SERIALIZE; + + if (_cpuid_info.sef_cpuid7_edx.bits.avx512_fp16 != 0) + result |= CPU_AVX512_FP16; } // ZX features. diff --git a/src/hotspot/cpu/x86/vm_version_x86.hpp b/src/hotspot/cpu/x86/vm_version_x86.hpp index cb9e806999b..f1a62295c8f 100644 --- a/src/hotspot/cpu/x86/vm_version_x86.hpp +++ b/src/hotspot/cpu/x86/vm_version_x86.hpp @@ -275,7 +275,9 @@ class VM_Version : public Abstract_VM_Version { serialize : 1, : 5, cet_ibt : 1, - : 11; + : 2, + avx512_fp16 : 1, + : 8; } bits; }; @@ -390,7 +392,8 @@ class VM_Version : public Abstract_VM_Version { decl(OSPKE, "ospke", 55) /* OS enables protection keys */ \ decl(CET_IBT, "cet_ibt", 56) /* Control Flow Enforcement - Indirect Branch Tracking */ \ decl(CET_SS, "cet_ss", 57) /* Control Flow Enforcement - Shadow Stack */ \ - decl(AVX512_IFMA, "avx512_ifma", 58) /* Integer Vector FMA instructions*/ + decl(AVX512_IFMA, "avx512_ifma", 58) /* Integer Vector FMA instructions*/ \ + decl(AVX512_FP16, "avx512_fp16", 59) /* AVX512 FP16 ISA support*/ #define DECLARE_CPU_FEATURE_FLAG(id, name, bit) CPU_##id = (1ULL << bit), CPU_FEATURE_FLAGS(DECLARE_CPU_FEATURE_FLAG) @@ -696,6 +699,7 @@ class VM_Version : public Abstract_VM_Version { static bool supports_avx512_bitalg() { return (_features & CPU_AVX512_BITALG) != 0; } static bool supports_avx512_vbmi() { return (_features & CPU_AVX512_VBMI) != 0; } static bool supports_avx512_vbmi2() { return (_features & CPU_AVX512_VBMI2) != 0; } + static bool supports_avx512_fp16() { return (_features & CPU_AVX512_FP16) != 0; } static bool supports_hv() { return (_features & CPU_HV) != 0; } static bool supports_serialize() { return (_features & CPU_SERIALIZE) != 0; } static bool supports_f16c() { return (_features & CPU_F16C) != 0; } diff --git a/src/hotspot/cpu/x86/x86.ad b/src/hotspot/cpu/x86/x86.ad index 20fdac3f978..e4daf90d3ba 100644 --- a/src/hotspot/cpu/x86/x86.ad +++ b/src/hotspot/cpu/x86/x86.ad @@ -1451,6 +1451,13 @@ bool Matcher::match_rule_supported(int opcode) { return false; } break; + case Op_AddHF: + case Op_ReinterpretS2HF: + case Op_ReinterpretHF2S: + if (!VM_Version::supports_avx512_fp16()) { + return false; + } + break; case Op_VectorLoadShuffle: case Op_VectorRearrange: case Op_MulReductionVI: @@ -1722,6 +1729,11 @@ bool Matcher::match_rule_supported_vector(int opcode, int vlen, BasicType bt) { // * 128bit vroundpd instruction is present only in AVX1 int size_in_bits = vlen * type2aelembytes(bt) * BitsPerByte; switch (opcode) { + case Op_AddVHF: + if (!VM_Version::supports_avx512_fp16()) { + return false; + } + break; case Op_AbsVF: case Op_NegVF: if ((vlen == 16) && (VM_Version::supports_avx512dq() == false)) { @@ -10149,4 +10161,52 @@ instruct DoubleClassCheck_reg_reg_vfpclass(rRegI dst, regD src, kReg ktmp, rFlag ins_pipe(pipe_slow); %} +instruct reinterpretS2H (regF dst, rRegI src) +%{ + match(Set dst (ReinterpretS2HF src)); + format %{ "vmovw $dst, $src" %} + ins_encode %{ + __ vmovw($dst$$XMMRegister, $src$$Register); + %} + ins_pipe(pipe_slow); +%} + +instruct convF2HFAndS2HF (regF dst, regF src) +%{ + match(Set dst (ReinterpretS2HF (ConvF2HF src))); + format %{ "convF2HFAndS2HF $dst, $src" %} + ins_encode %{ + __ vcvtps2ph($dst$$XMMRegister, $src$$XMMRegister, 0x04, Assembler::AVX_128bit); + %} + ins_pipe(pipe_slow); +%} + +instruct reinterpretH2S (rRegI dst, regF src) +%{ + match(Set dst (ReinterpretHF2S src)); + format %{ "vmovw $dst, $src" %} + ins_encode %{ + __ vmovw($dst$$Register, $src$$XMMRegister); + %} + ins_pipe(pipe_slow); +%} +instruct addFP16_scalar (regF dst, regF src1, regF src2) +%{ + match(Set dst (AddHF src1 src2)); + format %{ "vaddsh $dst, $src1, $src2" %} + ins_encode %{ + __ evaddsh($dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister); + %} + ins_pipe(pipe_slow); +%} +instruct vaddVHF (vec dst, vec src1, vec src2) +%{ + match(Set dst (AddVHF src1 src2)); + format %{ "vaddph $dst, $src1, $src2" %} + ins_encode %{ + int vlen_enc = vector_length_encoding(this); + __ evaddph($dst$$XMMRegister, $src1$$XMMRegister, $src2$$XMMRegister, vlen_enc); + %} + ins_pipe(pipe_slow); +%} diff --git a/src/hotspot/share/adlc/formssel.cpp b/src/hotspot/share/adlc/formssel.cpp index fedda555a52..870607537ea 100644 --- a/src/hotspot/share/adlc/formssel.cpp +++ b/src/hotspot/share/adlc/formssel.cpp @@ -4201,7 +4201,7 @@ Form::DataType MatchRule::is_ideal_load() const { bool MatchRule::is_vector() const { static const char *vector_list[] = { - "AddVB","AddVS","AddVI","AddVL","AddVF","AddVD", + "AddVB","AddVHF", "AddVS","AddVI","AddVL","AddVF","AddVD", "SubVB","SubVS","SubVI","SubVL","SubVF","SubVD", "MulVB","MulVS","MulVI","MulVL","MulVF","MulVD", "DivVF","DivVD", diff --git a/src/hotspot/share/classfile/classFileParser.cpp b/src/hotspot/share/classfile/classFileParser.cpp index be31cb11a68..ad94914ab1e 100644 --- a/src/hotspot/share/classfile/classFileParser.cpp +++ b/src/hotspot/share/classfile/classFileParser.cpp @@ -4695,6 +4695,22 @@ static void check_illegal_static_method(const InstanceKlass* this_klass, TRAPS) } } +// utility function to skip over internal jdk primitive classes used to override the need for passing +// an explict JVM flag EnablePrimitiveClasses. +bool ClassFileParser::is_jdk_internal_class(const Symbol* class_name) const { + if (vmSymbols::java_lang_Float16() == class_name) { + return (EnablePrimitiveClasses = true); + } + return false; +} + +bool ClassFileParser::is_jdk_internal_class_sig(const char* sig) const { + if (strstr(sig, vmSymbols::java_lang_Float16_signature()->as_C_string())) { + return true; + } + return false; +} + // utility methods for format checking void ClassFileParser::verify_legal_class_modifiers(jint flags, const char* name, bool is_Object, TRAPS) const { @@ -4725,7 +4741,7 @@ void ClassFileParser::verify_legal_class_modifiers(jint flags, const char* name, return; } - if (is_primitive_class && !EnablePrimitiveClasses) { + if (is_primitive_class && !is_jdk_internal_class(_class_name) && !EnablePrimitiveClasses) { ResourceMark rm(THREAD); Exceptions::fthrow( THREAD_AND_LOCATION, @@ -5157,7 +5173,7 @@ const char* ClassFileParser::skip_over_field_signature(const char* signature, case JVM_SIGNATURE_PRIMITIVE_OBJECT: // Can't enable this check fully until JDK upgrades the bytecode generators (TODO: JDK-8270852). // For now, compare to class file version 51 so old verifier doesn't see Q signatures. - if ( (_major_version < 51 /* CONSTANT_CLASS_DESCRIPTORS */ ) || (!EnablePrimitiveClasses)) { + if ( (_major_version < 51 /* CONSTANT_CLASS_DESCRIPTORS */ ) || (!EnablePrimitiveClasses && !is_jdk_internal_class_sig(signature))) { classfile_parse_error("Class name contains illegal Q-signature " "in descriptor in class file %s, requires option -XX:+EnablePrimitiveClasses", CHECK_0); diff --git a/src/hotspot/share/classfile/classFileParser.hpp b/src/hotspot/share/classfile/classFileParser.hpp index 49971dd0b14..19c03ab3989 100644 --- a/src/hotspot/share/classfile/classFileParser.hpp +++ b/src/hotspot/share/classfile/classFileParser.hpp @@ -217,6 +217,10 @@ class ClassFileParser { bool _has_vanilla_constructor; int _max_bootstrap_specifier_index; // detects BSS values + bool is_jdk_internal_class(const Symbol* class_name) const; + + bool is_jdk_internal_class_sig(const char* sig) const; + void parse_stream(const ClassFileStream* const stream, TRAPS); void mangle_hidden_class_name(InstanceKlass* const ik); diff --git a/src/hotspot/share/classfile/vmClassMacros.hpp b/src/hotspot/share/classfile/vmClassMacros.hpp index 89aaa1298e1..268c6bbd3d3 100644 --- a/src/hotspot/share/classfile/vmClassMacros.hpp +++ b/src/hotspot/share/classfile/vmClassMacros.hpp @@ -171,6 +171,7 @@ do_klass(Boolean_klass, java_lang_Boolean ) \ do_klass(Character_klass, java_lang_Character ) \ do_klass(Float_klass, java_lang_Float ) \ + do_klass(Float16_klass, java_lang_Float16 ) \ do_klass(Double_klass, java_lang_Double ) \ do_klass(Byte_klass, java_lang_Byte ) \ do_klass(Short_klass, java_lang_Short ) \ diff --git a/src/hotspot/share/classfile/vmIntrinsics.hpp b/src/hotspot/share/classfile/vmIntrinsics.hpp index 34b7c077653..1230b82a8b5 100644 --- a/src/hotspot/share/classfile/vmIntrinsics.hpp +++ b/src/hotspot/share/classfile/vmIntrinsics.hpp @@ -195,6 +195,12 @@ class methodHandle; do_intrinsic(_dsignum, java_lang_Math, signum_name, double_double_signature, F_S) \ do_intrinsic(_fsignum, java_lang_Math, signum_name, float_float_signature, F_S) \ \ + \ + /* Float16 intrinsics, similar to what we have in Math. */ \ + do_intrinsic(_sum_float16, java_lang_Float16, sum_name, floa16_float16_signature, F_S) \ + do_name(sum_name, "sum") \ + do_signature(floa16_float16_signature, "(Qjava/lang/Float16;Qjava/lang/Float16;)Qjava/lang/Float16;") \ + \ /* StrictMath intrinsics, similar to what we have in Math. */ \ do_intrinsic(_min_strict, java_lang_StrictMath, min_name, int2_int_signature, F_S) \ do_intrinsic(_max_strict, java_lang_StrictMath, max_name, int2_int_signature, F_S) \ diff --git a/src/hotspot/share/classfile/vmSymbols.hpp b/src/hotspot/share/classfile/vmSymbols.hpp index 623c81f2d4c..0644287b6dd 100644 --- a/src/hotspot/share/classfile/vmSymbols.hpp +++ b/src/hotspot/share/classfile/vmSymbols.hpp @@ -81,6 +81,8 @@ class SerializeClosure; template(java_lang_Character_CharacterCache, "java/lang/Character$CharacterCache") \ template(java_lang_CharacterDataLatin1, "java/lang/CharacterDataLatin1") \ template(java_lang_Float, "java/lang/Float") \ + template(java_lang_Float16, "java/lang/Float16") \ + template(java_lang_Float16_signature, "Qjava/lang/Float16;") \ template(java_lang_Double, "java/lang/Double") \ template(java_lang_Byte, "java/lang/Byte") \ template(java_lang_Byte_ByteCache, "java/lang/Byte$ByteCache") \ diff --git a/src/hotspot/share/oops/inlineKlass.cpp b/src/hotspot/share/oops/inlineKlass.cpp index 9b9d5e3ad52..c40460c0a1d 100644 --- a/src/hotspot/share/oops/inlineKlass.cpp +++ b/src/hotspot/share/oops/inlineKlass.cpp @@ -542,6 +542,9 @@ void InlineKlass::restore_unshareable_info(ClassLoaderData* loader_data, Handle if (value_array_klasses() != nullptr) { value_array_klasses()->restore_unshareable_info(ClassLoaderData::the_null_class_loader_data(), Handle(), CHECK); } + if (vmSymbols::java_lang_Float16() == name()) { + EnablePrimitiveClasses = true; + } } // oop verify diff --git a/src/hotspot/share/opto/addnode.hpp b/src/hotspot/share/opto/addnode.hpp index 709958b6abf..b34a43bc7e9 100644 --- a/src/hotspot/share/opto/addnode.hpp +++ b/src/hotspot/share/opto/addnode.hpp @@ -115,7 +115,7 @@ class AddLNode : public AddNode { }; //------------------------------AddFNode--------------------------------------- -// Add 2 floats +// Add 2 half-precision floats class AddFNode : public AddNode { public: AddFNode( Node *in1, Node *in2 ) : AddNode(in1,in2) {} @@ -131,6 +131,14 @@ class AddFNode : public AddNode { virtual uint ideal_reg() const { return Op_RegF; } }; +//------------------------------AddHFNode--------------------------------------- +// Add 2 floats +class AddHFNode : public AddFNode { +public: + AddHFNode( Node *in1, Node *in2 ) : AddFNode(in1,in2) {} + virtual int Opcode() const; +}; + //------------------------------AddDNode--------------------------------------- // Add 2 doubles class AddDNode : public AddNode { diff --git a/src/hotspot/share/opto/c2compiler.cpp b/src/hotspot/share/opto/c2compiler.cpp index 2be8dbd2864..4f18a8c77a2 100644 --- a/src/hotspot/share/opto/c2compiler.cpp +++ b/src/hotspot/share/opto/c2compiler.cpp @@ -749,7 +749,9 @@ bool C2Compiler::is_intrinsic_supported(const methodHandle& method) { case vmIntrinsics::_Preconditions_checkLongIndex: case vmIntrinsics::_getObjectSize: break; - + case vmIntrinsics::_sum_float16: + if (!Matcher::match_rule_supported(Op_AddHF)) return false; + break; case vmIntrinsics::_VectorCompressExpand: case vmIntrinsics::_VectorUnaryOp: case vmIntrinsics::_VectorBinaryOp: diff --git a/src/hotspot/share/opto/classes.hpp b/src/hotspot/share/opto/classes.hpp index ad7f26db27e..dd771cb35b2 100644 --- a/src/hotspot/share/opto/classes.hpp +++ b/src/hotspot/share/opto/classes.hpp @@ -36,6 +36,7 @@ macro(AddF) macro(AddI) macro(AddL) macro(AddP) +macro(AddHF) macro(Allocate) macro(AllocateArray) macro(AndI) @@ -373,6 +374,7 @@ macro(XorL) macro(InlineType) macro(Vector) macro(AddVB) +macro(AddVHF) macro(AddVS) macro(AddVI) macro(AddReductionVI) @@ -486,6 +488,8 @@ macro(ExtractF) macro(ExtractD) macro(Digit) macro(LowerCase) +macro(ReinterpretS2HF) +macro(ReinterpretHF2S) macro(UpperCase) macro(Whitespace) macro(VectorBox) diff --git a/src/hotspot/share/opto/convertnode.cpp b/src/hotspot/share/opto/convertnode.cpp index 91803f394ad..e1884431cd7 100644 --- a/src/hotspot/share/opto/convertnode.cpp +++ b/src/hotspot/share/opto/convertnode.cpp @@ -853,3 +853,35 @@ const Type* RoundDoubleModeNode::Value(PhaseGVN* phase) const { return Type::DOUBLE; } //============================================================================= + +const Type* ReinterpretS2HFNode::Value(PhaseGVN* phase) const { + const Type* type = phase->type( in(1) ); + // Convert FP16 constant value to Float constant value, this will allow + // further constant folding to be done at float granularity by value routines + // of FP16 IR nodes. + if (type->isa_int() && type->is_int()->is_con()) { + jshort hfval = type->is_int()->get_con(); + jfloat fval = StubRoutines::hf2f(hfval); + return TypeF::make(fval); + } + return Type::FLOAT; +} + +Node* ReinterpretS2HFNode::Identity(PhaseGVN* phase) { + if (in(1)->Opcode() == Op_ReinterpretHF2S) { + assert(in(1)->in(1)->bottom_type()->isa_float(), ""); + return in(1)->in(1); + } + return this; +} + +const Type* ReinterpretHF2SNode::Value(PhaseGVN* phase) const { + const Type* type = phase->type( in(1) ); + // Convert Float constant value to FP16 constant value. + if (type->isa_float_constant()) { + jfloat fval = type->is_float_constant()->_f; + jshort hfval = StubRoutines::f2hf(fval); + return TypeInt::make(hfval); + } + return TypeInt::SHORT; +} diff --git a/src/hotspot/share/opto/convertnode.hpp b/src/hotspot/share/opto/convertnode.hpp index 45277aead8d..59680389f7b 100644 --- a/src/hotspot/share/opto/convertnode.hpp +++ b/src/hotspot/share/opto/convertnode.hpp @@ -173,6 +173,25 @@ class ConvI2FNode : public Node { virtual uint ideal_reg() const { return Op_RegF; } }; +class ReinterpretS2HFNode : public Node { + public: + ReinterpretS2HFNode(Node* in1) : Node(0, in1) {} + virtual int Opcode() const; + virtual const Type* bottom_type() const { return Type::FLOAT; } + virtual const Type* Value(PhaseGVN* phase) const; + virtual Node* Identity(PhaseGVN* phase); + virtual uint ideal_reg() const { return Op_RegF; } +}; + +class ReinterpretHF2SNode : public Node { + public: + ReinterpretHF2SNode( Node *in1 ) : Node(0,in1) {} + virtual int Opcode() const; + virtual const Type* Value(PhaseGVN* phase) const; + virtual const Type* bottom_type() const { return TypeInt::SHORT; } + virtual uint ideal_reg() const { return Op_RegI; } +}; + class RoundFNode : public Node { public: RoundFNode( Node *in1 ) : Node(0,in1) {} diff --git a/src/hotspot/share/opto/inlinetypenode.hpp b/src/hotspot/share/opto/inlinetypenode.hpp index 4ef042a090a..872a391b759 100644 --- a/src/hotspot/share/opto/inlinetypenode.hpp +++ b/src/hotspot/share/opto/inlinetypenode.hpp @@ -50,9 +50,6 @@ class InlineTypeNode : public TypeNode { // Nodes are connected in increasing order of the index of the field they correspond to. }; - // Get the klass defining the field layout of the inline type - ciInlineKlass* inline_klass() const { return type()->inline_klass(); } - void make_scalar_in_safepoint(PhaseIterGVN* igvn, Unique_Node_List& worklist, SafePointNode* sfpt); const TypePtr* field_adr_type(Node* base, int offset, ciInstanceKlass* holder, DecoratorSet decorators, PhaseGVN& gvn) const; @@ -77,6 +74,9 @@ class InlineTypeNode : public TypeNode { static InlineTypeNode* make_from_flattened_impl(GraphKit* kit, ciInlineKlass* vk, Node* obj, Node* ptr, ciInstanceKlass* holder, int holder_offset, DecoratorSet decorators, GrowableArray& visited); public: + // Get the klass defining the field layout of the inline type + ciInlineKlass* inline_klass() const { return type()->inline_klass(); } + // Create with default field values static InlineTypeNode* make_default(PhaseGVN& gvn, ciInlineKlass* vk); // Create uninitialized diff --git a/src/hotspot/share/opto/library_call.cpp b/src/hotspot/share/opto/library_call.cpp index 7604718ccf8..3b089324b99 100644 --- a/src/hotspot/share/opto/library_call.cpp +++ b/src/hotspot/share/opto/library_call.cpp @@ -544,6 +544,8 @@ bool LibraryCallKit::try_to_inline(int predicate) { case vmIntrinsics::_floatToFloat16: case vmIntrinsics::_float16ToFloat: return inline_fp_conversions(intrinsic_id()); + case vmIntrinsics::_sum_float16: return inline_fp16_operations(intrinsic_id()); + case vmIntrinsics::_floatIsFinite: case vmIntrinsics::_floatIsInfinite: case vmIntrinsics::_doubleIsFinite: @@ -4894,6 +4896,34 @@ bool LibraryCallKit::inline_native_Reflection_getCallerClass() { return false; // bail-out; let JVM_GetCallerClass do the work } +bool LibraryCallKit::inline_fp16_operations(vmIntrinsics::ID id) { + if (!Matcher::match_rule_supported(Op_ReinterpretS2HF) || + !Matcher::match_rule_supported(Op_ReinterpretHF2S)) { + return false; + } + + Node* result = nullptr; + Node* val1 = argument(0); // receiver + Node* val2 = argument(1); // argument + assert(val1->is_InlineType() && val2->is_InlineType(), ""); + + Node* fld1 = _gvn.transform(new ReinterpretS2HFNode(val1->as_InlineType()->field_value(0))); + Node* fld2 = _gvn.transform(new ReinterpretS2HFNode(val2->as_InlineType()->field_value(0))); + + switch (id) { + case vmIntrinsics::_sum_float16: result = _gvn.transform(new AddHFNode(fld1, fld2)); break; + + default: + fatal_unexpected_iid(id); + break; + } + InlineTypeNode* box = InlineTypeNode::make_uninitialized(_gvn, val1->as_InlineType()->inline_klass(), true); + Node* short_result = _gvn.transform(new ReinterpretHF2SNode(result)); + box->set_field_value(0, short_result); + set_result(_gvn.transform(box)); + return true; +} + bool LibraryCallKit::inline_fp_conversions(vmIntrinsics::ID id) { Node* arg = argument(0); Node* result = nullptr; diff --git a/src/hotspot/share/opto/library_call.hpp b/src/hotspot/share/opto/library_call.hpp index bd0ed383367..8b6a7f03399 100644 --- a/src/hotspot/share/opto/library_call.hpp +++ b/src/hotspot/share/opto/library_call.hpp @@ -303,6 +303,7 @@ class LibraryCallKit : public GraphKit { bool inline_unsafe_load_store(BasicType type, LoadStoreKind kind, AccessKind access_kind); bool inline_unsafe_fence(vmIntrinsics::ID id); bool inline_onspinwait(); + bool inline_fp16_operations(vmIntrinsics::ID id); bool inline_fp_conversions(vmIntrinsics::ID id); bool inline_fp_range_check(vmIntrinsics::ID id); bool inline_number_methods(vmIntrinsics::ID id); diff --git a/src/hotspot/share/opto/superword.cpp b/src/hotspot/share/opto/superword.cpp index f665089d0bf..36a79123397 100644 --- a/src/hotspot/share/opto/superword.cpp +++ b/src/hotspot/share/opto/superword.cpp @@ -2903,6 +2903,13 @@ bool SuperWord::output() { int vopc = VectorCastNode::opcode(opc, in->bottom_type()->is_vect()->element_basic_type()); vn = VectorCastNode::make(vopc, in, bt, vlen); vlen_in_bytes = vn->as_Vector()->length_in_bytes(); + } else if (opc == Op_ReinterpretS2HF || opc == Op_ReinterpretHF2S) { + assert(n->req() == 2, "only one input expected"); + BasicType bt = velt_basic_type(n); + const TypeVect* vt = TypeVect::make(bt, vlen); + Node* in = vector_opd(p, 1); + vn = VectorReinterpretNode::make(in, vt, vt); + vlen_in_bytes = vn->as_Vector()->length_in_bytes(); } else if (opc == Op_FmaD || opc == Op_FmaF) { // Promote operands to vector Node* in1 = vector_opd(p, 1); @@ -3758,6 +3765,9 @@ const Type* SuperWord::container_type(Node* n) { // propagating the type of memory operations. return TypeInt::INT; } + if (VectorNode::is_float16_node(n->Opcode())) { + return TypeInt::SHORT; + } return t; } diff --git a/src/hotspot/share/opto/vectornode.cpp b/src/hotspot/share/opto/vectornode.cpp index fdfa0718a45..b55ee51faac 100644 --- a/src/hotspot/share/opto/vectornode.cpp +++ b/src/hotspot/share/opto/vectornode.cpp @@ -46,6 +46,7 @@ int VectorNode::opcode(int sopc, BasicType bt) { case T_INT: return Op_AddVI; default: return 0; } + case Op_AddHF: return (bt == T_SHORT ? Op_AddVHF : 0); case Op_AddL: return (bt == T_LONG ? Op_AddVL : 0); case Op_AddF: return (bt == T_FLOAT ? Op_AddVF : 0); case Op_AddD: return (bt == T_DOUBLE ? Op_AddVD : 0); @@ -267,6 +268,9 @@ int VectorNode::opcode(int sopc, BasicType bt) { return Op_SignumVF; case Op_SignumD: return Op_SignumVD; + case Op_ReinterpretS2HF: + case Op_ReinterpretHF2S: + return Op_VectorReinterpret; default: assert(!VectorNode::is_convert_opcode(sopc), @@ -605,6 +609,16 @@ bool VectorNode::is_rotate_opcode(int opc) { } } +bool VectorNode::is_float16_node(int opc) { + switch (opc) { + case Op_AddHF: + case Op_ReinterpretS2HF: + return true; + default: + return false; + } +} + bool VectorNode::is_scalar_rotate(Node* n) { if (is_rotate_opcode(n->Opcode())) { return true; @@ -674,7 +688,7 @@ void VectorNode::vector_operands(Node* n, uint* start, uint* end) { *start = 1; *end = (n->is_Con() && Matcher::supports_vector_constant_rotates(n->get_int())) ? 2 : 3; break; - case Op_AddI: case Op_AddL: case Op_AddF: case Op_AddD: + case Op_AddI: case Op_AddHF: case Op_AddL: case Op_AddF: case Op_AddD: case Op_SubI: case Op_SubL: case Op_SubF: case Op_SubD: case Op_MulI: case Op_MulL: case Op_MulF: case Op_MulD: case Op_DivF: case Op_DivD: @@ -732,6 +746,7 @@ VectorNode* VectorNode::make(int vopc, Node* n1, Node* n2, const TypeVect* vt, b switch (vopc) { case Op_AddVB: return new AddVBNode(n1, n2, vt); + case Op_AddVHF: return new AddVHFNode(n1, n2, vt); case Op_AddVS: return new AddVSNode(n1, n2, vt); case Op_AddVI: return new AddVINode(n1, n2, vt); case Op_AddVL: return new AddVLNode(n1, n2, vt); @@ -1733,6 +1748,10 @@ Node* VectorReinterpretNode::Identity(PhaseGVN *phase) { return this; } +VectorNode* VectorReinterpretNode::make(Node* n, const TypeVect* dst_vt, const TypeVect* src_vt) { + return new VectorReinterpretNode(n, dst_vt, src_vt); +} + Node* VectorInsertNode::make(Node* vec, Node* new_val, int position) { assert(position < (int)vec->bottom_type()->is_vect()->length(), "pos in range"); ConINode* pos = ConINode::make(position); diff --git a/src/hotspot/share/opto/vectornode.hpp b/src/hotspot/share/opto/vectornode.hpp index 39a5b49454d..8633872ca3d 100644 --- a/src/hotspot/share/opto/vectornode.hpp +++ b/src/hotspot/share/opto/vectornode.hpp @@ -106,6 +106,8 @@ class VectorNode : public TypeNode { static bool is_muladds2i(Node* n); static bool is_roundopD(Node* n); static bool is_scalar_rotate(Node* n); + static bool is_float16_node(int opc); + static bool is_vector_rotate_supported(int opc, uint vlen, BasicType bt); static bool is_vector_integral_negate_supported(int opc, uint vlen, BasicType bt, bool use_predicate); static bool is_populate_index_supported(BasicType bt); @@ -186,6 +188,14 @@ class AddVFNode : public VectorNode { virtual int Opcode() const; }; +//------------------------------AddVHFNode-------------------------------------- +// Vector add float +class AddVHFNode : public VectorNode { +public: + AddVHFNode(Node* in1, Node* in2, const TypeVect* vt) : VectorNode(in1, in2, vt) {} + virtual int Opcode() const; +}; + //------------------------------AddVDNode-------------------------------------- // Vector add double class AddVDNode : public VectorNode { @@ -1544,6 +1554,8 @@ class VectorReinterpretNode : public VectorNode { virtual Node* Identity(PhaseGVN* phase); virtual int Opcode() const; + + static VectorNode* make(Node* n, const TypeVect* dst_vt, const TypeVect* src_vt); }; class VectorCastNode : public VectorNode { diff --git a/src/hotspot/share/runtime/vmStructs.cpp b/src/hotspot/share/runtime/vmStructs.cpp index 0f94ebcdaf7..b479bfbcdb8 100644 --- a/src/hotspot/share/runtime/vmStructs.cpp +++ b/src/hotspot/share/runtime/vmStructs.cpp @@ -1405,6 +1405,7 @@ declare_c2_type(AddINode, AddNode) \ declare_c2_type(AddLNode, AddNode) \ declare_c2_type(AddFNode, AddNode) \ + declare_c2_type(AddHFNode, AddNode) \ declare_c2_type(AddDNode, AddNode) \ declare_c2_type(AddPNode, Node) \ declare_c2_type(OrINode, AddNode) \ @@ -1423,6 +1424,8 @@ declare_c2_type(StartNode, MultiNode) \ declare_c2_type(StartOSRNode, StartNode) \ declare_c2_type(ParmNode, ProjNode) \ + declare_c2_type(ReinterpretS2HFNode, Node) \ + declare_c2_type(ReinterpretHF2SNode, Node) \ declare_c2_type(ReturnNode, Node) \ declare_c2_type(RethrowNode, Node) \ declare_c2_type(TailCallNode, ReturnNode) \ @@ -1680,6 +1683,7 @@ declare_c2_type(AddVLNode, VectorNode) \ declare_c2_type(AddReductionVLNode, ReductionNode) \ declare_c2_type(AddVFNode, VectorNode) \ + declare_c2_type(AddVHFNode, VectorNode) \ declare_c2_type(AddReductionVFNode, ReductionNode) \ declare_c2_type(AddVDNode, VectorNode) \ declare_c2_type(AddReductionVDNode, ReductionNode) \ diff --git a/src/java.base/share/classes/java/lang/Float16.java b/src/java.base/share/classes/java/lang/Float16.java new file mode 100644 index 00000000000..9df51d81f2f --- /dev/null +++ b/src/java.base/share/classes/java/lang/Float16.java @@ -0,0 +1,180 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package java.lang; + +import java.lang.invoke.MethodHandles; +import java.lang.constant.Constable; +import java.lang.constant.ConstantDesc; +import java.util.Optional; + +import jdk.internal.math.FloatConsts; +import jdk.internal.math.FloatingDecimal; +import jdk.internal.math.FloatToDecimal; +import jdk.internal.vm.annotation.IntrinsicCandidate; + +/** + * The {@code Float16} is a primitive value class holding 16-bit data in IEEE 754 binary16 format + * {@code Float16} contains a single field whose type is {@code short}. + * + * Binary16 Format: + * S EEEEE MMMMMMMMMM + * Sign - 1 bit + * Exponent - 5 bits + * Significand - 10 bits + * + *

This is a primitive value class and its objects are + * identity-less non-nullable value objects. + * + * @author Jatin Bhateja + * @since 20.00 + */ + +// Currently Float16 is a primitive class but in future will be aligned with +// Enhanced Primitive Boxes described by JEP-402 (https://openjdk.org/jeps/402) +public primitive class Float16 extends Number { + private final short value; + + /** + * Returns a {@code Float16} instance wrapping IEEE 754 binary16 + * encoded {@code short} value. + * + * @param value a short value. + * @since 20 + */ + private Float16 (short value ) { + this.value = value; + } + + /** + * Returns a {@code Float16} instance wrapping IEEE 754 binary16 + * encoded {@code short} value. + * + * @param value a short value. + * @return a {@code Float16} instance representing {@code value}. + * @since 20 + */ + public static Float16 valueOf(short value) { + return new Float16(value); + } + + /** + * Returns the value of this {@code Float16} as a {@code byte} after + * a narrowing primitive conversion. + * + * @return the binary16 encoded {@code short} value represented by this object + * converted to type {@code byte} + * @jls 5.1.3 Narrowing Primitive Conversion + */ + public byte byteValue() { + return (byte)Float.float16ToFloat(value); + } + + /** + * Returns the value of this {@code Float16} as a {@code short} + * after a narrowing primitive conversion. + * + * @return the binary16 encoded {@code short} value represented by this object + * converted to type {@code short} + * @jls 5.1.3 Narrowing Primitive Conversion + * @since 1.1 + */ + public short shortValue() { + return (short)Float.float16ToFloat(value); + } + + /** + * Returns the value of this {@code Float16} as an {@code int} after + * a widening primitive conversion. + * + * @return the binary16 encoded {@code short} value represented by this object + * converted to type {@code int} + * @jls 5.1.3 Widening Primitive Conversion + */ + public int intValue() { + return (int)Float.float16ToFloat(value); + } + + /** + * Returns value of this {@code Float16} as a {@code long} after a + * widening conversion. + * + * @return the binary16 encoded {@code short} value represented by this object + * converted to type {@code long} + * @jls 5.1.3 Widening Primitive Conversion + */ + public long longValue() { + return (long)Float.float16ToFloat(value); + } + + /** + * Returns the {@code float} value of this {@code Float16} object. + * + * @return the binary16 encoded {@code short} value represented by this object + * converted to type {@code float} + */ + public float floatValue() { + return Float.float16ToFloat(value); + } + + /** + * Returns the value of this {@code Float16} as a {@code double} + * after a widening primitive conversion. + * + * @apiNote + * This method corresponds to the convertFormat operation defined + * in IEEE 754. + * + * @return the binary16 encoded {@code short} value represented by this + * object converted to type {@code double} + * @jls 5.1.2 Widening Primitive Conversion + */ + public double doubleValue() { + return (double)Float.float16ToFloat(value); + } + + /** + * Adds two {@code Float16} values together as per the + operator semantics. + * + * @apiNote This method corresponds to the addition operation + * defined in IEEE 754. + * + * @param a the first operand + * @param b the second operand + * @return the sum of {@code a} and {@code b} + * @since 20 + */ + @IntrinsicCandidate + public static Float16 sum(Float16 a, Float16 b) { + return Float16.valueOf(Float.floatToFloat16(Float.float16ToFloat(a.float16ToRawShortBits()) + Float.float16ToFloat(b.float16ToRawShortBits()))); + } + + /** + * Return raw short value. + * @return raw binary16 encoded {@code short} value represented by this object. + * @since 20 + */ + public short float16ToRawShortBits() { return value; } +} diff --git a/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java b/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java index 9b6b0e9c352..2b6d988bbfa 100644 --- a/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java +++ b/src/jdk.internal.vm.ci/share/classes/jdk/vm/ci/amd64/AMD64.java @@ -232,6 +232,7 @@ public enum CPUFeature implements CPUFeatureName { CET_IBT, CET_SS, AVX512_IFMA, + AVX512_FP16, } private final EnumSet features; diff --git a/test/hotspot/jtreg/compiler/intrinsics/float16/TestFP16ScalarAdd.java b/test/hotspot/jtreg/compiler/intrinsics/float16/TestFP16ScalarAdd.java new file mode 100644 index 00000000000..d2a214d0059 --- /dev/null +++ b/test/hotspot/jtreg/compiler/intrinsics/float16/TestFP16ScalarAdd.java @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** +* @test +* @bug 8308363 +* @summary Validate compiler IR for FP16 scalar operations. +* @requires vm.compiler2.enabled +* @library /test/lib / +* @compile -XDenablePrimitiveClasses TestFP16ScalarAdd.java +* @run driver compiler.vectorization.TestFP16ScalarAdd +*/ + +package compiler.vectorization; +import compiler.lib.ir_framework.*; +import java.util.Random; + +public class TestFP16ScalarAdd { + private static final int count = 1024; + + private short[] src; + private short[] dst; + private short res; + + public static void main(String args[]) { + TestFramework.run(TestFP16ScalarAdd.class); + } + + public TestFP16ScalarAdd() { + src = new short[count]; + dst = new short[count]; + for (int i = 0; i < count; i++) { + src[i] = Float.floatToFloat16(i); + } + } + + @Test + @IR(applyIfCPUFeature = {"avx512_fp16", "true"}, counts = {IRNode.ADD_HF, "> 0", IRNode.REINTERPRET_S2HF, "> 0", IRNode.REINTERPRET_HF2S, "> 0"}) + public void test1() { + Float16 res = Float16.valueOf((short)0); + for (int i = 0; i < count; i++) { + res = Float16.sum(res, Float16.valueOf(src[i])); + dst[i] = res.float16ToRawShortBits(); + } + } + + @Test + @IR(applyIfCPUFeature = {"avx512_fp16", "true"}, failOn = {IRNode.ADD_HF, IRNode.REINTERPRET_S2HF, IRNode.REINTERPRET_HF2S}) + public void test2() { + Float16 hf0 = Float16.valueOf((short)0); + Float16 hf1 = Float16.valueOf((short)15360); + Float16 hf2 = Float16.valueOf((short)16384); + Float16 hf3 = Float16.valueOf((short)16896); + Float16 hf4 = Float16.valueOf((short)17408); + res = Float16.sum(Float16.sum(Float16.sum(Float16.sum(hf0, hf1), hf2), hf3), hf4).float16ToRawShortBits(); + } +} diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java index 023cdb305ea..446b11a840b 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/IRNode.java @@ -151,6 +151,11 @@ public class IRNode { beforeMatchingNameRegex(ADD_L, "AddL"); } + public static final String ADD_HF = PREFIX + "ADD_HF" + POSTFIX; + static { + beforeMatchingNameRegex(ADD_HF, "AddHF"); + } + public static final String ADD_V = PREFIX + "ADD_V" + POSTFIX; static { beforeMatchingNameRegex(ADD_V, "AddV(B|S|I|L|F|D)"); @@ -171,6 +176,11 @@ public class IRNode { beforeMatchingNameRegex(ADD_VF, "AddVF"); } + public static final String ADD_VHF = PREFIX + "ADD_VHF" + POSTFIX; + static { + beforeMatchingNameRegex(ADD_VHF, "AddVHF"); + } + public static final String ADD_REDUCTION_V = PREFIX + "ADD_REDUCTION_V" + POSTFIX; static { beforeMatchingNameRegex(ADD_REDUCTION_V, "AddReductionV(B|S|I|L|F|D)"); @@ -893,6 +903,16 @@ public class IRNode { trapNodes(RANGE_CHECK_TRAP,"range_check"); } + public static final String REINTERPRET_S2HF = PREFIX + "REINTERPRET_S2HF" + POSTFIX; + static { + beforeMatchingNameRegex(REINTERPRET_S2HF, "ReinterpretS2HF"); + } + + public static final String REINTERPRET_HF2S = PREFIX + "REINTERPRET_HF2S" + POSTFIX; + static { + beforeMatchingNameRegex(REINTERPRET_HF2S, "ReinterpretHF2S"); + } + public static final String REPLICATE_B = PREFIX + "REPLICATE_B" + POSTFIX; static { String regex = START + "ReplicateB" + MID + END; diff --git a/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java b/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java index f293adba8b4..766b65a8e12 100644 --- a/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java +++ b/test/hotspot/jtreg/compiler/lib/ir_framework/test/IREncodingPrinter.java @@ -75,6 +75,7 @@ public class IREncodingPrinter { "avx512dq", "avx512vl", "avx512f", + "avx512_fp16", // AArch64 "sha3", "asimd", diff --git a/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorSum.java b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorSum.java new file mode 100644 index 00000000000..8b915ed24f4 --- /dev/null +++ b/test/hotspot/jtreg/compiler/vectorization/TestFloat16VectorSum.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/** +* @test +* @summary Test vectorization of Float16.sum operation. +* @requires vm.compiler2.enabled +* @library /test/lib / +* @compile -XDenablePrimitiveClasses TestFloat16VectorSum.java +* @run driver compiler.vectorization.TestFloat16VectorSum +*/ + +package compiler.vectorization; +import compiler.lib.ir_framework.*; +import java.util.Random; + + +public class TestFloat16VectorSum { + private Float16[] input; + private Float16[] output; + private static final int LEN = 2048; + private Random rng; + + public static void main(String args[]) { + TestFramework.run(TestFloat16VectorSum.class); + } + + public TestFloat16VectorSum() { + input = new Float16[LEN]; + output = new Float16[LEN]; + rng = new Random(42); + for (int i = 0; i < LEN; ++i) { + input[i] = Float16.valueOf(Float.floatToFloat16(rng.nextFloat())); + } + } + + @Test + @Warmup(10000) + @IR(applyIfCPUFeature = {"avx512_fp16" , "true"}, counts = {IRNode.ADD_VHF, " >= 1"}) + public void vectorSumFloat16() { + for (int i = 0; i < LEN; ++i) { + output[i] = Float16.sum(input[i], input[i]); + } + checkResult(); + } + + public void checkResult() { + for (int i = 0; i < LEN; ++i) { + Float16 expected = Float16.sum(input[i], input[i]); + if (output[i].float16ToRawShortBits() != expected.float16ToRawShortBits()) { + throw new RuntimeException("Invalid result: output[" + i + "] = " + output[i].float16ToRawShortBits() + " != " + expected.float16ToRawShortBits()); + } + } + } +} + diff --git a/test/jdk/java/lang/Float16/FP16ReductionOperations.java b/test/jdk/java/lang/Float16/FP16ReductionOperations.java new file mode 100644 index 00000000000..fa208cb8478 --- /dev/null +++ b/test/jdk/java/lang/Float16/FP16ReductionOperations.java @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8308363 + * @summary Test FP16 reduction operations. + * @compile -XDenablePrimitiveClasses FP16ReductionOperations.java + * @run main/othervm -XX:+EnablePrimitiveClasses -XX:-TieredCompilation -Xbatch FP16ReductionOperations + */ + +import java.util.Random; + +public class FP16ReductionOperations { + + public static Random r = new Random(1024); + + public static short test_reduction_add_constants() { + Float16 hf0 = Float16.valueOf((short)0); + Float16 hf1 = Float16.valueOf((short)15360); + Float16 hf2 = Float16.valueOf((short)16384); + Float16 hf3 = Float16.valueOf((short)16896); + Float16 hf4 = Float16.valueOf((short)17408); + return Float16.sum(Float16.sum(Float16.sum(Float16.sum(hf0, hf1), hf2), hf3), hf4).float16ToRawShortBits(); + } + + public static short expected_reduction_add_constants() { + Float16 hf0 = Float16.valueOf((short)0); + Float16 hf1 = Float16.valueOf((short)15360); + Float16 hf2 = Float16.valueOf((short)16384); + Float16 hf3 = Float16.valueOf((short)16896); + Float16 hf4 = Float16.valueOf((short)17408); + return Float.floatToFloat16(Float.float16ToFloat(hf0.float16ToRawShortBits()) + + Float.float16ToFloat(hf1.float16ToRawShortBits()) + + Float.float16ToFloat(hf2.float16ToRawShortBits()) + + Float.float16ToFloat(hf3.float16ToRawShortBits()) + + Float.float16ToFloat(hf4.float16ToRawShortBits())); + } + + public static boolean compare(short actual, short expected) { + return !((0xFFFF & actual) == (0xFFFF & expected)); + } + + public static void test_reduction_constants(char oper) { + short actual = 0; + short expected = 0; + switch(oper) { + case '+' -> { + actual = test_reduction_add_constants(); + expected = expected_reduction_add_constants(); + } + default -> throw new AssertionError("Unsupported Operation."); + } + if (compare(actual,expected)) { + throw new AssertionError("Result mismatch!, expected = " + expected + " actual = " + actual); + } + } + + public static short test_reduction_add(short [] arr) { + Float16 res = Float16.valueOf((short)0); + for (int i = 0; i < arr.length; i++) { + res = Float16.sum(res, Float16.valueOf(arr[i])); + } + return res.float16ToRawShortBits(); + } + + public static short expected_reduction_add(short [] arr) { + short res = 0; + for (int i = 0; i < arr.length; i++) { + res = Float.floatToFloat16(Float.float16ToFloat(res) + Float.float16ToFloat(arr[i])); + } + return res; + } + + public static void test_reduction(char oper, short [] arr) { + short actual = 0; + short expected = 0; + switch(oper) { + case '+' -> { + actual = test_reduction_add(arr); + expected = expected_reduction_add(arr); + } + default -> throw new AssertionError("Unsupported Operation."); + } + if (compare(actual,expected)) { + throw new AssertionError("Result mismatch!, expected = " + expected + " actual = " + actual); + } + } + + public static short [] get_fp16_array(int size) { + short [] arr = new short[size]; + for (int i = 0; i < arr.length; i++) { + arr[i] = Float.floatToFloat16(r.nextFloat()); + } + return arr; + } + + public static void main(String [] args) { + int res = 0; + short [] input = get_fp16_array(1024); + short [] special_values = { + 32256, // NAN + 31744, // +Inf + (short)-1024, // -Inf + 0, // +0.0 + (short)-32768, // -0.0 + }; + for (int i = 0; i < 1000; i++) { + test_reduction('+', input); + test_reduction('+', special_values); + test_reduction_constants('+'); + } + System.out.println("PASS"); + } +} diff --git a/test/jdk/java/lang/Float16/FP16ScalarOperations.java b/test/jdk/java/lang/Float16/FP16ScalarOperations.java new file mode 100644 index 00000000000..72d9efb65b6 --- /dev/null +++ b/test/jdk/java/lang/Float16/FP16ScalarOperations.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +/* + * @test + * @bug 8308363 + * @summary Initial compiler support for Float16.add operation. + * @compile -XDenablePrimitiveClasses FP16ScalarOperations.java + * @run main/othervm -XX:+EnablePrimitiveClasses -XX:-TieredCompilation -Xbatch FP16ScalarOperations + */ + +import java.util.Random; + +public class FP16ScalarOperations { + + public static Random r = new Random(1024); + + public static short actual_value(char oper, short val1, short val2) { + Float16 obj1 = Float16.valueOf(val1); + Float16 obj2 = Float16.valueOf(val2); + switch ((int)oper) { + case '+' : return Float16.sum(obj1, obj2).float16ToRawShortBits(); + default : throw new AssertionError("Unsupported Operation!"); + } + } + + public static void test_add(short [] arr1, short arr2[]) { + for (int i = 0; i < arr1.length; i++) { + validate('+', arr1[i], arr2[i]); + } + } + + public static short expected_value(char oper, short input1, short input2) { + switch((int)oper) { + case '+' : return Float.floatToFloat16(Float.float16ToFloat(input1) + Float.float16ToFloat(input2)); + default : throw new AssertionError("Unsupported Operation!"); + } + } + + public static boolean compare(short actual, short expected) { + return !((0xFFFF & actual) == (0xFFFF & expected)); + } + + public static void validate(char oper, short input1, short input2) { + short actual = actual_value(oper, input1, input2); + short expected = expected_value(oper, input1, input2); + if (compare(actual, expected)) { + throw new AssertionError("Test Failed: " + input1 + " + " + input2 + " : " + actual + " != " + expected); + } + } + + public static short [] get_fp16_array(int size) { + short [] arr = new short[size]; + for (int i = 0; i < arr.length; i++) { + arr[i] = Float.floatToFloat16(r.nextFloat()); + } + return arr; + } + + public static void main(String [] args) { + int res = 0; + short [] input1 = get_fp16_array(1024); + short [] input2 = get_fp16_array(1024); + short [] special_values = { + 32256, // NAN + 31744, // +Inf + (short)-1024, // -Inf + 0, // +0.0 + (short)-32768, // -0.0 + }; + for (int i = 0; i < 1000; i++) { + test_add(input1, input2); + test_add(special_values, special_values); + } + System.out.println("PASS"); + } +}