Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 06dea72f authored by En-Shuo Hsu's avatar En-Shuo Hsu
Browse files

floss: Implement the PLC(Packet Loss Concealment)

The PLC is recommended to implement for HFP WBS based on Bluetooth
organization's HFP spec.

The PLC algorithm is based on HFP spec 1.7.1 and has been used by
Chrome OS since 2019/07. This CL follows the existing implementation
with limited modification to adjust the code flow and architecture of
btm_sco_hci.

Bug: 232463744
Tag: #floss
Test: atest --host net_test_stack_btm --no-bazel-mode &&
Build and verify HFP WBS PLC works under different signal
strength. Print log to verify each state of the PLC works.

BYPASS_LONG_LINES_REASON: Bluetooth likes 120 char lines

Change-Id: I5d02df4df75509ff076ec02f022b37698d121420
parent cf6a4252
Loading
Loading
Loading
Loading
+254 −8
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@
#include <sys/stat.h>
#include <unistd.h>

#include <cfloat>
#include <memory>

#include "btif/include/core_callbacks.h"
@@ -40,6 +41,22 @@
#define BTM_MSBC_PKT_FRAME_LEN 57 /* Packet length without the header */
#define BTM_MSBC_SYNC_WORD 0xAD

/* Used by PLC */
#define BTM_MSBC_SAMPLE_SIZE 2 /* 2 bytes*/
#define BTM_MSBC_FS 120        /* Frame Size */

#define BTM_PLC_WL 256 /* 16ms - Window Length for pattern matching */
#define BTM_PLC_TL 64  /* 4ms - Template Length for matching */
#define BTM_PLC_HL \
  (BTM_PLC_WL + BTM_MSBC_FS - 1) /* Length of History buffer required */
#define BTM_PLC_SBCRL 36         /* SBC Reconvergence sample Length */
#define BTM_PLC_OLAL 16          /* OverLap-Add Length */

/* Disable the PLC when there are more than threshold of lost packets in the
 * window */
#define BTM_PLC_WINDOW_SIZE 5
#define BTM_PLC_PL_THRESHOLD 1

namespace {

std::unique_ptr<tUIPC_STATE> sco_uipc = nullptr;
@@ -128,16 +145,231 @@ constexpr size_t btm_wbs_supported_pkt_size[] = {BTM_MSBC_PKT_LEN, 72, 0};
 * BTM_MSBC_PKT_LEN for optimizing buffer copy. */
constexpr size_t btm_wbs_msbc_buffer_size[] = {BTM_MSBC_PKT_LEN, 360, 0};

/* The pre-computed zero input bit stream of mSBC codec, per HFP 1.7 spec.
 * This mSBC frame will be decoded into all-zero input PCM. */
/* The pre-computed SCO packet per HFP 1.7 spec. This mSBC packet will be
 * decoded into all-zero input PCM. */
static const uint8_t btm_msbc_zero_packet[] = {
    0x01, 0x08, /* Mock H2 header */
    0xad, 0x00, 0x00, 0xc5, 0x00, 0x00, 0x00, 0x00, 0x77, 0x6d, 0xb6, 0xdd,
    0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6d, 0xdd, 0xb6, 0xdb, 0x77, 0x6d, 0xb6,
    0xdd, 0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6d, 0xdd, 0xb6, 0xdb, 0x77, 0x6d,
    0xb6, 0xdd, 0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6d, 0xdd, 0xb6, 0xdb, 0x77,
    0x6d, 0xb6, 0xdd, 0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6c};
    0x6d, 0xb6, 0xdd, 0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6c,
    /* End of Audio Samples */
    0x00 /* A padding byte defined by mSBC */};

/* Raised Cosine table for OLA */
static const float rcos[BTM_PLC_OLAL] = {
    0.99148655f, 0.96623611f, 0.92510857f, 0.86950446f,
    0.80131732f, 0.72286918f, 0.63683150f, 0.54613418f,
    0.45386582f, 0.36316850f, 0.27713082f, 0.19868268f,
    0.13049554f, 0.07489143f, 0.03376389f, 0.00851345f};

static int16_t f_to_s16(float input) {
  return input > INT16_MAX   ? INT16_MAX
         : input < INT16_MIN ? INT16_MIN
                             : (int16_t)input;
}
/* This structure tracks the packet loss for last PLC_WINDOW_SIZE of packets */
struct tBTM_MSBC_BTM_PLC_WINDOW {
  bool loss_hist[BTM_PLC_WINDOW_SIZE]; /* The packet loss history of receiving
                                      packets.*/
  unsigned int idx;   /* The index of the to be updated packet loss status. */
  unsigned int count; /* The count of lost packets in the window. */

 public:
  void update_plc_state(bool is_packet_loss) {
    bool* curr = &loss_hist[idx];
    if (is_packet_loss != *curr) {
      count += (is_packet_loss - *curr);
      *curr = is_packet_loss;
    }
    idx = (idx + 1) % BTM_PLC_WINDOW_SIZE;
  }

  bool is_packet_loss_too_high() {
    /* The packet loss count comes from a time window and we use it as an
     * indicator of our confidence of the PLC algorithm. It is known to
     * generate poorer and robotic feeling sounds, when the majority of
     * samples in the PLC history buffer are from the concealment results.
     */
    return count > BTM_PLC_PL_THRESHOLD;
  }
};

/* The PLC is specifically designed for mSBC. The algorithm searches the
 * history of receiving samples to find the best match samples and constructs
 * substitutions for the lost samples. The selection is based on pattern
 * matching a template, composed of a length of samples preceding to the lost
 * samples. It then uses the following samples after the best match as the
 * replacement samples and applies Overlap-Add to reduce the audible
 * distortion.
 *
 * This structure holds related info needed to conduct the PLC algorithm.
 */
struct tBTM_MSBC_PLC {
  int16_t hist[BTM_PLC_HL + BTM_MSBC_FS + BTM_PLC_SBCRL +
               BTM_PLC_OLAL]; /* The history buffer for receiving samples, we
                                 also use it to buffer the processed
                                 replacement samples */
  unsigned best_lag;      /* The index of the best substitution samples in the
                             sample history */
  int handled_bad_frames; /* Number of bad frames handled since the last good
                             frame */
  int16_t decoded_buffer[BTM_MSBC_FS]; /* Used for storing the samples from
                                      decoding the mSBC zero frame packet and
                                      also constructed frames */
  tBTM_MSBC_BTM_PLC_WINDOW*
      pl_window; /* Used to monitor how many packets are bad within the recent
                    BTM_PLC_WINDOW_SIZE of packets. We use this to determine if
                    we want to disable the PLC temporarily */

  void overlap_add(int16_t* output, float scaler_d, const int16_t* desc,
                   float scaler_a, const int16_t* asc) {
    for (int i = 0; i < BTM_PLC_OLAL; i++) {
      output[i] = f_to_s16(scaler_d * desc[i] * rcos[i] +
                           scaler_a * asc[i] * rcos[BTM_PLC_OLAL - 1 - i]);
    }
  }

static const uint8_t btm_msbc_zero_frames[BTM_MSBC_CODE_SIZE] = {0};
  float cross_correlation(int16_t* x, int16_t* y) {
    float sum = 0, x2 = 0, y2 = 0;

    for (int i = 0; i < BTM_PLC_TL; i++) {
      sum += ((float)x[i]) * y[i];
      x2 += ((float)x[i]) * x[i];
      y2 += ((float)y[i]) * y[i];
    }
    return sum / sqrtf(x2 * y2);
  }

  int pattern_match(int16_t* hist) {
    int best = 0;
    float cn, max_cn = FLT_MIN;

    for (int i = 0; i < BTM_PLC_WL; i++) {
      cn = cross_correlation(&hist[BTM_PLC_HL - BTM_PLC_TL], &hist[i]);
      if (cn > max_cn) {
        best = i;
        max_cn = cn;
      }
    }
    return best;
  }

  float amplitude_match(int16_t* x, int16_t* y) {
    uint32_t sum_x = 0, sum_y = 0;
    float scaler;
    for (int i = 0; i < BTM_MSBC_FS; i++) {
      sum_x += abs(x[i]);
      sum_y += abs(y[i]);
    }

    if (sum_y == 0) return 1.2f;

    scaler = (float)sum_x / sum_y;
    return scaler > 1.2f ? 1.2f : scaler < 0.75f ? 0.75f : scaler;
  }

 public:
  void init() {
    if (pl_window) osi_free(pl_window);
    pl_window = (tBTM_MSBC_BTM_PLC_WINDOW*)osi_calloc(sizeof(*pl_window));
  }

  void deinit() {
    if (pl_window) osi_free(pl_window);
  }

  void handle_bad_frames(const uint8_t** output) {
    float scaler;
    int16_t* best_match_hist;
    int16_t* frame_head = &hist[BTM_PLC_HL];

    /* mSBC codec is stateful, the history of signal would contribute to the
     * decode result decoded_buffer. This should never fail. */
    GetInterfaceToProfiles()->msbcCodec->decodePacket(
        btm_msbc_zero_packet, decoded_buffer, sizeof(decoded_buffer));

    /* The PLC algorithm is more likely to generate bad results that sound
     * robotic after severe packet losses happened. Only applying it when
     * we are confident. */
    if (!pl_window->is_packet_loss_too_high()) {
      if (handled_bad_frames == 0) {
        /* Finds the best matching samples and amplitude */
        best_lag = pattern_match(hist) + BTM_PLC_TL;
        best_match_hist = &hist[best_lag];
        scaler =
            amplitude_match(&hist[BTM_PLC_HL - BTM_MSBC_FS], best_match_hist);

        /* Constructs the substitution samples */
        overlap_add(frame_head, 1.0, decoded_buffer, scaler, best_match_hist);
        for (int i = BTM_PLC_OLAL; i < BTM_MSBC_FS; i++)
          hist[BTM_PLC_HL + i] = f_to_s16(scaler * best_match_hist[i]);
        overlap_add(&frame_head[BTM_MSBC_FS], scaler,
                    &best_match_hist[BTM_MSBC_FS], 1.0,
                    &best_match_hist[BTM_MSBC_FS]);

        memmove(&frame_head[BTM_MSBC_FS + BTM_PLC_OLAL],
                &best_match_hist[BTM_MSBC_FS + BTM_PLC_OLAL],
                BTM_PLC_SBCRL * BTM_MSBC_SAMPLE_SIZE);
      } else {
        /* Using the existing best lag and copy the following frames */
        memmove(frame_head, &hist[best_lag],
                (BTM_MSBC_FS + BTM_PLC_SBCRL + BTM_PLC_OLAL) *
                    BTM_MSBC_SAMPLE_SIZE);
      }
      /* Copy the constructed frames to decoded buffer for caller to use */
      std::copy(frame_head, &frame_head[BTM_MSBC_FS], decoded_buffer);

      handled_bad_frames++;
    } else {
      /* This is a case similar to receiving a good frame with all zeros, we set
       * handled_bad_frames to zero to prevent the following good frame from
       * being concealed to reconverge with the zero frames we fill in. The
       * concealment result sounds more artificial and weird than simply writing
       * zeros and following samples.
       */
      std::copy(std::begin(decoded_buffer), std::end(decoded_buffer),
                frame_head);
      std::fill(&frame_head[BTM_MSBC_FS],
                &frame_head[BTM_MSBC_FS + BTM_PLC_SBCRL + BTM_PLC_OLAL], 0);
      /* No need to copy the frames as we'll use the decoded zero frames in the
       * decoded buffer as our concealment frames */

      handled_bad_frames = 0;
    }

    *output = (const uint8_t*)decoded_buffer;

    /* Shift the frames to update the history window */
    memmove(hist, &hist[BTM_MSBC_FS],
            (BTM_PLC_HL + BTM_PLC_SBCRL + BTM_PLC_OLAL) * BTM_MSBC_SAMPLE_SIZE);
    pl_window->update_plc_state(1);
  }

  void handle_good_frames(int16_t* input) {
    int16_t* frame_head;
    if (handled_bad_frames != 0) {
      /* If there was a packet concealment before this good frame, we need to
       * reconverge the input frames */
      frame_head = &hist[BTM_PLC_HL];

      /* For the first good frame after packet loss, we need to conceal the
       * received samples to have it reconverge with the true output */
      std::copy(frame_head, &frame_head[BTM_PLC_SBCRL], input);
      /* Overlap the input frame with the previous output frame */
      overlap_add(&input[BTM_PLC_SBCRL], 1.0, &frame_head[BTM_PLC_SBCRL], 1.0,
                  &input[BTM_PLC_SBCRL]);
      handled_bad_frames = 0;
    }

    /* Shift the history and update the good frame to the end of it */
    memmove(hist, &hist[BTM_MSBC_FS],
            (BTM_PLC_HL - BTM_MSBC_FS) * BTM_MSBC_SAMPLE_SIZE);
    std::copy(input, &input[BTM_MSBC_FS], &hist[BTM_PLC_HL - BTM_MSBC_FS]);
    pl_window->update_plc_state(0);
  }
};

/* Define the structure that contains mSBC data */
struct tBTM_MSBC_INFO {
@@ -153,9 +385,11 @@ struct tBTM_MSBC_INFO {
  size_t encode_buf_wo;     /* Write offset of the encode buffer */
  size_t encode_buf_ro;     /* Read offset of the encode buffer */

  int16_t decoded_pcm_buf[120]; /* Buffer to store decoded PCM */
  int16_t decoded_pcm_buf[BTM_MSBC_FS]; /* Buffer to store decoded PCM */

  uint8_t num_encoded_msbc_pkts; /* Number of the encoded mSBC packets */

  tBTM_MSBC_PLC* plc; /* PLC component to handle the packet loss of input */
  static size_t get_supported_packet_size(size_t pkt_size,
                                          size_t* buffer_size) {
    int i;
@@ -202,12 +436,23 @@ struct tBTM_MSBC_INFO {

    if (msbc_encode_buf) osi_free(msbc_encode_buf);
    msbc_encode_buf = (uint8_t*)osi_calloc(buf_size);

    if (plc) {
      plc->deinit();
      osi_free(plc);
    }
    plc = (tBTM_MSBC_PLC*)osi_calloc(sizeof(*plc));
    plc->init();
    return packet_size;
  }

  void deinit() {
    if (msbc_decode_buf) osi_free(msbc_decode_buf);
    if (msbc_encode_buf) osi_free(msbc_encode_buf);
    if (plc) {
      plc->deinit();
      osi_free(plc);
    }
  }

  size_t decodable() { return decode_buf_wo - decode_buf_ro; }
@@ -393,12 +638,13 @@ size_t decode(const uint8_t** out_data) {
    goto packet_loss;
  }

  msbc_info->plc->handle_good_frames(msbc_info->decoded_pcm_buf);
  *out_data = (const uint8_t*)msbc_info->decoded_pcm_buf;
  msbc_info->mark_pkt_decoded();
  return BTM_MSBC_CODE_SIZE;

packet_loss:
  *out_data = btm_msbc_zero_frames;
  msbc_info->plc->handle_bad_frames(out_data);
  msbc_info->mark_pkt_decoded();
  return BTM_MSBC_CODE_SIZE;
}
@@ -434,8 +680,8 @@ size_t encode(int16_t* data, size_t len) {
      GetInterfaceToProfiles()->msbcCodec->encodePacket(data, pkt_body);
  if (encoded_size != BTM_MSBC_PKT_FRAME_LEN) {
    LOG_WARN("Encoding invalid packet size: %lu", (unsigned long)encoded_size);
    std::copy(std::begin(btm_msbc_zero_packet), std::end(btm_msbc_zero_packet),
              pkt_body);
    std::copy(&btm_msbc_zero_packet[BTM_MSBC_H2_HEADER_LEN],
              std::end(btm_msbc_zero_packet), pkt_body);
  }

  return BTM_MSBC_CODE_SIZE;
+66 −0
Original line number Diff line number Diff line
@@ -38,6 +38,9 @@ extern bool mock_uipc_send_ret;

namespace {

using testing::AllOf;
using testing::Ge;
using testing::Le;
using testing::Test;

const uint8_t msbc_zero_packet[] = {
@@ -230,6 +233,7 @@ TEST_F(ScoHciWbsWithInitCleanTest, WbsDecode) {
      sizeof(payload));

  // Return all zero frames when there comes an invalid packet.
  // This is expected even with PLC as there is no history in the PLC buffer.
  ASSERT_EQ(bluetooth::audio::sco::wbs::decode(&decoded),
            size_t(BTM_MSBC_CODE_SIZE));
  ASSERT_NE(decoded, nullptr);
@@ -311,4 +315,66 @@ TEST_F(ScoHciWbsWithInitCleanTest, WbsEncodeDequeuePackets) {
  }
}

TEST_F(ScoHciWbsWithInitCleanTest, WbsPlc) {
  int16_t triangle[16] = {0, 100,  200,  300,  400,  300,  200,  100,
                          0, -100, -200, -300, -400, -300, -200, -100};
  int16_t data[120];
  int16_t expect_data[120];
  const uint8_t* encoded = nullptr;
  const uint8_t* decoded = nullptr;
  uint8_t invalid_pkt[60] = {0};
  size_t lost_pkt_idx = 17;

  // Simulate a run without any packet loss
  for (size_t i = 0, sample_idx = 0; i <= lost_pkt_idx; i++) {
    // Input data is a 1000Hz triangle wave
    for (size_t j = 0; j < 120; j++, sample_idx++)
      data[j] = triangle[sample_idx % 16];
    // Build the packet
    ASSERT_EQ(bluetooth::audio::sco::wbs::encode(data, sizeof(data)),
              sizeof(data));
    ASSERT_EQ(bluetooth::audio::sco::wbs::dequeue_packet(&encoded), size_t(60));
    ASSERT_NE(encoded, nullptr);

    // Simulate the reception of the packet
    ASSERT_EQ(bluetooth::audio::sco::wbs::enqueue_packet(encoded, 60),
              size_t(60));
    ASSERT_EQ(bluetooth::audio::sco::wbs::decode(&decoded),
              size_t(BTM_MSBC_CODE_SIZE));
    ASSERT_NE(decoded, nullptr);
  }
  // Store the decoded data we expect to get
  std::copy((const int16_t*)decoded,
            (const int16_t*)(decoded + BTM_MSBC_CODE_SIZE), expect_data);
  // Start with the fresh WBS buffer
  bluetooth::audio::sco::wbs::cleanup();
  bluetooth::audio::sco::wbs::init(60);
  for (size_t i = 0, sample_idx = 0; i <= lost_pkt_idx; i++) {
    // Data is a 1000Hz triangle wave
    for (size_t j = 0; j < 120; j++, sample_idx++)
      data[j] = triangle[sample_idx % 16];
    ASSERT_EQ(bluetooth::audio::sco::wbs::encode(data, sizeof(data)),
              sizeof(data));
    ASSERT_EQ(bluetooth::audio::sco::wbs::dequeue_packet(&encoded), size_t(60));
    ASSERT_NE(encoded, nullptr);

    // Substitute to invalid packet to simulate packet loss.
    ASSERT_EQ(bluetooth::audio::sco::wbs::enqueue_packet(
                  i != lost_pkt_idx ? encoded : invalid_pkt, 60),
              size_t(60));
    ASSERT_EQ(bluetooth::audio::sco::wbs::decode(&decoded),
              size_t(BTM_MSBC_CODE_SIZE));
    ASSERT_NE(decoded, nullptr);
  }
  int16_t* ptr = (int16_t*)decoded;
  for (size_t i = 0; i < 120; i++) {
    // The frames generated by PLC won't be perfect due to:
    // 1. mSBC decoder is statefull
    // 2. We apply overlap-add to glue the frames when packet loss happens
    ASSERT_THAT(ptr[i] - expect_data[i], AllOf(Ge(-3), Le(3)))
        << "PLC data " << ptr[i] << " deviates from expected " << expect_data[i]
        << " at index " << i;
  }
}

}  // namespace