1/*
2 *  Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include <math.h>
12#include <stdio.h>
13#include "webrtc/modules/audio_coding/neteq/tools/neteq_quality_test.h"
14
15namespace webrtc {
16namespace test {
17
18const uint8_t kPayloadType = 95;
19const int kOutputSizeMs = 10;
20const int kInitSeed = 0x12345678;
21const int kPacketLossTimeUnitMs = 10;
22
23// Define switch for packet loss rate.
24static bool ValidatePacketLossRate(const char* /* flag_name */, int32_t value) {
25  if (value >= 0 && value <= 100)
26    return true;
27  printf("Invalid packet loss percentile, should be between 0 and 100.");
28  return false;
29}
30
31DEFINE_int32(packet_loss_rate, 10, "Percentile of packet loss.");
32
33static const bool packet_loss_rate_dummy =
34    RegisterFlagValidator(&FLAGS_packet_loss_rate, &ValidatePacketLossRate);
35
36// Define switch for random loss mode.
37static bool ValidateRandomLossMode(const char* /* flag_name */, int32_t value) {
38  if (value >= 0 && value <= 2)
39    return true;
40  printf("Invalid random packet loss mode, should be between 0 and 2.");
41  return false;
42}
43
44DEFINE_int32(random_loss_mode, 1,
45    "Random loss mode: 0--no loss, 1--uniform loss, 2--Gilbert Elliot loss.");
46static const bool random_loss_mode_dummy =
47    RegisterFlagValidator(&FLAGS_random_loss_mode, &ValidateRandomLossMode);
48
49// Define switch for burst length.
50static bool ValidateBurstLength(const char* /* flag_name */, int32_t value) {
51  if (value >= kPacketLossTimeUnitMs)
52    return true;
53  printf("Invalid burst length, should be greater than %d ms.",
54         kPacketLossTimeUnitMs);
55  return false;
56}
57
58DEFINE_int32(burst_length, 30,
59    "Burst length in milliseconds, only valid for Gilbert Elliot loss.");
60
61static const bool burst_length_dummy =
62    RegisterFlagValidator(&FLAGS_burst_length, &ValidateBurstLength);
63
64// Define switch for drift factor.
65static bool ValidateDriftFactor(const char* /* flag_name */, double value) {
66  if (value > -0.1)
67    return true;
68  printf("Invalid drift factor, should be greater than -0.1.");
69  return false;
70}
71
72DEFINE_double(drift_factor, 0.0, "Time drift factor.");
73
74static const bool drift_factor_dummy =
75    RegisterFlagValidator(&FLAGS_drift_factor, &ValidateDriftFactor);
76
77// ProbTrans00Solver() is to calculate the transition probability from no-loss
78// state to itself in a modified Gilbert Elliot packet loss model. The result is
79// to achieve the target packet loss rate |loss_rate|, when a packet is not
80// lost only if all |units| drawings within the duration of the packet result in
81// no-loss.
82static double ProbTrans00Solver(int units, double loss_rate,
83                                double prob_trans_10) {
84  if (units == 1)
85    return prob_trans_10 / (1.0f - loss_rate) - prob_trans_10;
86// 0 == prob_trans_00 ^ (units - 1) + (1 - loss_rate) / prob_trans_10 *
87//     prob_trans_00 - (1 - loss_rate) * (1 + 1 / prob_trans_10).
88// There is a unique solution between 0.0 and 1.0, due to the monotonicity and
89// an opposite sign at 0.0 and 1.0.
90// For simplicity, we reformulate the equation as
91//     f(x) = x ^ (units - 1) + a x + b.
92// Its derivative is
93//     f'(x) = (units - 1) x ^ (units - 2) + a.
94// The derivative is strictly greater than 0 when x is between 0 and 1.
95// We use Newton's method to solve the equation, iteration is
96//     x(k+1) = x(k) - f(x) / f'(x);
97  const double kPrecision = 0.001f;
98  const int kIterations = 100;
99  const double a = (1.0f - loss_rate) / prob_trans_10;
100  const double b = (loss_rate - 1.0f) * (1.0f + 1.0f / prob_trans_10);
101  double x = 0.0f;  // Starting point;
102  double f = b;
103  double f_p;
104  int iter = 0;
105  while ((f >= kPrecision || f <= -kPrecision) && iter < kIterations) {
106    f_p = (units - 1.0f) * pow(x, units - 2) + a;
107    x -= f / f_p;
108    if (x > 1.0f) {
109      x = 1.0f;
110    } else if (x < 0.0f) {
111      x = 0.0f;
112    }
113    f = pow(x, units - 1) + a * x + b;
114    iter ++;
115  }
116  return x;
117}
118
119NetEqQualityTest::NetEqQualityTest(int block_duration_ms,
120                                   int in_sampling_khz,
121                                   int out_sampling_khz,
122                                   enum NetEqDecoder decoder_type,
123                                   int channels,
124                                   std::string in_filename,
125                                   std::string out_filename)
126    : decoded_time_ms_(0),
127      decodable_time_ms_(0),
128      drift_factor_(FLAGS_drift_factor),
129      packet_loss_rate_(FLAGS_packet_loss_rate),
130      block_duration_ms_(block_duration_ms),
131      in_sampling_khz_(in_sampling_khz),
132      out_sampling_khz_(out_sampling_khz),
133      decoder_type_(decoder_type),
134      channels_(channels),
135      in_filename_(in_filename),
136      out_filename_(out_filename),
137      log_filename_(out_filename + ".log"),
138      in_size_samples_(in_sampling_khz_ * block_duration_ms_),
139      out_size_samples_(out_sampling_khz_ * kOutputSizeMs),
140      payload_size_bytes_(0),
141      max_payload_bytes_(0),
142      in_file_(new InputAudioFile(in_filename_)),
143      out_file_(NULL),
144      log_file_(NULL),
145      rtp_generator_(new RtpGenerator(in_sampling_khz_, 0, 0,
146                                      decodable_time_ms_)),
147      total_payload_size_bytes_(0) {
148  NetEq::Config config;
149  config.sample_rate_hz = out_sampling_khz_ * 1000;
150  neteq_.reset(NetEq::Create(config));
151  max_payload_bytes_ = in_size_samples_ * channels_ * sizeof(int16_t);
152  in_data_.reset(new int16_t[in_size_samples_ * channels_]);
153  payload_.reset(new uint8_t[max_payload_bytes_]);
154  out_data_.reset(new int16_t[out_size_samples_ * channels_]);
155}
156
157bool NoLoss::Lost() {
158  return false;
159}
160
161UniformLoss::UniformLoss(double loss_rate)
162    : loss_rate_(loss_rate) {
163}
164
165bool UniformLoss::Lost() {
166  int drop_this = rand();
167  return (drop_this < loss_rate_ * RAND_MAX);
168}
169
170GilbertElliotLoss::GilbertElliotLoss(double prob_trans_11, double prob_trans_01)
171    : prob_trans_11_(prob_trans_11),
172      prob_trans_01_(prob_trans_01),
173      lost_last_(false),
174      uniform_loss_model_(new UniformLoss(0)) {
175}
176
177bool GilbertElliotLoss::Lost() {
178  // Simulate bursty channel (Gilbert model).
179  // (1st order) Markov chain model with memory of the previous/last
180  // packet state (lost or received).
181  if (lost_last_) {
182    // Previous packet was not received.
183    uniform_loss_model_->set_loss_rate(prob_trans_11_);
184    return lost_last_ = uniform_loss_model_->Lost();
185  } else {
186    uniform_loss_model_->set_loss_rate(prob_trans_01_);
187    return lost_last_ = uniform_loss_model_->Lost();
188  }
189}
190
191void NetEqQualityTest::SetUp() {
192  out_file_ = fopen(out_filename_.c_str(), "wb");
193  log_file_ = fopen(log_filename_.c_str(), "wt");
194  ASSERT_TRUE(out_file_ != NULL);
195  ASSERT_EQ(0, neteq_->RegisterPayloadType(decoder_type_, kPayloadType));
196  rtp_generator_->set_drift_factor(drift_factor_);
197
198  int units = block_duration_ms_ / kPacketLossTimeUnitMs;
199  switch (FLAGS_random_loss_mode) {
200    case 1: {
201      // |unit_loss_rate| is the packet loss rate for each unit time interval
202      // (kPacketLossTimeUnitMs). Since a packet loss event is generated if any
203      // of |block_duration_ms_ / kPacketLossTimeUnitMs| unit time intervals of
204      // a full packet duration is drawn with a loss, |unit_loss_rate| fulfills
205      // (1 - unit_loss_rate) ^ (block_duration_ms_ / kPacketLossTimeUnitMs) ==
206      // 1 - packet_loss_rate.
207      double unit_loss_rate = (1.0f - pow(1.0f - 0.01f * packet_loss_rate_,
208          1.0f / units));
209      loss_model_.reset(new UniformLoss(unit_loss_rate));
210      break;
211    }
212    case 2: {
213      // |FLAGS_burst_length| should be integer times of kPacketLossTimeUnitMs.
214      ASSERT_EQ(0, FLAGS_burst_length % kPacketLossTimeUnitMs);
215
216      // We do not allow 100 percent packet loss in Gilbert Elliot model, which
217      // makes no sense.
218      ASSERT_GT(100, packet_loss_rate_);
219
220      // To guarantee the overall packet loss rate, transition probabilities
221      // need to satisfy:
222      // pi_0 * (1 - prob_trans_01_) ^ units +
223      //     pi_1 * prob_trans_10_ ^ (units - 1) == 1 - loss_rate
224      // pi_0 = prob_trans_10 / (prob_trans_10 + prob_trans_01_)
225      //     is the stationary state probability of no-loss
226      // pi_1 = prob_trans_01_ / (prob_trans_10 + prob_trans_01_)
227      //     is the stationary state probability of loss
228      // After a derivation prob_trans_00 should satisfy:
229      // prob_trans_00 ^ (units - 1) = (loss_rate - 1) / prob_trans_10 *
230      //     prob_trans_00 + (1 - loss_rate) * (1 + 1 / prob_trans_10).
231      double loss_rate = 0.01f * packet_loss_rate_;
232      double prob_trans_10 = 1.0f * kPacketLossTimeUnitMs / FLAGS_burst_length;
233      double prob_trans_00 = ProbTrans00Solver(units, loss_rate, prob_trans_10);
234      loss_model_.reset(new GilbertElliotLoss(1.0f - prob_trans_10,
235                                              1.0f - prob_trans_00));
236      break;
237    }
238    default: {
239      loss_model_.reset(new NoLoss);
240      break;
241    }
242  }
243
244  // Make sure that the packet loss profile is same for all derived tests.
245  srand(kInitSeed);
246}
247
248void NetEqQualityTest::TearDown() {
249  fclose(out_file_);
250}
251
252bool NetEqQualityTest::PacketLost() {
253  int cycles = block_duration_ms_ / kPacketLossTimeUnitMs;
254
255  // The loop is to make sure that codecs with different block lengths share the
256  // same packet loss profile.
257  bool lost = false;
258  for (int idx = 0; idx < cycles; idx ++) {
259    if (loss_model_->Lost()) {
260      // The packet will be lost if any of the drawings indicates a loss, but
261      // the loop has to go on to make sure that codecs with different block
262      // lengths keep the same pace.
263      lost = true;
264    }
265  }
266  return lost;
267}
268
269int NetEqQualityTest::Transmit() {
270  int packet_input_time_ms =
271      rtp_generator_->GetRtpHeader(kPayloadType, in_size_samples_,
272                                   &rtp_header_);
273  if (payload_size_bytes_ > 0) {
274    fprintf(log_file_, "Packet at %d ms", packet_input_time_ms);
275    if (!PacketLost()) {
276      int ret = neteq_->InsertPacket(rtp_header_, &payload_[0],
277                                     payload_size_bytes_,
278                                     packet_input_time_ms * in_sampling_khz_);
279      if (ret != NetEq::kOK)
280        return -1;
281      fprintf(log_file_, " OK.\n");
282    } else {
283      fprintf(log_file_, " Lost.\n");
284    }
285  }
286  return packet_input_time_ms;
287}
288
289int NetEqQualityTest::DecodeBlock() {
290  int channels;
291  int samples;
292  int ret = neteq_->GetAudio(out_size_samples_ * channels_, &out_data_[0],
293                             &samples, &channels, NULL);
294
295  if (ret != NetEq::kOK) {
296    return -1;
297  } else {
298    assert(channels == channels_);
299    assert(samples == kOutputSizeMs * out_sampling_khz_);
300    fwrite(&out_data_[0], sizeof(int16_t), samples * channels, out_file_);
301    return samples;
302  }
303}
304
305void NetEqQualityTest::Simulate(int end_time_ms) {
306  int audio_size_samples;
307
308  while (decoded_time_ms_ < end_time_ms) {
309    // Assume 10 packets in packets buffer.
310    while (decodable_time_ms_ - 10 * block_duration_ms_ < decoded_time_ms_) {
311      ASSERT_TRUE(in_file_->Read(in_size_samples_ * channels_, &in_data_[0]));
312      payload_size_bytes_ = EncodeBlock(&in_data_[0],
313                                        in_size_samples_, &payload_[0],
314                                        max_payload_bytes_);
315      total_payload_size_bytes_ += payload_size_bytes_;
316      decodable_time_ms_ = Transmit() + block_duration_ms_;
317    }
318    audio_size_samples = DecodeBlock();
319    if (audio_size_samples > 0) {
320      decoded_time_ms_ += audio_size_samples / out_sampling_khz_;
321    }
322  }
323  fprintf(log_file_, "%f", 8.0f * total_payload_size_bytes_ / end_time_ms);
324}
325
326}  // namespace test
327}  // namespace webrtc
328