1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Signal reconstruction via overlapped addition of frames.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.signal.python.ops import shape_ops 22from tensorflow.contrib.signal.python.ops import util_ops 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27 28 29def _shuffle_to_front(input_tensor, k): 30 """Shuffles the last `k` indices of `input_tensor` to the front. 31 32 Transposes `input_tensor` to have the last `k` indices at the front. The input 33 may have arbitrary rank and unknown shape. 34 35 Args: 36 input_tensor: A `Tensor` of arbitrary rank and unknown shape. 37 k: A scalar `Tensor` specifying how many indices to shuffle. 38 39 Returns: 40 A transposed version of `input_tensor` with `k` indices shuffled to the 41 front. 42 43 Raises: 44 ValueError: If `input_tensor` is not at least rank `k` or `k` is not scalar. 45 """ 46 k = ops.convert_to_tensor(k, name="k") 47 k.shape.with_rank(0) 48 k_static = tensor_util.constant_value(k) 49 if k_static is not None: 50 input_tensor.shape.with_rank_at_least(k_static) 51 52 rank = array_ops.rank(input_tensor) 53 outer_indices, inner_indices = array_ops.split(math_ops.range(rank), 54 [rank - k, k]) 55 permutation = array_ops.concat([inner_indices, outer_indices], 0) 56 57 return array_ops.transpose(input_tensor, perm=permutation) 58 59 60def overlap_and_add(signal, frame_step, name=None): 61 """Reconstructs a signal from a framed representation. 62 63 Adds potentially overlapping frames of a signal with shape 64 `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. 65 The resulting tensor has shape `[..., output_size]` where 66 67 output_size = (frames - 1) * frame_step + frame_length 68 69 Args: 70 signal: A [..., frames, frame_length] `Tensor`. All dimensions may be 71 unknown, and rank must be at least 2. 72 frame_step: An integer or scalar `Tensor` denoting overlap offsets. Must be 73 less than or equal to `frame_length`. 74 name: An optional name for the operation. 75 76 Returns: 77 A `Tensor` with shape `[..., output_size]` containing the overlap-added 78 frames of `signal`'s inner-most two dimensions. 79 80 Raises: 81 ValueError: If `signal`'s rank is less than 2, `frame_step` is not a scalar 82 integer or `frame_step` is greater than `frame_length`. 83 """ 84 with ops.name_scope(name, "overlap_and_add", [signal, frame_step]): 85 signal = ops.convert_to_tensor(signal, name="signal") 86 signal.shape.with_rank_at_least(2) 87 frame_step = ops.convert_to_tensor(frame_step, name="frame_step") 88 frame_step.shape.assert_has_rank(0) 89 if not frame_step.dtype.is_integer: 90 raise ValueError("frame_step must be an integer. Got %s" % 91 frame_step.dtype) 92 93 # If frame_length and frame_step are known at graph construction time, check 94 # frame_step is less than or equal to frame_length. 95 frame_step_static = tensor_util.constant_value(frame_step) 96 if (frame_step_static is not None and signal.shape.ndims is not None and 97 signal.shape[-1].value is not None and 98 frame_step_static > signal.shape[-1].value): 99 raise ValueError( 100 "frame_step (%d) must be less than or equal to frame_length (%d)" % ( 101 frame_step_static, signal.shape[-1].value)) 102 103 signal_shape = array_ops.shape(signal) 104 105 # All dimensions that are not part of the overlap-and-add. Can be empty for 106 # rank 2 inputs. 107 outer_dimensions = signal_shape[:-2] 108 109 signal_rank = array_ops.rank(signal) 110 frames = signal_shape[-2] 111 frame_length = signal_shape[-1] 112 113 subframe_length = util_ops.gcd(frame_length, frame_step) 114 subframe_step = frame_step // subframe_length 115 subframes_per_frame = frame_length // subframe_length 116 output_size = frame_step * (frames - 1) + frame_length 117 output_subframes = output_size // subframe_length 118 119 # To avoid overlap-adding sample-by-sample, we overlap-add at the "subframe" 120 # level, where a subframe is gcd(frame_length, frame_step). Reshape signal 121 # from [..., frames, frame_length] into [..., subframes, subframe_length]. 122 subframe_shape = array_ops.concat( 123 [outer_dimensions, [-1, subframe_length]], 0) 124 subframe_signal = array_ops.reshape(signal, subframe_shape) 125 126 # Now we shuffle the last [subframes, subframe_length] dimensions to the 127 # front. 128 # TODO(rjryan): Add an axis argument to unsorted_segment_sum so we can 129 # avoid this pair of transposes. 130 subframe_signal = _shuffle_to_front(subframe_signal, 2) 131 132 # Use unsorted_segment_sum to add overlapping subframes together. 133 segment_ids = array_ops.reshape(shape_ops.frame( 134 math_ops.range(output_subframes), subframes_per_frame, subframe_step, 135 pad_end=False), [-1]) 136 result = math_ops.unsorted_segment_sum(subframe_signal, segment_ids, 137 num_segments=output_subframes) 138 139 # result is a [subframes, subframe_length, ...outer_dimensions] tensor. We 140 # return a [...outer_dimensions, output_size] tensor with a transpose and 141 # reshape. 142 result_shape = array_ops.concat([outer_dimensions, [output_size]], 0) 143 return array_ops.reshape(_shuffle_to_front(result, signal_rank - 2), 144 result_shape) 145