From 161db240b51b2d480e2109aa7d33a4ded0a34128 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Wed, 25 Oct 2023 10:41:58 -0700 Subject: [PATCH] Allow the copy constructor to use copy-by-value as float8 types are quite small No functional change is intended. PiperOrigin-RevId: 576576088 --- ml_dtypes/include/float8.h | 42 ++++++++++++++------------------------ 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/ml_dtypes/include/float8.h b/ml_dtypes/include/float8.h index f3f557b7..e073b84f 100644 --- a/ml_dtypes/include/float8.h +++ b/ml_dtypes/include/float8.h @@ -59,9 +59,9 @@ class float8_base { public: constexpr float8_base() : rep_(0) {} - template >> - explicit EIGEN_DEVICE_FUNC float8_base(T f) + template + explicit EIGEN_DEVICE_FUNC float8_base( + T f, std::enable_if_t, int> = 0) : float8_base(ConvertFrom(static_cast(f)).rep(), ConstructFromRepTag{}) {} explicit EIGEN_DEVICE_FUNC float8_base(double f64) @@ -239,6 +239,10 @@ class float8_base { uint8_t rep_; }; +template +using RequiresIsDerivedFromFloat8Base = + std::enable_if_t, T>, int>; + class float8_e4m3fn : public float8_base { // Exponent: 4, Mantissa: 3, bias: 7. // Extended range: no inf, NaN represented by 0bS111'1111. @@ -252,9 +256,8 @@ class float8_e4m3fn : public float8_base { using Base::Base; public: - explicit EIGEN_DEVICE_FUNC float8_e4m3fn(const float8_e5m2& f8) - : float8_e4m3fn(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e4m3fn(const float8_e4m3b11fnuz& f8) + template = 0> + explicit EIGEN_DEVICE_FUNC float8_e4m3fn(T f8) : float8_e4m3fn(ConvertFrom(f8)) {} }; @@ -267,13 +270,8 @@ class float8_e4m3b11fnuz : public float8_base { using Base::Base; public: - explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e5m2& f8) - : float8_e4m3b11fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e5m2fnuz& f8) - : float8_e4m3b11fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e4m3fn& f8) - : float8_e4m3b11fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(const float8_e4m3fnuz& f8) + template = 0> + explicit EIGEN_DEVICE_FUNC float8_e4m3b11fnuz(T f8) : float8_e4m3b11fnuz(ConvertFrom(f8)) {} constexpr float8_e4m3b11fnuz operator-() const { @@ -315,13 +313,8 @@ class float8_e4m3fnuz : public float8_base { using Base::Base; public: - explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e5m2& f8) - : float8_e4m3fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e5m2fnuz& f8) - : float8_e4m3fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e4m3b11fnuz& f8) - : float8_e4m3fnuz(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(const float8_e4m3fn& f8) + template = 0> + explicit EIGEN_DEVICE_FUNC float8_e4m3fnuz(T f8) : float8_e4m3fnuz(ConvertFrom(f8)) {} constexpr float8_e4m3fnuz operator-() const { @@ -347,13 +340,8 @@ class float8_e5m2 : public float8_base { using Base::Base; public: - explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3fn f8) - : float8_e5m2(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3fnuz f8) - : float8_e5m2(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e4m3b11fnuz f8) - : float8_e5m2(ConvertFrom(f8)) {} - explicit EIGEN_DEVICE_FUNC float8_e5m2(float8_e5m2fnuz& f8) + template = 0> + explicit EIGEN_DEVICE_FUNC float8_e5m2(T f8) : float8_e5m2(ConvertFrom(f8)) {} };