Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit b58afe6d authored by Alex Strelnikov's avatar Alex Strelnikov
Browse files

Fix Mesh UAF, uniform handling, and performance bug.

Prior to this change, the render thread would only receive a non-owning
pointer or reference to a Mesh. Also because the Mesh itself was passed
by pointer, the refcount in its uniform sk_sp<SkData> would not be
incremented until an SkMesh was updated on the render thread, causing
uniform setting calls on the UI thread to impact prior draw calls with
the same mesh. The dirty flag used for the uniforms was also signaling a
reupload of vertex and index data to the GPU.

This fix adds a MeshBufferData class to handle keeping Skia buffers
up-to-date, and a Mesh::Snapshot class that carries shared ownership of
all pieces needed to construct an SkMesh. The snapshot is stored on the
render thread by value and increments the refcount of the uniform
sk_sp<SkData>.

Because the current Android Mesh API does not support partial buffer
updates, there is no need for a dirty flag, as comparing the
DirectContextID and checking if the buffers have been created is
sufficient. Creating an SkMesh is performed lazily inside the SkMesh
getter of the snapshot.

BUG: 328507000
Test: atest CtsUiRenderingTestCases:MeshTest
Change-Id: Iabe83dca462d4526c118047621b131009032d35b
parent d2a2fae7
Loading
Loading
Loading
Loading
+29 −19
Original line number Diff line number Diff line
@@ -21,6 +21,8 @@

#include "SafeMath.h"

namespace android {

static size_t min_vcount_for_mode(SkMesh::Mode mode) {
    switch (mode) {
        case SkMesh::Mode::kTriangles:
@@ -28,6 +30,7 @@ static size_t min_vcount_for_mode(SkMesh::Mode mode) {
        case SkMesh::Mode::kTriangleStrip:
            return 3;
    }
    return 1;
}

// Re-implementation of SkMesh::validate to validate user side that their mesh is valid.
@@ -36,29 +39,30 @@ std::tuple<bool, SkString> Mesh::validate() {
    if (!mMeshSpec) {
        FAIL_MESH_VALIDATE("MeshSpecification is required.");
    }
    if (mVertexBufferData.empty()) {
    if (mBufferData->vertexData().empty()) {
        FAIL_MESH_VALIDATE("VertexBuffer is required.");
    }

    auto meshStride = mMeshSpec->stride();
    auto meshMode = SkMesh::Mode(mMode);
    size_t vertexStride = mMeshSpec->stride();
    size_t vertexCount = mBufferData->vertexCount();
    size_t vertexOffset = mBufferData->vertexOffset();
    SafeMath sm;
    size_t vsize = sm.mul(meshStride, mVertexCount);
    if (sm.add(vsize, mVertexOffset) > mVertexBufferData.size()) {
    size_t vertexSize = sm.mul(vertexStride, vertexCount);
    if (sm.add(vertexSize, vertexOffset) > mBufferData->vertexData().size()) {
        FAIL_MESH_VALIDATE(
                "The vertex buffer offset and vertex count reads beyond the end of the"
                " vertex buffer.");
    }

    if (mVertexOffset % meshStride != 0) {
    if (vertexOffset % vertexStride != 0) {
        FAIL_MESH_VALIDATE("The vertex offset (%zu) must be a multiple of the vertex stride (%zu).",
                           mVertexOffset, meshStride);
                           vertexOffset, vertexStride);
    }

    if (size_t uniformSize = mMeshSpec->uniformSize()) {
        if (!mBuilder->fUniforms || mBuilder->fUniforms->size() < uniformSize) {
        if (!mUniformBuilder.fUniforms || mUniformBuilder.fUniforms->size() < uniformSize) {
            FAIL_MESH_VALIDATE("The uniform data is %zu bytes but must be at least %zu.",
                               mBuilder->fUniforms->size(), uniformSize);
                               mUniformBuilder.fUniforms->size(), uniformSize);
        }
    }

@@ -69,29 +73,33 @@ std::tuple<bool, SkString> Mesh::validate() {
            case SkMesh::Mode::kTriangleStrip:
                return "triangle-strip";
        }
        return "unknown";
    };
    if (!mIndexBufferData.empty()) {
        if (mIndexCount < min_vcount_for_mode(meshMode)) {

    size_t indexCount = mBufferData->indexCount();
    size_t indexOffset = mBufferData->indexOffset();
    if (!mBufferData->indexData().empty()) {
        if (indexCount < min_vcount_for_mode(mMode)) {
            FAIL_MESH_VALIDATE("%s mode requires at least %zu indices but index count is %zu.",
                               modeToStr(meshMode), min_vcount_for_mode(meshMode), mIndexCount);
                               modeToStr(mMode), min_vcount_for_mode(mMode), indexCount);
        }
        size_t isize = sm.mul(sizeof(uint16_t), mIndexCount);
        if (sm.add(isize, mIndexOffset) > mIndexBufferData.size()) {
        size_t isize = sm.mul(sizeof(uint16_t), indexCount);
        if (sm.add(isize, indexOffset) > mBufferData->indexData().size()) {
            FAIL_MESH_VALIDATE(
                    "The index buffer offset and index count reads beyond the end of the"
                    " index buffer.");
        }
        // If we allow 32 bit indices then this should enforce 4 byte alignment in that case.
        if (!SkIsAlign2(mIndexOffset)) {
        if (!SkIsAlign2(indexOffset)) {
            FAIL_MESH_VALIDATE("The index offset must be a multiple of 2.");
        }
    } else {
        if (mVertexCount < min_vcount_for_mode(meshMode)) {
        if (vertexCount < min_vcount_for_mode(mMode)) {
            FAIL_MESH_VALIDATE("%s mode requires at least %zu vertices but vertex count is %zu.",
                               modeToStr(meshMode), min_vcount_for_mode(meshMode), mVertexCount);
                               modeToStr(mMode), min_vcount_for_mode(mMode), vertexCount);
        }
        LOG_ALWAYS_FATAL_IF(mIndexCount != 0);
        LOG_ALWAYS_FATAL_IF(mIndexOffset != 0);
        LOG_ALWAYS_FATAL_IF(indexCount != 0);
        LOG_ALWAYS_FATAL_IF(indexOffset != 0);
    }

    if (!sm.ok()) {
@@ -100,3 +108,5 @@ std::tuple<bool, SkString> Mesh::validate() {
#undef FAIL_MESH_VALIDATE
    return {true, {}};
}

}  // namespace android
+143 −82
Original line number Diff line number Diff line
@@ -25,6 +25,8 @@

#include <utility>

namespace android {

class MeshUniformBuilder {
public:
    struct MeshUniform {
@@ -103,111 +105,170 @@ private:
    sk_sp<SkMeshSpecification> fMeshSpec;
};

class Mesh {
// Storage for CPU and GPU copies of the vertex and index data of a mesh.
class MeshBufferData {
public:
    Mesh(const sk_sp<SkMeshSpecification>& meshSpec, int mode,
         std::vector<uint8_t>&& vertexBufferData, jint vertexCount, jint vertexOffset,
         std::unique_ptr<MeshUniformBuilder> builder, const SkRect& bounds)
            : mMeshSpec(meshSpec)
            , mMode(mode)
            , mVertexBufferData(std::move(vertexBufferData))
            , mVertexCount(vertexCount)
            , mVertexOffset(vertexOffset)
            , mBuilder(std::move(builder))
            , mBounds(bounds) {}

    Mesh(const sk_sp<SkMeshSpecification>& meshSpec, int mode,
         std::vector<uint8_t>&& vertexBufferData, jint vertexCount, jint vertexOffset,
         std::vector<uint8_t>&& indexBuffer, jint indexCount, jint indexOffset,
         std::unique_ptr<MeshUniformBuilder> builder, const SkRect& bounds)
            : mMeshSpec(meshSpec)
            , mMode(mode)
            , mVertexBufferData(std::move(vertexBufferData))
            , mVertexCount(vertexCount)
    MeshBufferData(std::vector<uint8_t> vertexData, int32_t vertexCount, int32_t vertexOffset,
                   std::vector<uint8_t> indexData, int32_t indexCount, int32_t indexOffset)
            : mVertexCount(vertexCount)
            , mVertexOffset(vertexOffset)
            , mIndexBufferData(std::move(indexBuffer))
            , mIndexCount(indexCount)
            , mIndexOffset(indexOffset)
            , mBuilder(std::move(builder))
            , mBounds(bounds) {}
            , mVertexData(std::move(vertexData))
            , mIndexData(std::move(indexData)) {}

    Mesh(Mesh&&) = default;

    Mesh& operator=(Mesh&&) = default;

    [[nodiscard]] std::tuple<bool, SkString> validate();

    void updateSkMesh(GrDirectContext* context) const {
        GrDirectContext::DirectContextID genId = GrDirectContext::DirectContextID();
        if (context) {
            genId = context->directContextID();
    void updateBuffers(GrDirectContext* context) const {
        GrDirectContext::DirectContextID currentId = context == nullptr
                                                             ? GrDirectContext::DirectContextID()
                                                             : context->directContextID();
        if (currentId == mSkiaBuffers.fGenerationId && mSkiaBuffers.fVertexBuffer != nullptr) {
            // Nothing to update since the Android API does not support partial updates yet.
            return;
        }

        if (mIsDirty || genId != mGenerationId) {
            auto vertexData = reinterpret_cast<const void*>(mVertexBufferData.data());
        mSkiaBuffers.fVertexBuffer =
#ifdef __ANDROID__
            auto vb = SkMeshes::MakeVertexBuffer(context,
                                                 vertexData,
                                                 mVertexBufferData.size());
                SkMeshes::MakeVertexBuffer(context, mVertexData.data(), mVertexData.size());
#else
            auto vb = SkMeshes::MakeVertexBuffer(vertexData,
                                                 mVertexBufferData.size());
                SkMeshes::MakeVertexBuffer(mVertexData.data(), mVertexData.size());
#endif
            auto meshMode = SkMesh::Mode(mMode);
            if (!mIndexBufferData.empty()) {
                auto indexData = reinterpret_cast<const void*>(mIndexBufferData.data());
        if (mIndexCount != 0) {
            mSkiaBuffers.fIndexBuffer =
#ifdef __ANDROID__
                auto ib = SkMeshes::MakeIndexBuffer(context,
                                                    indexData,
                                                    mIndexBufferData.size());
                    SkMeshes::MakeIndexBuffer(context, mIndexData.data(), mIndexData.size());
#else
                auto ib = SkMeshes::MakeIndexBuffer(indexData,
                                                    mIndexBufferData.size());
                    SkMeshes::MakeIndexBuffer(mIndexData.data(), mIndexData.size());
#endif
                mMesh = SkMesh::MakeIndexed(mMeshSpec, meshMode, vb, mVertexCount, mVertexOffset,
                                            ib, mIndexCount, mIndexOffset, mBuilder->fUniforms,
                                            SkSpan<SkRuntimeEffect::ChildPtr>(), mBounds)
                                .mesh;
            } else {
                mMesh = SkMesh::Make(mMeshSpec, meshMode, vb, mVertexCount, mVertexOffset,
                                     mBuilder->fUniforms, SkSpan<SkRuntimeEffect::ChildPtr>(),
                                     mBounds)
                                .mesh;
            }
            mIsDirty = false;
            mGenerationId = genId;
        }
        mSkiaBuffers.fGenerationId = currentId;
    }

    SkMesh& getSkMesh() const {
        LOG_FATAL_IF(mIsDirty,
                     "Attempt to obtain SkMesh when Mesh is dirty, did you "
                     "forget to call updateSkMesh with a GrDirectContext? "
                     "Defensively creating a CPU mesh");
    SkMesh::VertexBuffer* vertexBuffer() const { return mSkiaBuffers.fVertexBuffer.get(); }

    sk_sp<SkMesh::VertexBuffer> refVertexBuffer() const { return mSkiaBuffers.fVertexBuffer; }
    int32_t vertexCount() const { return mVertexCount; }
    int32_t vertexOffset() const { return mVertexOffset; }

    sk_sp<SkMesh::IndexBuffer> refIndexBuffer() const { return mSkiaBuffers.fIndexBuffer; }
    int32_t indexCount() const { return mIndexCount; }
    int32_t indexOffset() const { return mIndexOffset; }

    const std::vector<uint8_t>& vertexData() const { return mVertexData; }
    const std::vector<uint8_t>& indexData() const { return mIndexData; }

private:
    struct CachedSkiaBuffers {
        sk_sp<SkMesh::VertexBuffer> fVertexBuffer;
        sk_sp<SkMesh::IndexBuffer> fIndexBuffer;
        GrDirectContext::DirectContextID fGenerationId = GrDirectContext::DirectContextID();
    };

    mutable CachedSkiaBuffers mSkiaBuffers;
    int32_t mVertexCount = 0;
    int32_t mVertexOffset = 0;
    int32_t mIndexCount = 0;
    int32_t mIndexOffset = 0;
    std::vector<uint8_t> mVertexData;
    std::vector<uint8_t> mIndexData;
};

class Mesh {
public:
    // A snapshot of the mesh for use by the render thread.
    //
    // After a snapshot is taken, future uniform changes to the original Mesh will not modify the
    // uniforms returned by makeSkMesh.
    class Snapshot {
    public:
        Snapshot() = delete;
        Snapshot(const Snapshot&) = default;
        Snapshot(Snapshot&&) = default;
        Snapshot& operator=(const Snapshot&) = default;
        Snapshot& operator=(Snapshot&&) = default;
        ~Snapshot() = default;

        const SkMesh& getSkMesh() const {
            SkMesh::VertexBuffer* vertexBuffer = mBufferData->vertexBuffer();
            LOG_FATAL_IF(vertexBuffer == nullptr,
                         "Attempt to obtain SkMesh when vertexBuffer has not been created, did you "
                         "forget to call MeshBufferData::updateBuffers with a GrDirectContext?");
            if (vertexBuffer != mMesh.vertexBuffer()) mMesh = makeSkMesh();
            return mMesh;
        }

    void markDirty() { mIsDirty = true; }
    private:
        friend class Mesh;

    MeshUniformBuilder* uniformBuilder() { return mBuilder.get(); }
        Snapshot(sk_sp<SkMeshSpecification> meshSpec, SkMesh::Mode mode,
                 std::shared_ptr<const MeshBufferData> bufferData, sk_sp<const SkData> uniforms,
                 const SkRect& bounds)
                : mMeshSpec(std::move(meshSpec))
                , mMode(mode)
                , mBufferData(std::move(bufferData))
                , mUniforms(std::move(uniforms))
                , mBounds(bounds) {}

private:
        SkMesh makeSkMesh() const {
            const MeshBufferData& d = *mBufferData;
            if (d.indexCount() != 0) {
                return SkMesh::MakeIndexed(mMeshSpec, mMode, d.refVertexBuffer(), d.vertexCount(),
                                           d.vertexOffset(), d.refIndexBuffer(), d.indexCount(),
                                           d.indexOffset(), mUniforms,
                                           SkSpan<SkRuntimeEffect::ChildPtr>(), mBounds)
                        .mesh;
            }
            return SkMesh::Make(mMeshSpec, mMode, d.refVertexBuffer(), d.vertexCount(),
                                d.vertexOffset(), mUniforms, SkSpan<SkRuntimeEffect::ChildPtr>(),
                                mBounds)
                    .mesh;
        }

        mutable SkMesh mMesh;
        sk_sp<SkMeshSpecification> mMeshSpec;
    int mMode = 0;
        SkMesh::Mode mMode;
        std::shared_ptr<const MeshBufferData> mBufferData;
        sk_sp<const SkData> mUniforms;
        SkRect mBounds;
    };

    std::vector<uint8_t> mVertexBufferData;
    size_t mVertexCount = 0;
    size_t mVertexOffset = 0;
    Mesh(sk_sp<SkMeshSpecification> meshSpec, SkMesh::Mode mode, std::vector<uint8_t> vertexData,
         int32_t vertexCount, int32_t vertexOffset, const SkRect& bounds)
            : Mesh(std::move(meshSpec), mode, std::move(vertexData), vertexCount, vertexOffset,
                   /* indexData = */ {}, /* indexCount = */ 0, /* indexOffset = */ 0, bounds) {}

    std::vector<uint8_t> mIndexBufferData;
    size_t mIndexCount = 0;
    size_t mIndexOffset = 0;
    Mesh(sk_sp<SkMeshSpecification> meshSpec, SkMesh::Mode mode, std::vector<uint8_t> vertexData,
         int32_t vertexCount, int32_t vertexOffset, std::vector<uint8_t> indexData,
         int32_t indexCount, int32_t indexOffset, const SkRect& bounds)
            : mMeshSpec(std::move(meshSpec))
            , mMode(mode)
            , mBufferData(std::make_shared<MeshBufferData>(std::move(vertexData), vertexCount,
                                                           vertexOffset, std::move(indexData),
                                                           indexCount, indexOffset))
            , mUniformBuilder(mMeshSpec)
            , mBounds(bounds) {}

    Mesh(Mesh&&) = default;

    Mesh& operator=(Mesh&&) = default;

    [[nodiscard]] std::tuple<bool, SkString> validate();

    std::unique_ptr<MeshUniformBuilder> mBuilder;
    SkRect mBounds{};
    std::shared_ptr<const MeshBufferData> refBufferData() const { return mBufferData; }

    mutable SkMesh mMesh{};
    mutable bool mIsDirty = true;
    mutable GrDirectContext::DirectContextID mGenerationId = GrDirectContext::DirectContextID();
    Snapshot takeSnapshot() const {
        return Snapshot(mMeshSpec, mMode, mBufferData, mUniformBuilder.fUniforms, mBounds);
    }

    MeshUniformBuilder* uniformBuilder() { return &mUniformBuilder; }

private:
    sk_sp<SkMeshSpecification> mMeshSpec;
    SkMesh::Mode mMode;
    std::shared_ptr<MeshBufferData> mBufferData;
    MeshUniformBuilder mUniformBuilder;
    SkRect mBounds;
};

}  // namespace android

#endif  // MESH_H_
+2 −11
Original line number Diff line number Diff line
@@ -573,9 +573,9 @@ struct DrawSkMesh final : Op {
struct DrawMesh final : Op {
    static const auto kType = Type::DrawMesh;
    DrawMesh(const Mesh& mesh, sk_sp<SkBlender> blender, const SkPaint& paint)
            : mesh(mesh), blender(std::move(blender)), paint(paint) {}
            : mesh(mesh.takeSnapshot()), blender(std::move(blender)), paint(paint) {}

    const Mesh& mesh;
    Mesh::Snapshot mesh;
    sk_sp<SkBlender> blender;
    SkPaint paint;

@@ -1296,14 +1296,5 @@ void RecordingCanvas::drawWebView(skiapipeline::FunctorDrawable* drawable) {
    fDL->drawWebView(drawable);
}

[[nodiscard]] const SkMesh& DrawMeshPayload::getSkMesh() const {
    LOG_FATAL_IF(!meshWrapper && !mesh, "One of Mesh or Mesh must be non-null");
    if (meshWrapper) {
        return meshWrapper->getSkMesh();
    } else {
        return *mesh;
    }
}

}  // namespace uirenderer
}  // namespace android
+3 −14
Original line number Diff line number Diff line
@@ -41,11 +41,12 @@

enum class SkBlendMode;
class SkRRect;
class Mesh;

namespace android {
namespace uirenderer {

class Mesh;

namespace uirenderer {
namespace skiapipeline {
class FunctorDrawable;
}
@@ -68,18 +69,6 @@ struct DisplayListOp {

static_assert(sizeof(DisplayListOp) == 4);

class DrawMeshPayload {
public:
    explicit DrawMeshPayload(const SkMesh* mesh) : mesh(mesh) {}
    explicit DrawMeshPayload(const Mesh* meshWrapper) : meshWrapper(meshWrapper) {}

    [[nodiscard]] const SkMesh& getSkMesh() const;

private:
    const SkMesh* mesh = nullptr;
    const Mesh* meshWrapper = nullptr;
};

struct DrawImagePayload {
    explicit DrawImagePayload(Bitmap& bitmap)
            : image(bitmap.makeImage()), palette(bitmap.palette()) {
+2 −2
Original line number Diff line number Diff line
@@ -596,8 +596,8 @@ void SkiaCanvas::drawMesh(const Mesh& mesh, sk_sp<SkBlender> blender, const Pain
    if (recordingContext) {
        context = recordingContext->asDirectContext();
    }
    mesh.updateSkMesh(context);
    mCanvas->drawMesh(mesh.getSkMesh(), blender, paint);
    mesh.refBufferData()->updateBuffers(context);
    mCanvas->drawMesh(mesh.takeSnapshot().getSkMesh(), blender, paint);
}

// ----------------------------------------------------------------------------
Loading