Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit d6364416 authored by Treehugger Robot's avatar Treehugger Robot Committed by Android (Google) Code Review
Browse files

Merge "[Lut] correct 3D Lut trilinear algorithm" into main

parents abceb16f 5eae32d4
Loading
Loading
Loading
Loading
+57 −37
Original line number Diff line number Diff line
@@ -70,49 +70,65 @@ static const SkString kShader = SkString(R"(
            }
        } else if (dimension == 3) {
            if (key == 0) {
                float tx = linear.r * float(size - 1);
                float ty = linear.g * float(size - 1);
                float tz = linear.b * float(size - 1);
                // index
                float x = linear.r * float(size - 1);
                float y = linear.g * float(size - 1);
                float z = linear.b * float(size - 1);

                // calculate lower and upper bounds for each dimension
                int x = int(tx);
                int y = int(ty);
                int z = int(tz);
                // lower bound
                float x0 = floor(x);
                float y0 = floor(y);
                float z0 = floor(z);

                int i000 = x + y * size + z * size * size;
                int i100 = i000 + 1;
                int i010 = i000 + size;
                int i110 = i000 + size + 1;
                int i001 = i000 + size * size;
                int i101 = i000 + size * size + 1;
                int i011 = i000 + size * size + size;
                int i111 = i000 + size * size + size + 1;
                // upper bound
                float x1 = min(x0 + 1.0, float(size - 1));
                float y1 = min(y0 + 1.0, float(size - 1));
                float z1 = min(z0 + 1.0, float(size - 1));

                // get 1d normalized indices
                float c000 = float(i000) / float(size * size * size);
                float c100 = float(i100) / float(size * size * size);
                float c010 = float(i010) / float(size * size * size);
                float c110 = float(i110) / float(size * size * size);
                float c001 = float(i001) / float(size * size * size);
                float c101 = float(i101) / float(size * size * size);
                float c011 = float(i011) / float(size * size * size);
                float c111 = float(i111) / float(size * size * size);
                // weight
                // if the value reaches to upper bound, x1 == x0, then weight is 0
                // if no, x1 - x0 should always be 1.0
                float tx = x1 == x0 ? 0 : x - x0;
                float ty = y1 == y0 ? 0 : y - y0;
                float tz = z1 == z0 ? 0 : z - z0;

                // get indices
                // this follows 3d flatten policy described in API/AIDL interface
                // i.e., `FLAT[z + DEPTH * (y + HEIGHT * x)] = ORIGINAL[x][y][z]`
                float i000 = z0 + (y0 * float(size)) + (x0 * float(size) * float(size));
                float i001 = z1 + (y0 * float(size)) + (x0 * float(size) * float(size));
                float i010 = z0 + (y1 * float(size)) + (x0 * float(size) * float(size));
                float i011 = z1 + (y1 * float(size)) + (x0 * float(size) * float(size));
                float i100 = z0 + (y0 * float(size)) + (x1 * float(size) * float(size));
                float i101 = z1 + (y0 * float(size)) + (x1 * float(size) * float(size));
                float i110 = z0 + (y1 * float(size)) + (x1 * float(size) * float(size));
                float i111 = z1 + (y1 * float(size)) + (x1 * float(size) * float(size));

                // TODO(b/377984618): support Tetrahedral interpolation
                // perform trilinear interpolation
                float3 c00 = mix(lut.eval(vec2(c000, 0.0) + 0.5).rgb,
                                 lut.eval(vec2(c100, 0.0) + 0.5).rgb, linear.r);
                float3 c01 = mix(lut.eval(vec2(c001, 0.0) + 0.5).rgb,
                                 lut.eval(vec2(c101, 0.0) + 0.5).rgb, linear.r);
                float3 c10 = mix(lut.eval(vec2(c010, 0.0) + 0.5).rgb,
                                 lut.eval(vec2(c110, 0.0) + 0.5).rgb, linear.r);
                float3 c11 = mix(lut.eval(vec2(c011, 0.0) + 0.5).rgb,
                                 lut.eval(vec2(c111, 0.0) + 0.5).rgb, linear.r);
                // see https://en.wikipedia.org/wiki/Trilinear_interpolation
                float3 c000 = lut.eval(vec2(i000, 0.0) + 0.5).rgb;
                float3 c001 = lut.eval(vec2(i001, 0.0) + 0.5).rgb;
                float3 c010 = lut.eval(vec2(i010, 0.0) + 0.5).rgb;
                float3 c011 = lut.eval(vec2(i011, 0.0) + 0.5).rgb;
                float3 c100 = lut.eval(vec2(i100, 0.0) + 0.5).rgb;
                float3 c101 = lut.eval(vec2(i101, 0.0) + 0.5).rgb;
                float3 c110 = lut.eval(vec2(i110, 0.0) + 0.5).rgb;
                float3 c111 = lut.eval(vec2(i111, 0.0) + 0.5).rgb;

                // mix(x, y, a) = x * (1 - a) + y * a
                // interpolate along the z-axis
                float3 c00 = mix(c000, c001, tz);
                float3 c01 = mix(c010, c011, tz);
                float3 c10 = mix(c100, c101, tz);
                float3 c11 = mix(c110, c111, tz);

                float3 c0 = mix(c00, c10, linear.g);
                float3 c1 = mix(c01, c11, linear.g);
                // interpolate along the y-axis
                float3 c0 = mix(c00, c01, ty);
                float3 c1 = mix(c10, c11, ty);

                linear = mix(c0, c1, linear.b);
                // interpolate along the x-axis
                linear = mix(c0, c1, tx);
            }
        }
        return float4(fromLinearSrgb(linear), rgba.a);
@@ -170,16 +186,20 @@ sk_sp<SkShader> LutShader::generateLutShader(sk_sp<SkShader> input,
     * 1D Lut RGB/MAX_RGB
     * (R0, 0, 0, 0)
     * (R1, 0, 0, 0)
     * ...
     * (R_length-1, 0, 0, 0)
     *
     * 1D Lut CIE_Y
     * (Y0, 0, 0, 0)
     * (Y1, 0, 0, 0)
     * ...
     * (Y_length-1, 0, 0, 0)
     *
     * 3D Lut MAX_RGB
     * (R0, G0, B0, 0)
     * (R1, G1, B1, 0)
     * ...
     * (R_length-1, G_length-1, B_length-1, 0)
     */
    SkImageInfo info = SkImageInfo::Make(length /* the number of rgba */, 1, kRGBA_F16_SkColorType,
                                         kPremul_SkAlphaType);
@@ -214,7 +234,7 @@ sk_sp<SkShader> LutShader::generateLutShader(sk_sp<SkShader> input,
        default:
            normalizeScalar = 1.0;
    }
    const int uSize = static_cast<int>(size);
    const int uSize = static_cast<int>(size); // the size per dimension
    const int uKey = static_cast<int>(samplingKey);
    const int uDimension = static_cast<int>(dimension);
    const float uNormalizeScalar = static_cast<float>(normalizeScalar);