1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may not 4// use this file except in compliance with the License. You may obtain a copy 5// of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 11// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 12// License for the specific language governing permissions and limitations under 13// the License. 14// ============================================================================== 15 16// TensorFlow kernels and Ops for constructing WALS normal equations. 17// TODO(agarwal,rmlarsen): Add security checks to the code. 18 19#include <algorithm> 20#include <numeric> 21#include <vector> 22 23// This is only used for std::this_thread::get_id() 24#include <thread> // NOLINT 25 26#include "tensorflow/core/framework/op.h" 27#include "tensorflow/core/framework/op_kernel.h" 28#include "tensorflow/core/framework/tensor.h" 29#include "tensorflow/core/framework/tensor_shape.h" 30#include "tensorflow/core/framework/types.h" 31#include "tensorflow/core/lib/core/blocking_counter.h" 32#include "tensorflow/core/lib/core/errors.h" 33#include "tensorflow/core/lib/core/threadpool.h" 34#include "tensorflow/core/platform/mutex.h" 35 36using tensorflow::DEVICE_CPU; 37using tensorflow::DT_BOOL; 38using tensorflow::DT_FLOAT; 39using tensorflow::DT_INT64; 40using tensorflow::OpKernel; 41using tensorflow::OpKernelConstruction; 42using tensorflow::OpKernelContext; 43using tensorflow::Tensor; 44using tensorflow::TensorShape; 45using tensorflow::TensorShapeUtils; 46using tensorflow::errors::InvalidArgument; 47 48namespace tensorflow { 49 50// TODO(ataei): Consider using RowMajor maps. 51typedef Eigen::Map< 52 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> 53 EigenMatrixFloatMap; 54typedef Eigen::Map< 55 const Eigen::Matrix<int64, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> 56 ConstEigenMatrixInt64Map; 57typedef Eigen::Map< 58 const Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> 59 ConstEigenMatrixFloatMap; 60 61class WALSComputePartialLhsAndRhsOp : public OpKernel { 62 public: 63 explicit WALSComputePartialLhsAndRhsOp(OpKernelConstruction* context) 64 : OpKernel(context) { 65 OP_REQUIRES_OK(context, context->MatchSignature( 66 {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, 67 DT_INT64, DT_FLOAT, DT_INT64, DT_BOOL}, 68 {DT_FLOAT, DT_FLOAT})); 69 } 70 71 void Compute(OpKernelContext* context) override { 72 const Tensor& factors = context->input(0); 73 const Tensor& factor_weights = context->input(1); 74 const Tensor& unobserved_weights = context->input(2); 75 const Tensor& input_weights = context->input(3); 76 const Tensor& input_indices = context->input(4); 77 const Tensor& input_values = context->input(5); 78 const Tensor& input_block_size = context->input(6); 79 const Tensor& input_is_transpose = context->input(7); 80 81 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(factors.shape()), 82 InvalidArgument("Input factors should be a matrix.")); 83 OP_REQUIRES(context, TensorShapeUtils::IsVector(factor_weights.shape()), 84 InvalidArgument("Input factor_weights should be a vector.")); 85 OP_REQUIRES( 86 context, TensorShapeUtils::IsScalar(unobserved_weights.shape()), 87 InvalidArgument("Input unobserved_weights should be a scalar.")); 88 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_weights.shape()), 89 InvalidArgument("Input input_weights should be a vector.")); 90 OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), 91 InvalidArgument("Input input_indices should be a matrix.")); 92 OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()), 93 InvalidArgument("Input input_values should be a vector")); 94 OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_block_size.shape()), 95 InvalidArgument("Input input_block_size should be a scalar.")); 96 OP_REQUIRES( 97 context, TensorShapeUtils::IsScalar(input_is_transpose.shape()), 98 InvalidArgument("Input input_is_transpose should be a scalar.")); 99 100 const int64 factor_dim = factors.dim_size(1); 101 const int64 factors_size = factors.dim_size(0); 102 const int64 num_nonzero_elements = input_indices.dim_size(0); 103 const int64 block_size = input_block_size.scalar<int64>()(); 104 const auto& factor_weights_vec = factor_weights.vec<float>(); 105 const auto& input_weights_vec = input_weights.vec<float>(); 106 const float w_0 = unobserved_weights.scalar<float>()(); 107 const auto& input_values_vec = input_values.vec<float>(); 108 109 ConstEigenMatrixFloatMap factors_mat(factors.matrix<float>().data(), 110 factor_dim, factors_size); 111 ConstEigenMatrixInt64Map indices_mat(input_indices.matrix<int64>().data(), 112 2, num_nonzero_elements); 113 114 Tensor* output_lhs_tensor; 115 OP_REQUIRES_OK(context, 116 context->allocate_output( 117 0, TensorShape({block_size, factor_dim, factor_dim}), 118 &output_lhs_tensor)); 119 auto output_lhs_t = output_lhs_tensor->tensor<float, 3>(); 120 output_lhs_t.setZero(); 121 Tensor* output_rhs_tensor; 122 OP_REQUIRES_OK(context, context->allocate_output( 123 1, TensorShape({block_size, factor_dim}), 124 &output_rhs_tensor)); 125 EigenMatrixFloatMap rhs_mat(output_rhs_tensor->matrix<float>().data(), 126 factor_dim, block_size); 127 rhs_mat.setZero(); 128 const bool is_transpose = input_is_transpose.scalar<bool>()(); 129 130 auto get_input_index = [is_transpose, &indices_mat](int64 i) { 131 return is_transpose ? indices_mat(1, i) : indices_mat(0, i); 132 }; 133 auto get_factor_index = [is_transpose, &indices_mat](int64 i) { 134 return is_transpose ? indices_mat(0, i) : indices_mat(1, i); 135 }; 136 137 // TODO(rmlarsen): In principle, we should be using the SparseTensor class 138 // and machinery for iterating over groups, but the fact that class 139 // SparseTensor makes a complete copy of the matrix makes me reluctant to 140 // use it. 141 std::vector<int64> perm(num_nonzero_elements); 142 std::iota(perm.begin(), perm.end(), 0); 143 144 typedef std::pair<int64, int64> Shard; 145 std::vector<Shard> shards; 146 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); 147 int64 shard_total = 0; 148 // Compute a permutation such that get_input_index(perm[i]) is sorted, use 149 // stable_sort to preserve spatial locality. 150 std::stable_sort(perm.begin(), perm.end(), 151 [&get_input_index](int64 i, int64 j) { 152 return get_input_index(i) < get_input_index(j); 153 }); 154 155 // Compute the start and end of runs with identical input_index. 156 // These are the shards of work that can be processed in parallel 157 // without locking. 158 int64 start = 0; 159 int64 end = 0; 160 while (end < num_nonzero_elements) { 161 start = end; 162 while (end < num_nonzero_elements && 163 get_input_index(perm[start]) == get_input_index(perm[end])) { 164 ++end; 165 } 166 shards.emplace_back(start, end); 167 shard_total += end - start; 168 } 169 CHECK_EQ(shard_total, num_nonzero_elements); 170 CHECK_LE(shards.size(), num_nonzero_elements); 171 CHECK_GT(shards.size(), 0); 172 173 // Batch the rank-one updates into a rank-k update to lower memory traffic 174 const int kMaxBatchSize = 128; 175 176 // Since we do not have an easy way of generating thread id's within the 177 // range [0,num_threads), we can instead call out to an std::unordered_map 178 // of matrices and initialize the matrix on the first call. 179 // However, this might have a performance penalty, as memory allocation can 180 // cause the OS kernel to enter a critical section and temporarily disable 181 // parallelism, and the unordered_map must be protected with a read/write 182 // mutex. 183 // 184 // TODO(jpoulson): Simplify after the thread rank can be queried 185 std::unordered_map<size_t, Eigen::MatrixXf> factor_batch_map; 186 mutex map_mutex; 187 188 BlockingCounter counter(shards.size()); 189 // Lambda encapsulating the per-shard computation. 190 auto work = [&](const Shard& shard) { 191 const std::thread::id thread_id = std::this_thread::get_id(); 192 const size_t id_hash = std::hash<std::thread::id>()(thread_id); 193 // If this thread's unique factors_mat.rows() x kMaxBatchSize 194 // batching matrix has not yet been created, then emplace it into the 195 // map using the hash of the thread id as the key. 196 // 197 // TODO(jpoulson): Switch to try_emplace once C++17 is supported 198 map_mutex.lock(); 199 const auto key_count = factor_batch_map.count(id_hash); 200 map_mutex.unlock(); 201 if (!key_count) { 202 map_mutex.lock(); 203 factor_batch_map.emplace( 204 std::piecewise_construct, std::forward_as_tuple(id_hash), 205 std::forward_as_tuple(factors_mat.rows(), kMaxBatchSize)); 206 map_mutex.unlock(); 207 } 208 map_mutex.lock(); 209 auto& factor_batch = factor_batch_map[id_hash]; 210 map_mutex.unlock(); 211 212 CHECK_GE(shard.first, 0); 213 CHECK_LE(shard.second, perm.size()); 214 CHECK_LE(shard.first, shard.second); 215 const int64 input_index = get_input_index(perm[shard.first]); 216 // Accumulate the rhs and lhs terms in the normal equations 217 // for the non-zero elements in the row or column of the sparse matrix 218 // corresponding to input_index. 219 int num_batched = 0; 220 EigenMatrixFloatMap lhs_mat(output_lhs_tensor->flat<float>().data() + 221 input_index * factor_dim * factor_dim, 222 factor_dim, factor_dim); 223 auto lhs_symm = lhs_mat.selfadjointView<Eigen::Lower>(); 224 for (int64 p = shard.first; p < shard.second; ++p) { 225 const int64 i = perm[p]; 226 // Check that all entries in the shard have the same input index. 227 CHECK_EQ(input_index, get_input_index(i)); 228 const int64 factor_index = get_factor_index(i); 229 const float input_value = input_values_vec(i); 230 const float weight = 231 input_weights_vec(input_index) * factor_weights_vec(factor_index); 232 CHECK_GE(weight, 0); 233 factor_batch.col(num_batched) = 234 factors_mat.col(factor_index) * std::sqrt(weight); 235 ++num_batched; 236 if (num_batched == kMaxBatchSize) { 237 lhs_symm.rankUpdate(factor_batch); 238 num_batched = 0; 239 } 240 241 rhs_mat.col(input_index) += 242 input_value * (w_0 + weight) * factors_mat.col(factor_index); 243 } 244 if (num_batched != 0) { 245 auto factor_block = 246 factor_batch.block(0, 0, factors_mat.rows(), num_batched); 247 lhs_symm.rankUpdate(factor_block); 248 } 249 // Copy lower triangular to upper triangular part of normal equation 250 // matrix. 251 lhs_mat = lhs_symm; 252 counter.DecrementCount(); 253 }; 254 for (size_t i = 1; i < shards.size(); ++i) { 255 worker_threads.workers->Schedule(std::bind(work, shards[i])); 256 } 257 // Inline execute the 1st shard. 258 work(shards[0]); 259 counter.Wait(); 260 } 261}; 262 263REGISTER_KERNEL_BUILDER(Name("WALSComputePartialLhsAndRhs").Device(DEVICE_CPU), 264 WALSComputePartialLhsAndRhsOp); 265 266} // namespace tensorflow 267