1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not
4// use this file except in compliance with the License.  You may obtain a copy
5// of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
12// License for the specific language governing permissions and limitations under
13// the License.
14// ==============================================================================
15
16// TensorFlow kernels and Ops for constructing WALS normal equations.
17// TODO(agarwal,rmlarsen): Add security checks to the code.
18
19#include <algorithm>
20#include <numeric>
21#include <vector>
22
23// This is only used for std::this_thread::get_id()
24#include <thread>  // NOLINT
25
26#include "tensorflow/core/framework/op.h"
27#include "tensorflow/core/framework/op_kernel.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/tensor_shape.h"
30#include "tensorflow/core/framework/types.h"
31#include "tensorflow/core/lib/core/blocking_counter.h"
32#include "tensorflow/core/lib/core/errors.h"
33#include "tensorflow/core/lib/core/threadpool.h"
34#include "tensorflow/core/platform/mutex.h"
35
36using tensorflow::DEVICE_CPU;
37using tensorflow::DT_BOOL;
38using tensorflow::DT_FLOAT;
39using tensorflow::DT_INT64;
40using tensorflow::OpKernel;
41using tensorflow::OpKernelConstruction;
42using tensorflow::OpKernelContext;
43using tensorflow::Tensor;
44using tensorflow::TensorShape;
45using tensorflow::TensorShapeUtils;
46using tensorflow::errors::InvalidArgument;
47
48namespace tensorflow {
49
50// TODO(ataei): Consider using RowMajor maps.
51typedef Eigen::Map<
52    Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
53    EigenMatrixFloatMap;
54typedef Eigen::Map<
55    const Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
56    ConstEigenMatrixInt64Map;
57typedef Eigen::Map<
58    const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>
59    ConstEigenMatrixFloatMap;
60
61class WALSComputePartialLhsAndRhsOp : public OpKernel {
62 public:
63  explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context)
64      : OpKernel(context) {
65    OP_REQUIRES_OK(context, context->MatchSignature(
66                                {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT,
67                                 DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL},
68                                {DT_FLOAT, DT_FLOAT}));
69  }
70
71  void Compute(OpKernelContext* context) override {
72    const Tensor& factors = context->input(0);
73    const Tensor& factor_weights = context->input(1);
74    const Tensor& unobserved_weights = context->input(2);
75    const Tensor& input_weights = context->input(3);
76    const Tensor& input_indices = context->input(4);
77    const Tensor& input_values = context->input(5);
78    const Tensor& input_block_size = context->input(6);
79    const Tensor& input_is_transpose = context->input(7);
80
81    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()),
82                InvalidArgument("Input factors should be a matrix."));
83    OP_REQUIRES(context, TensorShapeUtils::IsVector(factor_weights.shape()),
84                InvalidArgument("Input factor_weights should be a vector."));
85    OP_REQUIRES(
86        context, TensorShapeUtils::IsScalar(unobserved_weights.shape()),
87        InvalidArgument("Input unobserved_weights should be a scalar."));
88    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_weights.shape()),
89                InvalidArgument("Input input_weights should be a vector."));
90    OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
91                InvalidArgument("Input input_indices should be a matrix."));
92    OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()),
93                InvalidArgument("Input input_values should be a vector"));
94    OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()),
95                InvalidArgument("Input input_block_size should be a scalar."));
96    OP_REQUIRES(
97        context, TensorShapeUtils::IsScalar(input_is_transpose.shape()),
98        InvalidArgument("Input input_is_transpose should be a scalar."));
99
100    const int64 factor_dim = factors.dim_size(1);
101    const int64 factors_size = factors.dim_size(0);
102    const int64 num_nonzero_elements = input_indices.dim_size(0);
103    const int64 block_size = input_block_size.scalar<int64>()();
104    const auto& factor_weights_vec = factor_weights.vec<float>();
105    const auto& input_weights_vec = input_weights.vec<float>();
106    const float w_0 = unobserved_weights.scalar<float>()();
107    const auto& input_values_vec = input_values.vec<float>();
108
109    ConstEigenMatrixFloatMap factors_mat(factors.matrix<float>().data(),
110                                         factor_dim, factors_size);
111    ConstEigenMatrixInt64Map indices_mat(input_indices.matrix<int64>().data(),
112                                         2, num_nonzero_elements);
113
114    Tensor* output_lhs_tensor;
115    OP_REQUIRES_OK(context,
116                   context->allocate_output(
117                       0, TensorShape({block_size, factor_dim, factor_dim}),
118                       &output_lhs_tensor));
119    auto output_lhs_t = output_lhs_tensor->tensor<float, 3>();
120    output_lhs_t.setZero();
121    Tensor* output_rhs_tensor;
122    OP_REQUIRES_OK(context, context->allocate_output(
123                                1, TensorShape({block_size, factor_dim}),
124                                &output_rhs_tensor));
125    EigenMatrixFloatMap rhs_mat(output_rhs_tensor->matrix<float>().data(),
126                                factor_dim, block_size);
127    rhs_mat.setZero();
128    const bool is_transpose = input_is_transpose.scalar<bool>()();
129
130    auto get_input_index = [is_transpose, &indices_mat](int64 i) {
131      return is_transpose ? indices_mat(1, i) : indices_mat(0, i);
132    };
133    auto get_factor_index = [is_transpose, &indices_mat](int64 i) {
134      return is_transpose ? indices_mat(0, i) : indices_mat(1, i);
135    };
136
137    // TODO(rmlarsen): In principle, we should be using the SparseTensor class
138    // and machinery for iterating over groups, but the fact that class
139    // SparseTensor makes a complete copy of the matrix makes me reluctant to
140    // use it.
141    std::vector<int64> perm(num_nonzero_elements);
142    std::iota(perm.begin(), perm.end(), 0);
143
144    typedef std::pair<int64, int64> Shard;
145    std::vector<Shard> shards;
146    auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
147    int64 shard_total = 0;
148    // Compute a permutation such that get_input_index(perm[i]) is sorted, use
149    // stable_sort to preserve spatial locality.
150    std::stable_sort(perm.begin(), perm.end(),
151                     [&get_input_index](int64 i, int64 j) {
152                       return get_input_index(i) < get_input_index(j);
153                     });
154
155    // Compute the start and end of runs with identical input_index.
156    // These are the shards of work that can be processed in parallel
157    // without locking.
158    int64 start = 0;
159    int64 end = 0;
160    while (end < num_nonzero_elements) {
161      start = end;
162      while (end < num_nonzero_elements &&
163             get_input_index(perm[start]) == get_input_index(perm[end])) {
164        ++end;
165      }
166      shards.emplace_back(start, end);
167      shard_total += end - start;
168    }
169    CHECK_EQ(shard_total, num_nonzero_elements);
170    CHECK_LE(shards.size(), num_nonzero_elements);
171    CHECK_GT(shards.size(), 0);
172
173    // Batch the rank-one updates into a rank-k update to lower memory traffic
174    const int kMaxBatchSize = 128;
175
176    // Since we do not have an easy way of generating thread id's within the
177    // range [0,num_threads), we can instead call out to an std::unordered_map
178    // of matrices and initialize the matrix on the first call.
179    // However, this might have a performance penalty, as memory allocation can
180    // cause the OS kernel to enter a critical section and temporarily disable
181    // parallelism, and the unordered_map must be protected with a read/write
182    // mutex.
183    //
184    // TODO(jpoulson): Simplify after the thread rank can be queried
185    std::unordered_map<size_t, Eigen::MatrixXf> factor_batch_map;
186    mutex map_mutex;
187
188    BlockingCounter counter(shards.size());
189    // Lambda encapsulating the per-shard computation.
190    auto work = [&](const Shard& shard) {
191      const std::thread::id thread_id = std::this_thread::get_id();
192      const size_t id_hash = std::hash<std::thread::id>()(thread_id);
193      // If this thread's unique factors_mat.rows() x kMaxBatchSize
194      // batching matrix has not yet been created, then emplace it into the
195      // map using the hash of the thread id as the key.
196      //
197      // TODO(jpoulson): Switch to try_emplace once C++17 is supported
198      map_mutex.lock();
199      const auto key_count = factor_batch_map.count(id_hash);
200      map_mutex.unlock();
201      if (!key_count) {
202        map_mutex.lock();
203        factor_batch_map.emplace(
204            std::piecewise_construct, std::forward_as_tuple(id_hash),
205            std::forward_as_tuple(factors_mat.rows(), kMaxBatchSize));
206        map_mutex.unlock();
207      }
208      map_mutex.lock();
209      auto& factor_batch = factor_batch_map[id_hash];
210      map_mutex.unlock();
211
212      CHECK_GE(shard.first, 0);
213      CHECK_LE(shard.second, perm.size());
214      CHECK_LE(shard.first, shard.second);
215      const int64 input_index = get_input_index(perm[shard.first]);
216      // Accumulate the rhs and lhs terms in the normal equations
217      // for the non-zero elements in the row or column of the sparse matrix
218      // corresponding to input_index.
219      int num_batched = 0;
220      EigenMatrixFloatMap lhs_mat(output_lhs_tensor->flat<float>().data() +
221                                      input_index * factor_dim * factor_dim,
222                                  factor_dim, factor_dim);
223      auto lhs_symm = lhs_mat.selfadjointView<Eigen::Lower>();
224      for (int64 p = shard.first; p < shard.second; ++p) {
225        const int64 i = perm[p];
226        // Check that all entries in the shard have the same input index.
227        CHECK_EQ(input_index, get_input_index(i));
228        const int64 factor_index = get_factor_index(i);
229        const float input_value = input_values_vec(i);
230        const float weight =
231            input_weights_vec(input_index) * factor_weights_vec(factor_index);
232        CHECK_GE(weight, 0);
233        factor_batch.col(num_batched) =
234            factors_mat.col(factor_index) * std::sqrt(weight);
235        ++num_batched;
236        if (num_batched == kMaxBatchSize) {
237          lhs_symm.rankUpdate(factor_batch);
238          num_batched = 0;
239        }
240
241        rhs_mat.col(input_index) +=
242            input_value * (w_0 + weight) * factors_mat.col(factor_index);
243      }
244      if (num_batched != 0) {
245        auto factor_block =
246            factor_batch.block(0, 0, factors_mat.rows(), num_batched);
247        lhs_symm.rankUpdate(factor_block);
248      }
249      // Copy lower triangular to upper triangular part of normal equation
250      // matrix.
251      lhs_mat = lhs_symm;
252      counter.DecrementCount();
253    };
254    for (size_t i = 1; i < shards.size(); ++i) {
255      worker_threads.workers->Schedule(std::bind(work, shards[i]));
256    }
257    // Inline execute the 1st shard.
258    work(shards[0]);
259    counter.Wait();
260  }
261};
262
263REGISTER_KERNEL_BUILDER(Name("WALSComputePartialLhsAndRhs").Device(DEVICE_CPU),
264                        WALSComputePartialLhsAndRhsOp);
265
266}  // namespace tensorflow
267