1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/host/audio_capturer_win.h"
6
7#include <windows.h>
8#include <avrt.h>
9#include <mmreg.h>
10#include <mmsystem.h>
11
12#include <algorithm>
13#include <stdlib.h>
14
15#include "base/logging.h"
16
17namespace {
18const int kChannels = 2;
19const int kBytesPerSample = 2;
20const int kBitsPerSample = kBytesPerSample * 8;
21// Conversion factor from 100ns to 1ms.
22const int k100nsPerMillisecond = 10000;
23
24// Tolerance for catching packets of silence. If all samples have absolute
25// value less than this threshold, the packet will be counted as a packet of
26// silence. A value of 2 was chosen, because Windows can give samples of 1 and
27// -1, even when no audio is playing.
28const int kSilenceThreshold = 2;
29
30// Lower bound for timer intervals, in milliseconds.
31const int kMinTimerInterval = 30;
32
33// Upper bound for the timer precision error, in milliseconds.
34// Timers are supposed to be accurate to 20ms, so we use 30ms to be safe.
35const int kMaxExpectedTimerLag = 30;
36}  // namespace
37
38namespace remoting {
39
40AudioCapturerWin::AudioCapturerWin()
41    : sampling_rate_(AudioPacket::SAMPLING_RATE_INVALID),
42      silence_detector_(kSilenceThreshold),
43      last_capture_error_(S_OK) {
44    thread_checker_.DetachFromThread();
45}
46
47AudioCapturerWin::~AudioCapturerWin() {
48}
49
50bool AudioCapturerWin::Start(const PacketCapturedCallback& callback) {
51  DCHECK(!audio_capture_client_.get());
52  DCHECK(!audio_client_.get());
53  DCHECK(!mm_device_.get());
54  DCHECK(static_cast<PWAVEFORMATEX>(wave_format_ex_) == NULL);
55  DCHECK(thread_checker_.CalledOnValidThread());
56
57  callback_ = callback;
58
59  // Initialize the capture timer.
60  capture_timer_.reset(new base::RepeatingTimer<AudioCapturerWin>());
61
62  HRESULT hr = S_OK;
63
64  base::win::ScopedComPtr<IMMDeviceEnumerator> mm_device_enumerator;
65  hr = mm_device_enumerator.CreateInstance(__uuidof(MMDeviceEnumerator));
66  if (FAILED(hr)) {
67    LOG(ERROR) << "Failed to create IMMDeviceEnumerator. Error " << hr;
68    return false;
69  }
70
71  // Get the audio endpoint.
72  hr = mm_device_enumerator->GetDefaultAudioEndpoint(eRender,
73                                                     eConsole,
74                                                     mm_device_.Receive());
75  if (FAILED(hr)) {
76    LOG(ERROR) << "Failed to get IMMDevice. Error " << hr;
77    return false;
78  }
79
80  // Get an audio client.
81  hr = mm_device_->Activate(__uuidof(IAudioClient),
82                            CLSCTX_ALL,
83                            NULL,
84                            audio_client_.ReceiveVoid());
85  if (FAILED(hr)) {
86    LOG(ERROR) << "Failed to get an IAudioClient. Error " << hr;
87    return false;
88  }
89
90  REFERENCE_TIME device_period;
91  hr = audio_client_->GetDevicePeriod(&device_period, NULL);
92  if (FAILED(hr)) {
93    LOG(ERROR) << "IAudioClient::GetDevicePeriod failed. Error " << hr;
94    return false;
95  }
96  // We round up, if |device_period| / |k100nsPerMillisecond|
97  // is not a whole number.
98  int device_period_in_milliseconds =
99      1 + ((device_period - 1) / k100nsPerMillisecond);
100  audio_device_period_ = base::TimeDelta::FromMilliseconds(
101      std::max(device_period_in_milliseconds, kMinTimerInterval));
102
103  // Get the wave format.
104  hr = audio_client_->GetMixFormat(&wave_format_ex_);
105  if (FAILED(hr)) {
106    LOG(ERROR) << "Failed to get WAVEFORMATEX. Error " << hr;
107    return false;
108  }
109
110  // Set the wave format
111  switch (wave_format_ex_->wFormatTag) {
112    case WAVE_FORMAT_IEEE_FLOAT:
113      // Intentional fall-through.
114    case WAVE_FORMAT_PCM:
115      if (!AudioCapturer::IsValidSampleRate(wave_format_ex_->nSamplesPerSec)) {
116        LOG(ERROR) << "Host sampling rate is neither 44.1 kHz nor 48 kHz.";
117        return false;
118      }
119      sampling_rate_ = static_cast<AudioPacket::SamplingRate>(
120          wave_format_ex_->nSamplesPerSec);
121
122      wave_format_ex_->wFormatTag = WAVE_FORMAT_PCM;
123      wave_format_ex_->nChannels = kChannels;
124      wave_format_ex_->wBitsPerSample = kBitsPerSample;
125      wave_format_ex_->nBlockAlign = kChannels * kBytesPerSample;
126      wave_format_ex_->nAvgBytesPerSec =
127          sampling_rate_ * kChannels * kBytesPerSample;
128      break;
129    case WAVE_FORMAT_EXTENSIBLE: {
130      PWAVEFORMATEXTENSIBLE wave_format_extensible =
131          reinterpret_cast<WAVEFORMATEXTENSIBLE*>(
132          static_cast<WAVEFORMATEX*>(wave_format_ex_));
133      if (IsEqualGUID(KSDATAFORMAT_SUBTYPE_IEEE_FLOAT,
134                      wave_format_extensible->SubFormat)) {
135        if (!AudioCapturer::IsValidSampleRate(
136                wave_format_extensible->Format.nSamplesPerSec)) {
137          LOG(ERROR) << "Host sampling rate is neither 44.1 kHz nor 48 kHz.";
138          return false;
139        }
140        sampling_rate_ = static_cast<AudioPacket::SamplingRate>(
141            wave_format_extensible->Format.nSamplesPerSec);
142
143        wave_format_extensible->SubFormat = KSDATAFORMAT_SUBTYPE_PCM;
144        wave_format_extensible->Samples.wValidBitsPerSample = kBitsPerSample;
145
146        wave_format_extensible->Format.nChannels = kChannels;
147        wave_format_extensible->Format.nSamplesPerSec = sampling_rate_;
148        wave_format_extensible->Format.wBitsPerSample = kBitsPerSample;
149        wave_format_extensible->Format.nBlockAlign =
150            kChannels * kBytesPerSample;
151        wave_format_extensible->Format.nAvgBytesPerSec =
152            sampling_rate_ * kChannels * kBytesPerSample;
153      } else {
154        LOG(ERROR) << "Failed to force 16-bit samples";
155        return false;
156      }
157      break;
158    }
159    default:
160      LOG(ERROR) << "Failed to force 16-bit PCM";
161      return false;
162  }
163
164  // Initialize the IAudioClient.
165  hr = audio_client_->Initialize(
166      AUDCLNT_SHAREMODE_SHARED,
167      AUDCLNT_STREAMFLAGS_LOOPBACK,
168      (kMaxExpectedTimerLag + audio_device_period_.InMilliseconds()) *
169      k100nsPerMillisecond,
170      0,
171      wave_format_ex_,
172      NULL);
173  if (FAILED(hr)) {
174    LOG(ERROR) << "Failed to initialize IAudioClient. Error " << hr;
175    return false;
176  }
177
178  // Get an IAudioCaptureClient.
179  hr = audio_client_->GetService(__uuidof(IAudioCaptureClient),
180                                 audio_capture_client_.ReceiveVoid());
181  if (FAILED(hr)) {
182    LOG(ERROR) << "Failed to get an IAudioCaptureClient. Error " << hr;
183    return false;
184  }
185
186  // Start the IAudioClient.
187  hr = audio_client_->Start();
188  if (FAILED(hr)) {
189    LOG(ERROR) << "Failed to start IAudioClient. Error " << hr;
190    return false;
191  }
192
193  silence_detector_.Reset(sampling_rate_, kChannels);
194
195  // Start capturing.
196  capture_timer_->Start(FROM_HERE,
197                        audio_device_period_,
198                        this,
199                        &AudioCapturerWin::DoCapture);
200  return true;
201}
202
203void AudioCapturerWin::Stop() {
204  DCHECK(thread_checker_.CalledOnValidThread());
205  DCHECK(IsStarted());
206
207  capture_timer_.reset();
208  mm_device_.Release();
209  audio_client_.Release();
210  audio_capture_client_.Release();
211  wave_format_ex_.Reset(NULL);
212
213  thread_checker_.DetachFromThread();
214}
215
216bool AudioCapturerWin::IsStarted() {
217  DCHECK(thread_checker_.CalledOnValidThread());
218  return capture_timer_.get() != NULL;
219}
220
221void AudioCapturerWin::DoCapture() {
222  DCHECK(AudioCapturer::IsValidSampleRate(sampling_rate_));
223  DCHECK(thread_checker_.CalledOnValidThread());
224  DCHECK(IsStarted());
225
226  // Fetch all packets from the audio capture endpoint buffer.
227  HRESULT hr = S_OK;
228  while (true) {
229    UINT32 next_packet_size;
230    HRESULT hr = audio_capture_client_->GetNextPacketSize(&next_packet_size);
231    if (FAILED(hr))
232      break;
233
234    if (next_packet_size <= 0) {
235      return;
236    }
237
238    BYTE* data;
239    UINT32 frames;
240    DWORD flags;
241    hr = audio_capture_client_->GetBuffer(&data, &frames, &flags, NULL, NULL);
242    if (FAILED(hr))
243      break;
244
245    if ((flags & AUDCLNT_BUFFERFLAGS_SILENT) == 0 &&
246        !silence_detector_.IsSilence(
247            reinterpret_cast<const int16*>(data), frames * kChannels)) {
248      scoped_ptr<AudioPacket> packet =
249          scoped_ptr<AudioPacket>(new AudioPacket());
250      packet->add_data(data, frames * wave_format_ex_->nBlockAlign);
251      packet->set_encoding(AudioPacket::ENCODING_RAW);
252      packet->set_sampling_rate(sampling_rate_);
253      packet->set_bytes_per_sample(AudioPacket::BYTES_PER_SAMPLE_2);
254      packet->set_channels(AudioPacket::CHANNELS_STEREO);
255
256      callback_.Run(packet.Pass());
257    }
258
259    hr = audio_capture_client_->ReleaseBuffer(frames);
260    if (FAILED(hr))
261      break;
262  }
263
264  // There is nothing to capture if the audio endpoint device has been unplugged
265  // or disabled.
266  if (hr == AUDCLNT_E_DEVICE_INVALIDATED)
267    return;
268
269  // Avoid reporting the same error multiple times.
270  if (FAILED(hr) && hr != last_capture_error_) {
271    last_capture_error_ = hr;
272    LOG(ERROR) << "Failed to capture an audio packet: 0x"
273               << std::hex << hr << std::dec << ".";
274  }
275}
276
277bool AudioCapturer::IsSupported() {
278  return true;
279}
280
281scoped_ptr<AudioCapturer> AudioCapturer::Create() {
282  return scoped_ptr<AudioCapturer>(new AudioCapturerWin());
283}
284
285}  // namespace remoting
286