1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Signal reconstruction via overlapped addition of frames."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.signal.python.ops import shape_ops
22from tensorflow.contrib.signal.python.ops import util_ops
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27
28
29def _shuffle_to_front(input_tensor, k):
30  """Shuffles the last `k` indices of `input_tensor` to the front.
31
32  Transposes `input_tensor` to have the last `k` indices at the front. The input
33  may have arbitrary rank and unknown shape.
34
35  Args:
36    input_tensor: A `Tensor` of arbitrary rank and unknown shape.
37    k: A scalar `Tensor` specifying how many indices to shuffle.
38
39  Returns:
40    A transposed version of `input_tensor` with `k` indices shuffled to the
41    front.
42
43  Raises:
44    ValueError: If `input_tensor` is not at least rank `k` or `k` is not scalar.
45  """
46  k = ops.convert_to_tensor(k, name="k")
47  k.shape.with_rank(0)
48  k_static = tensor_util.constant_value(k)
49  if k_static is not None:
50    input_tensor.shape.with_rank_at_least(k_static)
51
52  rank = array_ops.rank(input_tensor)
53  outer_indices, inner_indices = array_ops.split(math_ops.range(rank),
54                                                 [rank - k, k])
55  permutation = array_ops.concat([inner_indices, outer_indices], 0)
56
57  return array_ops.transpose(input_tensor, perm=permutation)
58
59
60def overlap_and_add(signal, frame_step, name=None):
61  """Reconstructs a signal from a framed representation.
62
63  Adds potentially overlapping frames of a signal with shape
64  `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
65  The resulting tensor has shape `[..., output_size]` where
66
67      output_size = (frames - 1) * frame_step + frame_length
68
69  Args:
70    signal: A [..., frames, frame_length] `Tensor`. All dimensions may be
71      unknown, and rank must be at least 2.
72    frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be
73      less than or equal to `frame_length`.
74    name: An optional name for the operation.
75
76  Returns:
77    A `Tensor` with shape `[..., output_size]` containing the overlap-added
78    frames of `signal`'s inner-most two dimensions.
79
80  Raises:
81    ValueError: If `signal`'s rank is less than 2, `frame_step` is not a scalar
82      integer or `frame_step` is greater than `frame_length`.
83  """
84  with ops.name_scope(name, "overlap_and_add", [signal, frame_step]):
85    signal = ops.convert_to_tensor(signal, name="signal")
86    signal.shape.with_rank_at_least(2)
87    frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
88    frame_step.shape.assert_has_rank(0)
89    if not frame_step.dtype.is_integer:
90      raise ValueError("frame_step must be an integer. Got %s" %
91                       frame_step.dtype)
92
93    # If frame_length and frame_step are known at graph construction time, check
94    # frame_step is less than or equal to frame_length.
95    frame_step_static = tensor_util.constant_value(frame_step)
96    if (frame_step_static is not None and signal.shape.ndims is not None and
97        signal.shape[-1].value is not None and
98        frame_step_static > signal.shape[-1].value):
99      raise ValueError(
100          "frame_step (%d) must be less than or equal to frame_length (%d)" % (
101              frame_step_static, signal.shape[-1].value))
102
103    signal_shape = array_ops.shape(signal)
104
105    # All dimensions that are not part of the overlap-and-add. Can be empty for
106    # rank 2 inputs.
107    outer_dimensions = signal_shape[:-2]
108
109    signal_rank = array_ops.rank(signal)
110    frames = signal_shape[-2]
111    frame_length = signal_shape[-1]
112
113    subframe_length = util_ops.gcd(frame_length, frame_step)
114    subframe_step = frame_step // subframe_length
115    subframes_per_frame = frame_length // subframe_length
116    output_size = frame_step * (frames - 1) + frame_length
117    output_subframes = output_size // subframe_length
118
119    # To avoid overlap-adding sample-by-sample, we overlap-add at the "subframe"
120    # level, where a subframe is gcd(frame_length, frame_step). Reshape signal
121    # from [..., frames, frame_length] into [..., subframes, subframe_length].
122    subframe_shape = array_ops.concat(
123        [outer_dimensions, [-1, subframe_length]], 0)
124    subframe_signal = array_ops.reshape(signal, subframe_shape)
125
126    # Now we shuffle the last [subframes, subframe_length] dimensions to the
127    # front.
128    # TODO(rjryan): Add an axis argument to unsorted_segment_sum so we can
129    # avoid this pair of transposes.
130    subframe_signal = _shuffle_to_front(subframe_signal, 2)
131
132    # Use unsorted_segment_sum to add overlapping subframes together.
133    segment_ids = array_ops.reshape(shape_ops.frame(
134        math_ops.range(output_subframes), subframes_per_frame, subframe_step,
135        pad_end=False), [-1])
136    result = math_ops.unsorted_segment_sum(subframe_signal, segment_ids,
137                                           num_segments=output_subframes)
138
139    # result is a [subframes, subframe_length, ...outer_dimensions] tensor. We
140    # return a [...outer_dimensions, output_size] tensor with a transpose and
141    # reshape.
142    result_shape = array_ops.concat([outer_dimensions, [output_size]], 0)
143    return array_ops.reshape(_shuffle_to_front(result, signal_rank - 2),
144                             result_shape)
145