Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 42f445c7 authored by Wen-yi Chu's avatar Wen-yi Chu Committed by Android (Google) Code Review
Browse files

Merge "Add support for denormalized half-floats" into main

parents e357a6a9 d2c0dac4
Loading
Loading
Loading
Loading
+82 −59
Original line number Diff line number Diff line
@@ -18,7 +18,6 @@

#include <stdint.h>
#include <functional>
#include <iosfwd>
#include <limits>
#include <type_traits>

@@ -60,12 +59,12 @@ class half {
        uint16_t bits;
        explicit constexpr fp16() noexcept : bits(0) {}
        explicit constexpr fp16(uint16_t b) noexcept : bits(b) {}
        void setS(unsigned int s) noexcept { bits = uint16_t((bits & 0x7FFF) | (s<<15)); }
        void setE(unsigned int s) noexcept { bits = uint16_t((bits & 0xE3FF) | (s<<10)); }
        void setM(unsigned int s) noexcept { bits = uint16_t((bits & 0xFC00) | (s<< 0)); }
        constexpr unsigned int getS() const noexcept { return  bits >> 15u; }
        constexpr unsigned int getE() const noexcept { return (bits >> 10u) & 0x1Fu; }
        constexpr unsigned int getM() const noexcept { return  bits         & 0x3FFu; }
        void setS(uint32_t s) noexcept { bits = uint16_t((bits & 0x7FFF) | (s << 15)); }
        void setE(uint32_t s) noexcept { bits = uint16_t((bits & 0x83FF) | (s << 10)); }
        void setM(uint32_t s) noexcept { bits = uint16_t((bits & 0xFC00) | (s << 0)); }
        constexpr uint32_t getS() const noexcept { return bits >> 15u; }
        constexpr uint32_t getE() const noexcept { return (bits >> 10u) & 0x1Fu; }
        constexpr uint32_t getM() const noexcept { return bits & 0x3FFu; }
    };
    struct fp32 {
        union {
@@ -74,12 +73,12 @@ class half {
        };
        explicit constexpr fp32() noexcept : bits(0) {}
        explicit constexpr fp32(float f) noexcept : fp(f) {}
        void setS(unsigned int s) noexcept { bits = uint32_t((bits & 0x7FFFFFFF) | (s<<31)); }
        void setE(unsigned int s) noexcept { bits = uint32_t((bits & 0x807FFFFF) | (s<<23)); }
        void setM(unsigned int s) noexcept { bits = uint32_t((bits & 0xFF800000) | (s<< 0)); }
        constexpr unsigned int getS() const noexcept { return  bits >> 31u; }
        constexpr unsigned int getE() const noexcept { return (bits >> 23u) & 0xFFu; }
        constexpr unsigned int getM() const noexcept { return  bits         & 0x7FFFFFu; }
        void setS(uint32_t s) noexcept { bits = uint32_t((bits & 0x7FFFFFFF) | (s << 31)); }
        void setE(uint32_t s) noexcept { bits = uint32_t((bits & 0x807FFFFF) | (s << 23)); }
        void setM(uint32_t s) noexcept { bits = uint32_t((bits & 0xFF800000) | (s << 0)); }
        constexpr uint32_t getS() const noexcept { return bits >> 31u; }
        constexpr uint32_t getE() const noexcept { return (bits >> 23u) & 0xFFu; }
        constexpr uint32_t getM() const noexcept { return bits & 0x7FFFFFu; }
    };

public:
@@ -88,8 +87,8 @@ public:
    CONSTEXPR operator float() const noexcept { return htof(mBits); }

    uint16_t getBits() const noexcept { return mBits.bits; }
    unsigned int getExponent() const noexcept { return mBits.getE(); }
    unsigned int getMantissa() const noexcept { return mBits.getM(); }
    uint32_t getExponent() const noexcept { return mBits.getE(); }
    uint32_t getMantissa() const noexcept { return mBits.getM(); }

private:
    friend class std::numeric_limits<half>;
@@ -109,21 +108,20 @@ inline CONSTEXPR half::fp16 half::ftoh(float v) noexcept {
        out.setE(0x1F);
        out.setM(in.getM() ? 0x200 : 0);
    } else {
        int e = static_cast<int>(in.getE()) - 127 + 15;
        if (e >= 0x1F) {
        uint32_t e = in.getE();
        uint32_t m = in.getM() + 0x1000; // Rounding to the nearest even.
        if (e > 143) {
            // overflow
            out.setE(0x31); // +/- inf
        } else if (e <= 0) {
            // underflow
            // flush to +/- 0
        } else {
            unsigned int m = in.getM();
            out.setE(uint16_t(e));
            out.setE(0x1F);
        } else if (e > 112) {
            // normalized
            out.setE(e - 112);
            out.setM(m >> 13);
            if (m & 0x1000) {
                // rounding
                out.bits++;
            }
        } else if (e > 101) {
            // denormalized
            out = fp16(static_cast<uint16_t>((((0x007FF000 + m) >> (125 - e)) + 1) >> 1));
        } else {
            // underflow
        }
    }
    out.setS(in.getS());
@@ -138,12 +136,19 @@ inline CONSTEXPR float half::htof(half::fp16 in) noexcept {
    } else {
        if (in.getE() == 0) {
            if (in.getM()) {
                // TODO: denormal half float, treat as zero for now
                // (it's stupid because they can be represented as regular float)
                uint32_t m = in.getM();
                uint32_t e = 127 - 14;
                m <<= 13;
                while (m < 0x800000) {
                    m <<= 1;
                    e--;
                }
                out.setE(e);
                out.setM(m & 0x7FFFFF);
            }
        } else {
            int e = static_cast<int>(in.getE()) - 15 + 127;
            unsigned int m = in.getM();
            uint32_t m = in.getM();
            out.setE(uint32_t(e));
            out.setM(m << 13);
        }
@@ -160,7 +165,8 @@ inline CONSTEXPR android::half operator"" _hf(long double v) {

namespace std {

template<> struct is_floating_point<android::half> : public std::true_type {};
template <>
struct is_floating_point<android::half> : public std::true_type {};

template <>
class numeric_limits<android::half> {
@@ -192,21 +198,38 @@ public:
    static constexpr const int max_exponent = 16;
    static constexpr const int max_exponent10 = 4;

    inline static constexpr type round_error() noexcept { return android::half(android::half::binary, 0x3800); }
    inline static constexpr type min() noexcept { return android::half(android::half::binary, 0x0400); }
    inline static constexpr type max() noexcept { return android::half(android::half::binary, 0x7bff); }
    inline static constexpr type lowest() noexcept { return android::half(android::half::binary, 0xfbff); }
    inline static constexpr type epsilon() noexcept { return android::half(android::half::binary, 0x1400); }
    inline static constexpr type infinity() noexcept { return android::half(android::half::binary, 0x7c00); }
    inline static constexpr type quiet_NaN() noexcept { return android::half(android::half::binary, 0x7fff); }
    inline static constexpr type denorm_min() noexcept { return android::half(android::half::binary, 0x0001); }
    inline static constexpr type signaling_NaN() noexcept { return android::half(android::half::binary, 0x7dff); }
    inline static constexpr type round_error() noexcept {
        return android::half(android::half::binary, 0x3800);
    }
    inline static constexpr type min() noexcept {
        return android::half(android::half::binary, 0x0400);
    }
    inline static constexpr type max() noexcept {
        return android::half(android::half::binary, 0x7bff);
    }
    inline static constexpr type lowest() noexcept {
        return android::half(android::half::binary, 0xfbff);
    }
    inline static constexpr type epsilon() noexcept {
        return android::half(android::half::binary, 0x1400);
    }
    inline static constexpr type infinity() noexcept {
        return android::half(android::half::binary, 0x7c00);
    }
    inline static constexpr type quiet_NaN() noexcept {
        return android::half(android::half::binary, 0x7fff);
    }
    inline static constexpr type denorm_min() noexcept {
        return android::half(android::half::binary, 0x0001);
    }
    inline static constexpr type signaling_NaN() noexcept {
        return android::half(android::half::binary, 0x7dff);
    }
};

template<> struct hash<android::half> {
    size_t operator()(const android::half& half) {
        return std::hash<float>{}(half);
    }
template <>
struct hash<android::half> {
    size_t operator()(const android::half& half) { return std::hash<float>{}(half); }
};

} // namespace std
+54 −41
Original line number Diff line number Diff line
@@ -26,53 +26,63 @@

namespace android {

class HalfTest : public testing::Test {
protected:
};
TEST(HalfTest, TestHalfSize) {
    EXPECT_EQ(2UL, sizeof(half));
}

TEST_F(HalfTest, Basics) {
TEST(HalfTest, TestZero) {
    EXPECT_EQ(half().getBits(), 0x0000);
    EXPECT_EQ(half(0.0f).getBits(), 0x0000);
    EXPECT_EQ(half(-0.0f).getBits(), 0x8000);
}

    EXPECT_EQ(2UL, sizeof(half));
TEST(HalfTest, TestNaN) {
    EXPECT_EQ(half(NAN).getBits(), 0x7e00);
    EXPECT_EQ(std::numeric_limits<half>::quiet_NaN().getBits(), 0x7FFF);
    EXPECT_EQ(std::numeric_limits<half>::signaling_NaN().getBits(), 0x7DFF);
}

TEST(HalfTest, TestInfinity) {
    EXPECT_EQ(std::numeric_limits<half>::infinity().getBits(), 0x7C00);
}

    // test +/- zero
    EXPECT_EQ(0x0000, half().getBits());
    EXPECT_EQ(0x0000, half( 0.0f).getBits());
    EXPECT_EQ(0x8000, half(-0.0f).getBits());
TEST(HalfTest, TestNumericLimits) {
    EXPECT_EQ(std::numeric_limits<half>::min().getBits(), 0x0400);
    EXPECT_EQ(std::numeric_limits<half>::max().getBits(), 0x7BFF);
    EXPECT_EQ(std::numeric_limits<half>::lowest().getBits(), 0xFBFF);
}

    // test nan
    EXPECT_EQ(0x7e00, half(NAN).getBits());
TEST(HalfTest, TestEpsilon) {
    EXPECT_EQ(std::numeric_limits<half>::epsilon().getBits(), 0x1400);
}

    // test +/- infinity
    EXPECT_EQ(0x7C00, half( std::numeric_limits<float>::infinity()).getBits());
    EXPECT_EQ(0xFC00, half(-std::numeric_limits<float>::infinity()).getBits());
TEST(HalfTest, TestDenormals) {
    EXPECT_EQ(half(std::numeric_limits<half>::denorm_min()).getBits(), 0x0001);
    EXPECT_EQ(half(std::numeric_limits<half>::denorm_min() * 2).getBits(), 0x0002);
    EXPECT_EQ(half(std::numeric_limits<half>::denorm_min() * 3).getBits(), 0x0003);
    // test a few known denormals
    EXPECT_EQ(half(6.09756e-5).getBits(), 0x03FF);
    EXPECT_EQ(half(5.96046e-8).getBits(), 0x0001);
    EXPECT_EQ(half(-6.09756e-5).getBits(), 0x83FF);
    EXPECT_EQ(half(-5.96046e-8).getBits(), 0x8001);
}

TEST(HalfTest, TestNormal) {
    // test a few known values
    EXPECT_EQ(0x3C01, half(1.0009765625).getBits());
    EXPECT_EQ(0xC000, half(-2).getBits());
    EXPECT_EQ(0x0400, half(6.10352e-5).getBits());
    EXPECT_EQ(0x7BFF, half(65504).getBits());
    EXPECT_EQ(0x3555, half(1.0f/3).getBits());

    // numeric limits
    EXPECT_EQ(0x7C00, std::numeric_limits<half>::infinity().getBits());
    EXPECT_EQ(0x0400, std::numeric_limits<half>::min().getBits());
    EXPECT_EQ(0x7BFF, std::numeric_limits<half>::max().getBits());
    EXPECT_EQ(0xFBFF, std::numeric_limits<half>::lowest().getBits());

    // denormals (flushed to zero)
    EXPECT_EQ(0x0000, half( 6.09756e-5).getBits());      // if handled, should be: 0x03FF
    EXPECT_EQ(0x0000, half( 5.96046e-8).getBits());      // if handled, should be: 0x0001
    EXPECT_EQ(0x8000, half(-6.09756e-5).getBits());      // if handled, should be: 0x83FF
    EXPECT_EQ(0x8000, half(-5.96046e-8).getBits());      // if handled, should be: 0x8001
    EXPECT_EQ(half(1.0009765625).getBits(), 0x3C01);
    EXPECT_EQ(half(-2).getBits(), 0xC000);
    EXPECT_EQ(half(6.10352e-5).getBits(), 0x0400);
    EXPECT_EQ(half(65504).getBits(), 0x7BFF);
    EXPECT_EQ(half(1.0f / 3).getBits(), 0x3555);

    // test all exactly representable integers
    for (int i = -2048; i <= 2048; ++i) {
        half h = i;
        EXPECT_EQ(i, float(h));
        EXPECT_EQ(float(h), i);
    }
}

TEST_F(HalfTest, Literals) {
TEST(HalfTest, TestLiterals) {
    half one = 1.0_hf;
    half pi = 3.1415926_hf;
    half minusTwo = -2.0_hf;
@@ -82,8 +92,7 @@ TEST_F(HalfTest, Literals) {
    EXPECT_EQ(half(-2.0f), minusTwo);
}


TEST_F(HalfTest, Vec) {
TEST(HalfTest, TestVec) {
    float4 f4(1, 2, 3, 4);
    half4 h4(f4);
    half3 h3(f4.xyz);
@@ -94,8 +103,7 @@ TEST_F(HalfTest, Vec) {
    EXPECT_EQ(f4.xy, h2);
}


TEST_F(HalfTest, Hash) {
TEST(HalfTest, TestHash) {
    float4 f4a(1, 2, 3, 4);
    float4 f4b(2, 2, 3, 4);
    half4 h4a(f4a), h4b(f4b);
@@ -103,4 +111,9 @@ TEST_F(HalfTest, Hash) {
    EXPECT_NE(std::hash<half4>{}(h4a), std::hash<half4>{}(h4b));
}

TEST(HalfTest, TestHalfToFloat) {
    EXPECT_EQ(4.25f, float(4.25_hf));
    EXPECT_EQ(3.05175781e-05f, float(3.05175781e-05_hf));
}

}; // namespace android