1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""General shape ops for frames."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22from tensorflow.contrib.signal.python.ops import util_ops
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_util
25
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import math_ops
28
29
30def _infer_frame_shape(signal, frame_length, frame_step, pad_end, axis):
31  """Infers the shape of the return value of `frame`."""
32  frame_length = tensor_util.constant_value(frame_length)
33  frame_step = tensor_util.constant_value(frame_step)
34  axis = tensor_util.constant_value(axis)
35  if signal.shape.ndims is None:
36    return None
37  if axis is None:
38    return [None] * (signal.shape.ndims + 1)
39
40  signal_shape = signal.shape.as_list()
41  num_frames = None
42  frame_axis = signal_shape[axis]
43  outer_dimensions = signal_shape[:axis]
44  inner_dimensions = signal_shape[axis:][1:]
45  if signal_shape and frame_axis is not None:
46    if frame_step and frame_length is not None:
47      if pad_end:
48        # Double negative is so that we round up.
49        num_frames = -(-frame_axis // frame_step)
50      else:
51        num_frames = (frame_axis - frame_length + frame_step) // frame_step
52      num_frames = max(0, num_frames)
53  return outer_dimensions + [num_frames, frame_length] + inner_dimensions
54
55
56def frame(signal, frame_length, frame_step, pad_end=False, pad_value=0, axis=-1,
57          name=None):
58  """Expands `signal`'s `axis` dimension into frames of `frame_length`.
59
60  Slides a window of size `frame_length` over `signal`'s `axis` dimension
61  with a stride of `frame_step`, replacing the `axis` dimension with
62  `[frames, frame_length]` frames.
63
64  If `pad_end` is True, window positions that are past the end of the `axis`
65  dimension are padded with `pad_value` until the window moves fully past the
66  end of the dimension. Otherwise, only window positions that fully overlap the
67  `axis` dimension are produced.
68
69  For example:
70
71  ```python
72  pcm = tf.placeholder(tf.float32, [None, 9152])
73  frames = tf.contrib.signal.frame(pcm, 512, 180)
74  magspec = tf.abs(tf.spectral.rfft(frames, [512]))
75  image = tf.expand_dims(magspec, 3)
76  ```
77
78  Args:
79    signal: A `[..., samples, ...]` `Tensor`. The rank and dimensions
80      may be unknown. Rank must be at least 1.
81    frame_length: The frame length in samples. An integer or scalar `Tensor`.
82    frame_step: The frame hop size in samples. An integer or scalar `Tensor`.
83    pad_end: Whether to pad the end of `signal` with `pad_value`.
84    pad_value: An optional scalar `Tensor` to use where the input signal
85      does not exist when `pad_end` is True.
86    axis: A scalar integer `Tensor` indicating the axis to frame. Defaults to
87      the last axis. Supports negative values for indexing from the end.
88    name: An optional name for the operation.
89
90  Returns:
91    A `Tensor` of frames with shape `[..., frames, frame_length, ...]`.
92
93  Raises:
94    ValueError: If `frame_length`, `frame_step`, `pad_value`, or `axis` are not
95      scalar.
96  """
97  with ops.name_scope(name, "frame", [signal, frame_length, frame_step,
98                                      pad_value]):
99    signal = ops.convert_to_tensor(signal, name="signal")
100    frame_length = ops.convert_to_tensor(frame_length, name="frame_length")
101    frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
102    axis = ops.convert_to_tensor(axis, name="axis")
103
104    signal.shape.with_rank_at_least(1)
105    frame_length.shape.assert_has_rank(0)
106    frame_step.shape.assert_has_rank(0)
107    axis.shape.assert_has_rank(0)
108
109    result_shape = _infer_frame_shape(signal, frame_length, frame_step, pad_end,
110                                      axis)
111
112    # Axis can be negative. Convert it to positive.
113    signal_rank = array_ops.rank(signal)
114    axis = math_ops.range(signal_rank)[axis]
115
116    signal_shape = array_ops.shape(signal)
117    outer_dimensions, length_samples, inner_dimensions = array_ops.split(
118        signal_shape, [axis, 1, signal_rank - 1 - axis])
119    length_samples = array_ops.reshape(length_samples, [])
120    num_outer_dimensions = array_ops.size(outer_dimensions)
121    num_inner_dimensions = array_ops.size(inner_dimensions)
122
123    # If padding is requested, pad the input signal tensor with pad_value.
124    if pad_end:
125      pad_value = ops.convert_to_tensor(pad_value, signal.dtype)
126      pad_value.shape.assert_has_rank(0)
127
128      # Calculate number of frames, using double negatives to round up.
129      num_frames = -(-length_samples // frame_step)
130
131      # Pad the signal by up to frame_length samples based on how many samples
132      # are remaining starting from last_frame_position.
133      pad_samples = math_ops.maximum(
134          0, frame_length + frame_step * (num_frames - 1) - length_samples)
135
136      # Pad the inner dimension of signal by pad_samples.
137      paddings = array_ops.concat(
138          [array_ops.zeros([num_outer_dimensions, 2], dtype=pad_samples.dtype),
139           [[0, pad_samples]],
140           array_ops.zeros([num_inner_dimensions, 2], dtype=pad_samples.dtype)],
141          0)
142      signal = array_ops.pad(signal, paddings, constant_values=pad_value)
143
144      signal_shape = array_ops.shape(signal)
145      length_samples = signal_shape[axis]
146    else:
147      num_frames = math_ops.maximum(
148          0, 1 + (length_samples - frame_length) // frame_step)
149
150    subframe_length = util_ops.gcd(frame_length, frame_step)
151    subframes_per_frame = frame_length // subframe_length
152    subframes_per_hop = frame_step // subframe_length
153    num_subframes = length_samples // subframe_length
154
155    slice_shape = array_ops.concat([outer_dimensions,
156                                    [num_subframes * subframe_length],
157                                    inner_dimensions], 0)
158    subframe_shape = array_ops.concat([outer_dimensions,
159                                       [num_subframes, subframe_length],
160                                       inner_dimensions], 0)
161    subframes = array_ops.reshape(array_ops.strided_slice(
162        signal, array_ops.zeros_like(signal_shape),
163        slice_shape), subframe_shape)
164
165    # frame_selector is a [num_frames, subframes_per_frame] tensor
166    # that indexes into the appropriate frame in subframes. For example:
167    # [[0, 0, 0, 0], [2, 2, 2, 2], [4, 4, 4, 4]]
168    frame_selector = array_ops.reshape(
169        math_ops.range(num_frames) * subframes_per_hop, [num_frames, 1])
170
171    # subframe_selector is a [num_frames, subframes_per_frame] tensor
172    # that indexes into the appropriate subframe within a frame. For example:
173    # [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]
174    subframe_selector = array_ops.reshape(
175        math_ops.range(subframes_per_frame), [1, subframes_per_frame])
176
177    # Adding the 2 selector tensors together produces a [num_frames,
178    # subframes_per_frame] tensor of indices to use with tf.gather to select
179    # subframes from subframes. We then reshape the inner-most
180    # subframes_per_frame dimension to stitch the subframes together into
181    # frames. For example: [[0, 1, 2, 3], [2, 3, 4, 5], [4, 5, 6, 7]].
182    selector = frame_selector + subframe_selector
183
184    frames = array_ops.reshape(
185        array_ops.gather(subframes, selector, axis=axis),
186        array_ops.concat([outer_dimensions, [num_frames, frame_length],
187                          inner_dimensions], 0))
188
189    if result_shape:
190      frames.set_shape(result_shape)
191    return frames
192