Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 1d76a6f2 authored by Treehugger Robot's avatar Treehugger Robot Committed by Gerrit Code Review
Browse files

Merge "floss: Implement the PLC(Packet Loss Concealment)"

parents d47dd0c2 06dea72f
Loading
Loading
Loading
Loading
+254 −8
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@
#include <sys/stat.h>
#include <unistd.h>

#include <cfloat>
#include <memory>

#include "btif/include/core_callbacks.h"
@@ -40,6 +41,22 @@
#define BTM_MSBC_PKT_FRAME_LEN 57 /* Packet length without the header */
#define BTM_MSBC_SYNC_WORD 0xAD

/* Used by PLC */
#define BTM_MSBC_SAMPLE_SIZE 2 /* 2 bytes*/
#define BTM_MSBC_FS 120        /* Frame Size */

#define BTM_PLC_WL 256 /* 16ms - Window Length for pattern matching */
#define BTM_PLC_TL 64  /* 4ms - Template Length for matching */
#define BTM_PLC_HL \
  (BTM_PLC_WL + BTM_MSBC_FS - 1) /* Length of History buffer required */
#define BTM_PLC_SBCRL 36         /* SBC Reconvergence sample Length */
#define BTM_PLC_OLAL 16          /* OverLap-Add Length */

/* Disable the PLC when there are more than threshold of lost packets in the
 * window */
#define BTM_PLC_WINDOW_SIZE 5
#define BTM_PLC_PL_THRESHOLD 1

namespace {

std::unique_ptr<tUIPC_STATE> sco_uipc = nullptr;
@@ -128,16 +145,231 @@ constexpr size_t btm_wbs_supported_pkt_size[] = {BTM_MSBC_PKT_LEN, 72, 0};
 * BTM_MSBC_PKT_LEN for optimizing buffer copy. */
constexpr size_t btm_wbs_msbc_buffer_size[] = {BTM_MSBC_PKT_LEN, 360, 0};

/* The pre-computed zero input bit stream of mSBC codec, per HFP 1.7 spec.
 * This mSBC frame will be decoded into all-zero input PCM. */
/* The pre-computed SCO packet per HFP 1.7 spec. This mSBC packet will be
 * decoded into all-zero input PCM. */
static const uint8_t btm_msbc_zero_packet[] = {
    0x01, 0x08, /* Mock H2 header */
    0xad, 0x00, 0x00, 0xc5, 0x00, 0x00, 0x00, 0x00, 0x77, 0x6d, 0xb6, 0xdd,
    0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6d, 0xdd, 0xb6, 0xdb, 0x77, 0x6d, 0xb6,
    0xdd, 0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6d, 0xdd, 0xb6, 0xdb, 0x77, 0x6d,
    0xb6, 0xdd, 0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6d, 0xdd, 0xb6, 0xdb, 0x77,
    0x6d, 0xb6, 0xdd, 0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6c};
    0x6d, 0xb6, 0xdd, 0xdb, 0x6d, 0xb7, 0x76, 0xdb, 0x6c,
    /* End of Audio Samples */
    0x00 /* A padding byte defined by mSBC */};

/* Raised Cosine table for OLA */
static const float rcos[BTM_PLC_OLAL] = {
    0.99148655f, 0.96623611f, 0.92510857f, 0.86950446f,
    0.80131732f, 0.72286918f, 0.63683150f, 0.54613418f,
    0.45386582f, 0.36316850f, 0.27713082f, 0.19868268f,
    0.13049554f, 0.07489143f, 0.03376389f, 0.00851345f};

static int16_t f_to_s16(float input) {
  return input > INT16_MAX   ? INT16_MAX
         : input < INT16_MIN ? INT16_MIN
                             : (int16_t)input;
}
/* This structure tracks the packet loss for last PLC_WINDOW_SIZE of packets */
struct tBTM_MSBC_BTM_PLC_WINDOW {
  bool loss_hist[BTM_PLC_WINDOW_SIZE]; /* The packet loss history of receiving
                                      packets.*/
  unsigned int idx;   /* The index of the to be updated packet loss status. */
  unsigned int count; /* The count of lost packets in the window. */

 public:
  void update_plc_state(bool is_packet_loss) {
    bool* curr = &loss_hist[idx];
    if (is_packet_loss != *curr) {
      count += (is_packet_loss - *curr);
      *curr = is_packet_loss;
    }
    idx = (idx + 1) % BTM_PLC_WINDOW_SIZE;
  }

  bool is_packet_loss_too_high() {
    /* The packet loss count comes from a time window and we use it as an
     * indicator of our confidence of the PLC algorithm. It is known to
     * generate poorer and robotic feeling sounds, when the majority of
     * samples in the PLC history buffer are from the concealment results.
     */
    return count > BTM_PLC_PL_THRESHOLD;
  }
};

/* The PLC is specifically designed for mSBC. The algorithm searches the
 * history of receiving samples to find the best match samples and constructs
 * substitutions for the lost samples. The selection is based on pattern
 * matching a template, composed of a length of samples preceding to the lost
 * samples. It then uses the following samples after the best match as the
 * replacement samples and applies Overlap-Add to reduce the audible
 * distortion.
 *
 * This structure holds related info needed to conduct the PLC algorithm.
 */
struct tBTM_MSBC_PLC {
  int16_t hist[BTM_PLC_HL + BTM_MSBC_FS + BTM_PLC_SBCRL +
               BTM_PLC_OLAL]; /* The history buffer for receiving samples, we
                                 also use it to buffer the processed
                                 replacement samples */
  unsigned best_lag;      /* The index of the best substitution samples in the
                             sample history */
  int handled_bad_frames; /* Number of bad frames handled since the last good
                             frame */
  int16_t decoded_buffer[BTM_MSBC_FS]; /* Used for storing the samples from
                                      decoding the mSBC zero frame packet and
                                      also constructed frames */
  tBTM_MSBC_BTM_PLC_WINDOW*
      pl_window; /* Used to monitor how many packets are bad within the recent
                    BTM_PLC_WINDOW_SIZE of packets. We use this to determine if
                    we want to disable the PLC temporarily */

  void overlap_add(int16_t* output, float scaler_d, const int16_t* desc,
                   float scaler_a, const int16_t* asc) {
    for (int i = 0; i < BTM_PLC_OLAL; i++) {
      output[i] = f_to_s16(scaler_d * desc[i] * rcos[i] +
                           scaler_a * asc[i] * rcos[BTM_PLC_OLAL - 1 - i]);
    }
  }

static const uint8_t btm_msbc_zero_frames[BTM_MSBC_CODE_SIZE] = {0};
  float cross_correlation(int16_t* x, int16_t* y) {
    float sum = 0, x2 = 0, y2 = 0;

    for (int i = 0; i < BTM_PLC_TL; i++) {
      sum += ((float)x[i]) * y[i];
      x2 += ((float)x[i]) * x[i];
      y2 += ((float)y[i]) * y[i];
    }
    return sum / sqrtf(x2 * y2);
  }

  int pattern_match(int16_t* hist) {
    int best = 0;
    float cn, max_cn = FLT_MIN;

    for (int i = 0; i < BTM_PLC_WL; i++) {
      cn = cross_correlation(&hist[BTM_PLC_HL - BTM_PLC_TL], &hist[i]);
      if (cn > max_cn) {
        best = i;
        max_cn = cn;
      }
    }
    return best;
  }

  float amplitude_match(int16_t* x, int16_t* y) {
    uint32_t sum_x = 0, sum_y = 0;
    float scaler;
    for (int i = 0; i < BTM_MSBC_FS; i++) {
      sum_x += abs(x[i]);
      sum_y += abs(y[i]);
    }

    if (sum_y == 0) return 1.2f;

    scaler = (float)sum_x / sum_y;
    return scaler > 1.2f ? 1.2f : scaler < 0.75f ? 0.75f : scaler;
  }

 public:
  void init() {
    if (pl_window) osi_free(pl_window);
    pl_window = (tBTM_MSBC_BTM_PLC_WINDOW*)osi_calloc(sizeof(*pl_window));
  }

  void deinit() {
    if (pl_window) osi_free(pl_window);
  }

  void handle_bad_frames(const uint8_t** output) {
    float scaler;
    int16_t* best_match_hist;
    int16_t* frame_head = &hist[BTM_PLC_HL];

    /* mSBC codec is stateful, the history of signal would contribute to the
     * decode result decoded_buffer. This should never fail. */
    GetInterfaceToProfiles()->msbcCodec->decodePacket(
        btm_msbc_zero_packet, decoded_buffer, sizeof(decoded_buffer));

    /* The PLC algorithm is more likely to generate bad results that sound
     * robotic after severe packet losses happened. Only applying it when
     * we are confident. */
    if (!pl_window->is_packet_loss_too_high()) {
      if (handled_bad_frames == 0) {
        /* Finds the best matching samples and amplitude */
        best_lag = pattern_match(hist) + BTM_PLC_TL;
        best_match_hist = &hist[best_lag];
        scaler =
            amplitude_match(&hist[BTM_PLC_HL - BTM_MSBC_FS], best_match_hist);

        /* Constructs the substitution samples */
        overlap_add(frame_head, 1.0, decoded_buffer, scaler, best_match_hist);
        for (int i = BTM_PLC_OLAL; i < BTM_MSBC_FS; i++)
          hist[BTM_PLC_HL + i] = f_to_s16(scaler * best_match_hist[i]);
        overlap_add(&frame_head[BTM_MSBC_FS], scaler,
                    &best_match_hist[BTM_MSBC_FS], 1.0,
                    &best_match_hist[BTM_MSBC_FS]);

        memmove(&frame_head[BTM_MSBC_FS + BTM_PLC_OLAL],
                &best_match_hist[BTM_MSBC_FS + BTM_PLC_OLAL],
                BTM_PLC_SBCRL * BTM_MSBC_SAMPLE_SIZE);
      } else {
        /* Using the existing best lag and copy the following frames */
        memmove(frame_head, &hist[best_lag],
                (BTM_MSBC_FS + BTM_PLC_SBCRL + BTM_PLC_OLAL) *
                    BTM_MSBC_SAMPLE_SIZE);
      }
      /* Copy the constructed frames to decoded buffer for caller to use */
      std::copy(frame_head, &frame_head[BTM_MSBC_FS], decoded_buffer);

      handled_bad_frames++;
    } else {
      /* This is a case similar to receiving a good frame with all zeros, we set
       * handled_bad_frames to zero to prevent the following good frame from
       * being concealed to reconverge with the zero frames we fill in. The
       * concealment result sounds more artificial and weird than simply writing
       * zeros and following samples.
       */
      std::copy(std::begin(decoded_buffer), std::end(decoded_buffer),
                frame_head);
      std::fill(&frame_head[BTM_MSBC_FS],
                &frame_head[BTM_MSBC_FS + BTM_PLC_SBCRL + BTM_PLC_OLAL], 0);
      /* No need to copy the frames as we'll use the decoded zero frames in the
       * decoded buffer as our concealment frames */

      handled_bad_frames = 0;
    }

    *output = (const uint8_t*)decoded_buffer;

    /* Shift the frames to update the history window */
    memmove(hist, &hist[BTM_MSBC_FS],
            (BTM_PLC_HL + BTM_PLC_SBCRL + BTM_PLC_OLAL) * BTM_MSBC_SAMPLE_SIZE);
    pl_window->update_plc_state(1);
  }

  void handle_good_frames(int16_t* input) {
    int16_t* frame_head;
    if (handled_bad_frames != 0) {
      /* If there was a packet concealment before this good frame, we need to
       * reconverge the input frames */
      frame_head = &hist[BTM_PLC_HL];

      /* For the first good frame after packet loss, we need to conceal the
       * received samples to have it reconverge with the true output */
      std::copy(frame_head, &frame_head[BTM_PLC_SBCRL], input);
      /* Overlap the input frame with the previous output frame */
      overlap_add(&input[BTM_PLC_SBCRL], 1.0, &frame_head[BTM_PLC_SBCRL], 1.0,
                  &input[BTM_PLC_SBCRL]);
      handled_bad_frames = 0;
    }

    /* Shift the history and update the good frame to the end of it */
    memmove(hist, &hist[BTM_MSBC_FS],
            (BTM_PLC_HL - BTM_MSBC_FS) * BTM_MSBC_SAMPLE_SIZE);
    std::copy(input, &input[BTM_MSBC_FS], &hist[BTM_PLC_HL - BTM_MSBC_FS]);
    pl_window->update_plc_state(0);
  }
};

/* Define the structure that contains mSBC data */
struct tBTM_MSBC_INFO {
@@ -153,9 +385,11 @@ struct tBTM_MSBC_INFO {
  size_t encode_buf_wo;     /* Write offset of the encode buffer */
  size_t encode_buf_ro;     /* Read offset of the encode buffer */

  int16_t decoded_pcm_buf[120]; /* Buffer to store decoded PCM */
  int16_t decoded_pcm_buf[BTM_MSBC_FS]; /* Buffer to store decoded PCM */

  uint8_t num_encoded_msbc_pkts; /* Number of the encoded mSBC packets */

  tBTM_MSBC_PLC* plc; /* PLC component to handle the packet loss of input */
  static size_t get_supported_packet_size(size_t pkt_size,
                                          size_t* buffer_size) {
    int i;
@@ -202,12 +436,23 @@ struct tBTM_MSBC_INFO {

    if (msbc_encode_buf) osi_free(msbc_encode_buf);
    msbc_encode_buf = (uint8_t*)osi_calloc(buf_size);

    if (plc) {
      plc->deinit();
      osi_free(plc);
    }
    plc = (tBTM_MSBC_PLC*)osi_calloc(sizeof(*plc));
    plc->init();
    return packet_size;
  }

  void deinit() {
    if (msbc_decode_buf) osi_free(msbc_decode_buf);
    if (msbc_encode_buf) osi_free(msbc_encode_buf);
    if (plc) {
      plc->deinit();
      osi_free(plc);
    }
  }

  size_t decodable() { return decode_buf_wo - decode_buf_ro; }
@@ -393,12 +638,13 @@ size_t decode(const uint8_t** out_data) {
    goto packet_loss;
  }

  msbc_info->plc->handle_good_frames(msbc_info->decoded_pcm_buf);
  *out_data = (const uint8_t*)msbc_info->decoded_pcm_buf;
  msbc_info->mark_pkt_decoded();
  return BTM_MSBC_CODE_SIZE;

packet_loss:
  *out_data = btm_msbc_zero_frames;
  msbc_info->plc->handle_bad_frames(out_data);
  msbc_info->mark_pkt_decoded();
  return BTM_MSBC_CODE_SIZE;
}
@@ -434,8 +680,8 @@ size_t encode(int16_t* data, size_t len) {
      GetInterfaceToProfiles()->msbcCodec->encodePacket(data, pkt_body);
  if (encoded_size != BTM_MSBC_PKT_FRAME_LEN) {
    LOG_WARN("Encoding invalid packet size: %lu", (unsigned long)encoded_size);
    std::copy(std::begin(btm_msbc_zero_packet), std::end(btm_msbc_zero_packet),
              pkt_body);
    std::copy(&btm_msbc_zero_packet[BTM_MSBC_H2_HEADER_LEN],
              std::end(btm_msbc_zero_packet), pkt_body);
  }

  return BTM_MSBC_CODE_SIZE;
+66 −0
Original line number Diff line number Diff line
@@ -38,6 +38,9 @@ extern bool mock_uipc_send_ret;

namespace {

using testing::AllOf;
using testing::Ge;
using testing::Le;
using testing::Test;

const uint8_t msbc_zero_packet[] = {
@@ -230,6 +233,7 @@ TEST_F(ScoHciWbsWithInitCleanTest, WbsDecode) {
      sizeof(payload));

  // Return all zero frames when there comes an invalid packet.
  // This is expected even with PLC as there is no history in the PLC buffer.
  ASSERT_EQ(bluetooth::audio::sco::wbs::decode(&decoded),
            size_t(BTM_MSBC_CODE_SIZE));
  ASSERT_NE(decoded, nullptr);
@@ -311,4 +315,66 @@ TEST_F(ScoHciWbsWithInitCleanTest, WbsEncodeDequeuePackets) {
  }
}

TEST_F(ScoHciWbsWithInitCleanTest, WbsPlc) {
  int16_t triangle[16] = {0, 100,  200,  300,  400,  300,  200,  100,
                          0, -100, -200, -300, -400, -300, -200, -100};
  int16_t data[120];
  int16_t expect_data[120];
  const uint8_t* encoded = nullptr;
  const uint8_t* decoded = nullptr;
  uint8_t invalid_pkt[60] = {0};
  size_t lost_pkt_idx = 17;

  // Simulate a run without any packet loss
  for (size_t i = 0, sample_idx = 0; i <= lost_pkt_idx; i++) {
    // Input data is a 1000Hz triangle wave
    for (size_t j = 0; j < 120; j++, sample_idx++)
      data[j] = triangle[sample_idx % 16];
    // Build the packet
    ASSERT_EQ(bluetooth::audio::sco::wbs::encode(data, sizeof(data)),
              sizeof(data));
    ASSERT_EQ(bluetooth::audio::sco::wbs::dequeue_packet(&encoded), size_t(60));
    ASSERT_NE(encoded, nullptr);

    // Simulate the reception of the packet
    ASSERT_EQ(bluetooth::audio::sco::wbs::enqueue_packet(encoded, 60),
              size_t(60));
    ASSERT_EQ(bluetooth::audio::sco::wbs::decode(&decoded),
              size_t(BTM_MSBC_CODE_SIZE));
    ASSERT_NE(decoded, nullptr);
  }
  // Store the decoded data we expect to get
  std::copy((const int16_t*)decoded,
            (const int16_t*)(decoded + BTM_MSBC_CODE_SIZE), expect_data);
  // Start with the fresh WBS buffer
  bluetooth::audio::sco::wbs::cleanup();
  bluetooth::audio::sco::wbs::init(60);
  for (size_t i = 0, sample_idx = 0; i <= lost_pkt_idx; i++) {
    // Data is a 1000Hz triangle wave
    for (size_t j = 0; j < 120; j++, sample_idx++)
      data[j] = triangle[sample_idx % 16];
    ASSERT_EQ(bluetooth::audio::sco::wbs::encode(data, sizeof(data)),
              sizeof(data));
    ASSERT_EQ(bluetooth::audio::sco::wbs::dequeue_packet(&encoded), size_t(60));
    ASSERT_NE(encoded, nullptr);

    // Substitute to invalid packet to simulate packet loss.
    ASSERT_EQ(bluetooth::audio::sco::wbs::enqueue_packet(
                  i != lost_pkt_idx ? encoded : invalid_pkt, 60),
              size_t(60));
    ASSERT_EQ(bluetooth::audio::sco::wbs::decode(&decoded),
              size_t(BTM_MSBC_CODE_SIZE));
    ASSERT_NE(decoded, nullptr);
  }
  int16_t* ptr = (int16_t*)decoded;
  for (size_t i = 0; i < 120; i++) {
    // The frames generated by PLC won't be perfect due to:
    // 1. mSBC decoder is statefull
    // 2. We apply overlap-add to glue the frames when packet loss happens
    ASSERT_THAT(ptr[i] - expect_data[i], AllOf(Ge(-3), Le(3)))
        << "PLC data " << ptr[i] << " deviates from expected " << expect_data[i]
        << " at index " << i;
  }
}

}  // namespace