126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFloweryou may not use this file except in compliance with the License. 526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerYou may obtain a copy of the License at 626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 1026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 1126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerSee the License for the specific language governing permissions and 1326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerlimitations under the License. 1426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower==============================================================================*/ 1526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 1626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower#ifndef TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_ 1726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower#define TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_ 1826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 1926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower#include "tensorflow/core/kernels/typed_conditional_accumulator_base.h" 2026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 2126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowernamespace tensorflow { 2226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 2326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower/** 2426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * An aggregation object for adding sparse gradients, represented as a tuple of 2526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * indices, values, and a (possibly empty) shape. 2626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * 2726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * The two main methods of this class are TryApplyGrad and TryTakeGrad. 2826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * 2926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * TryApplyGrad tries add a gradient to the accumulator. The attempt is 3026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * successful if local_step >= global_step, i.e., if the gradient is not stale, 3126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * having been computed using up-to-date information. Otherwise, the gradient is 3226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * silently dropped. 3326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * 3426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * TryTakeGrad logs an attempt to read the average gradient. The attempt is 3526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * blocked until the number of gradients accumulated (via TryApplyGrad) is equal 3626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * or exceeds the number requested by TryTakeGrad. 3726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * Once this condition is satisfied, the following actions are taken: 3826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * (1) the value of the average gradient is returned 3926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * (2) the count of accumulated gradients is reset to 0 4026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * (3) the internal global_step value (current_global_step_) is incremented by 1 4126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * 4226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * SparseConditionalAccumulator is the datatype-dependent templated sub-class of 4326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * ConditionalAccumulatorBase. It implements the virtual arithmetic methods that 4426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * are used by for aggregating, averaging, allocating, returning indexed slices. 4526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower */ 4626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowertemplate <typename Device, typename T> 4726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerclass SparseConditionalAccumulator 4826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower : public TypedConditionalAccumulatorBase< 4926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::tuple<const Tensor*, const Tensor*, const Tensor*>> { 5026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower public: 5126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SparseConditionalAccumulator(const DataType& dtype, 5226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const PartialTensorShape& shape, 5326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const string& name) 5426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower : TypedConditionalAccumulatorBase< 5526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::tuple<const Tensor*, const Tensor*, const Tensor*>>( 5626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower dtype, shape, name) { 5726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_idx_vec_ = nullptr; 5826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower count_element_ = nullptr; 5926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_val_ = nullptr; 6026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_val_persistent_ = new PersistentTensor(); 6126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 6226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 6326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ~SparseConditionalAccumulator() override { 6426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (accum_idx_vec_ != nullptr) delete accum_idx_vec_; 6526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (count_element_ != nullptr) delete count_element_; 6626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (accum_val_persistent_ != nullptr) delete accum_val_persistent_; 6726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Do not delete accum_val_! Will be automatically garbage collected 6826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower }; 6926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 7026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower protected: 7126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::vector<int64>* accum_idx_vec_ = nullptr; 7226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::vector<int>* count_element_ = nullptr; 7326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 7426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Tensor* accum_val_ = nullptr; 7526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower PersistentTensor* accum_val_persistent_ = nullptr; 7626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 7726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>, 7826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Eigen::Unaligned> 7926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SliceT; 8026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>, 8126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Eigen::Unaligned> 8226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SliceConstT; 8326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 8426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Status ValidateShape( 8526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::tuple<const Tensor*, const Tensor*, const Tensor*>* tensor, 8626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower bool has_known_shape) EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 8726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* tensor_idx = std::get<0>(*tensor); 8826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* tensor_val = std::get<1>(*tensor); 8926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* tensor_shape = std::get<2>(*tensor); 9026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower int64 grad_val_dims = tensor_val->dims(); 9126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower int64 grad_dims = grad_val_dims; 9226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 9326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Compare with provided shape 9426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (has_known_shape) { 9526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (shape_.dims() > tensor_shape->NumElements()) { 9626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return errors::InvalidArgument( 9726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower "Shape mismatch: expected shape rank at least ", shape_.dims(), 9826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ", got ", tensor_shape->NumElements()); 9926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 10026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const auto tensor_shape_flat = tensor_shape->flat<int64>(); 10126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (int64 i = 0; i < shape_.dims(); i++) { 10226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (shape_.dim_size(i) != -1 && 10326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower shape_.dim_size(i) != tensor_shape_flat(i)) { 10426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return errors::InvalidArgument("Shape mismatch: expected shape dim ", 10526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower i, " to be ", shape_.dim_size(i), 10626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ", got ", tensor_shape_flat(i)); 10726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 10826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 10926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 11026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Check that indices are within limits 11126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (shape_.dims() > 0 && shape_.dim_size(0) != -1 && 11226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower tensor_idx->dims() > 0) { 11326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (int64 i = 0; i < tensor_idx->dim_size(0); i++) { 11426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (tensor_idx->vec<int64>()(i) >= shape_.dim_size(0)) { 11526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return errors::InvalidArgument( 11626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower "Shape mismatch: index of slice ", i, " exceeded limits of shape", 11726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower "; index is ", tensor_idx->vec<int64>()(i), " exceeded ", 11826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower shape_.dim_size(0)); 11926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 12026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 12126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 12226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 12326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Check values compatibility with accumulated gradient if available 12426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (counter_ > 0) { 12526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower int64 accum_val_dims = accum_val_->dims(); 12626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (accum_val_dims != grad_val_dims) { 12726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return errors::InvalidArgument("Shape mismatch: expected values rank ", 12826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_val_dims, ", got ", grad_val_dims); 12926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 13026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (int64 i = 1; i < accum_val_dims; i++) { 13126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (accum_val_->dim_size(i) != tensor_val->dim_size(i)) { 13226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return errors::InvalidArgument("Shape mismatch: expected values dim ", 13326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower i, " to be ", accum_val_->dim_size(i), 13426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ", got ", tensor_val->dim_size(i)); 13526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 13626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 13726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } else { 13826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // If there are no accumulated gradients, check against shape_ 13926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (shape_.dims() > grad_dims) { 14026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return errors::InvalidArgument( 14126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower "Shape mismatch: expected values rank at least ", shape_.dims(), 14226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ", got ", grad_dims); 14326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 14426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Check that values have correct dimensions 14526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (int64 i = 1; i < shape_.dims(); i++) { 14626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (shape_.dim_size(i) != -1 && 14726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower shape_.dim_size(i) != tensor_val->dim_size(i)) { 14826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return errors::InvalidArgument("Shape mismatch: expected values dim ", 14926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower i, " to be ", shape_.dim_size(i), 15026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ", got ", tensor_val->dim_size(i)); 15126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 15226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 15326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 15426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 15526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return Status::OK(); 15626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 15726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 15826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower void AllocateAndAssignToAccumGradFunction( 15926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OpKernelContext* ctx, 16026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::tuple<const Tensor*, const Tensor*, const Tensor*>* grad) override { 16126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* grad_idx = std::get<0>(*grad); 16226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* grad_val = std::get<1>(*grad); 16326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 16426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 nnz = grad_idx->dim_size(0); 16526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 16626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Assign indices 16726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (accum_idx_vec_ != nullptr) delete accum_idx_vec_; 16826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_idx_vec_ = new std::vector<int64>(); 16926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_idx_vec_->reserve(nnz); 17026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (int i = 0; i < nnz; i++) { 17126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_idx_vec_->push_back(grad_idx->vec<int64>()(i)); 17226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 17326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 17426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Assign values to accum_val_tensor 175ef7d75baec1b0b3861acef52f1973bbe379ae881Justin Lebar // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! 17626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx->allocate_persistent(dtype_, grad_val->shape(), accum_val_persistent_, 177ef7d75baec1b0b3861acef52f1973bbe379ae881Justin Lebar &accum_val_) 178ef7d75baec1b0b3861acef52f1973bbe379ae881Justin Lebar .IgnoreError(); 17926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_val_->flat<T>().device(ctx->template eigen_device<Device>()) = 18026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower grad_val->flat<T>(); 18126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 18226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Assign count_element_ 18326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (count_element_ != nullptr) { 18426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower delete count_element_; 18526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 18626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower count_element_ = new std::vector<int>(nnz, 1); 18726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 18826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Do not need shape; Assume that the op has checked that the shapes match, 18926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // so grad's shape == shape_ 19026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 19126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 19226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower void AddToAccumGradFunction( 19326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OpKernelContext* ctx, 19426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::tuple<const Tensor*, const Tensor*, const Tensor*>* grad) override { 19526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Modeled after third_party/tensorflow/core/kernels/sparse_add_op 19626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 19726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* grad_idx = std::get<0>(*grad); 19826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* grad_val = std::get<1>(*grad); 19926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 20026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 accum_nnz = accum_idx_vec_->size(); 20126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 grad_nnz = grad_idx->dim_size(0); 20226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 20326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Source enumerates the origin of a non-zero element: whether it is from 20426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // the new gradient, the accumulated gradient, or the sum of both. 20526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower enum Source { from_accum, from_grad, from_accum_and_grad }; 20626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 20726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // (1) do a pass over inputs, and append values and indices to vectors 20826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::vector<std::tuple<Source, int64, int64>> entries_to_copy; 20926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower entries_to_copy.reserve(accum_nnz + grad_nnz); 21026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 21126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Pass over all non-zero elements of both the gradient and the accumulated 21226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // value, to identify where each non-zero element of the sum comes from. 21326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // The input and output indexed slices are assumed to be ordered along 21426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // increasing dimension number. 21526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower int64 i = 0, j = 0; 21626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower int64 sum_nnz = 0; 21726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower while (i < accum_nnz && j < grad_nnz) { 21826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_nnz++; 21926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower switch (cmp(accum_idx_vec_, grad_idx, i, j)) { 22026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower case -1: 22126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower entries_to_copy.emplace_back(from_accum, i, -1); 22226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ++i; 22326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower break; 22426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower case 0: 22526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower entries_to_copy.emplace_back(from_accum_and_grad, i, j); 22626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ++i; 22726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ++j; 22826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower break; 22926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower case 1: 23026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower entries_to_copy.emplace_back(from_grad, -1, j); 23126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ++j; 23226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower break; 23326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 23426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 23526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 23626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Handle leftovers 23726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower while (i < accum_nnz) { 23826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_nnz++; 23926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower entries_to_copy.emplace_back(from_accum, i, -1); 24026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ++i; 24126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 24226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower while (j < grad_nnz) { 24326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_nnz++; 24426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower entries_to_copy.emplace_back(from_grad, -1, j); 24526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ++j; 24626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 24726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 24826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // (2) Copy or sum the non-zero elements into sum_indices and sum_tensor 24926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::vector<int64>* sum_indices_vec = new std::vector<int64>(); 25026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_indices_vec->reserve(sum_nnz); 25126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 25226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::vector<int>* sum_counts = new std::vector<int>(); 25326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_counts->reserve(sum_nnz); 25426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 25526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Tensor* sum_tensor = nullptr; 25626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower PersistentTensor* tensor_sum_persistent = new PersistentTensor(); 25726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 25826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower TensorShape sum_shape = grad_val->shape(); 25926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_shape.set_dim(0, sum_nnz); 26026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 26126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_OK( 26226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx, ctx->allocate_persistent(dtype_, sum_shape, tensor_sum_persistent, 26326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower &sum_tensor)); 26426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower auto sum_flat = sum_tensor->flat_outer_dims<T>(); 26526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower auto accum_flat = accum_val_->flat_outer_dims<T>(); 26626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower auto grad_flat = grad_val->flat_outer_dims<T>(); 26726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 26826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 num_col = grad_flat.dimension(1); 26926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 27026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Eigen::DSizes<Eigen::DenseIndex, 1> slice_shape(num_col); 27126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 27226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (i = 0; i < sum_nnz; ++i) { 27326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Source src = std::get<0>(entries_to_copy[i]); 27426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 idx_a = std::get<1>(entries_to_copy[i]); 27526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 idx_b = std::get<2>(entries_to_copy[i]); 27626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower T* sum_slice_ptr = &sum_flat(i, 0); 27726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SliceT sum_slice(sum_slice_ptr, slice_shape); 27826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (src == from_accum) { 27926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Element comes from accumulator; directly copy data structures over 28026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_indices_vec->push_back(accum_idx_vec_->at(idx_a)); 28126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower T* accum_slice_ptr = &accum_flat(idx_a, 0); 28226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SliceT accum_slice(accum_slice_ptr, slice_shape); 28326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_slice = accum_slice; 28426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_counts->push_back(count_element_->at(idx_a)); 28526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } else if (src == from_accum_and_grad) { 28626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Element is a sum of accumulated value and new gradient; 28726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // compute sum here 28826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_indices_vec->push_back(accum_idx_vec_->at(idx_a)); 28926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const T* grad_slice_ptr = &grad_flat(idx_b, 0); 29026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SliceConstT grad_slice(grad_slice_ptr, slice_shape); 29126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower T* accum_slice_ptr = &accum_flat(idx_a, 0); 29226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SliceT accum_slice(accum_slice_ptr, slice_shape); 29326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_slice = grad_slice + accum_slice; 29426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_counts->push_back(count_element_->at(idx_a) + 1); 29526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } else if (src == from_grad) { 29626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Element comes from new gradient; make a copy of indices and values 29726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_indices_vec->push_back(grad_idx->vec<int64>()(idx_b)); 29826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const T* grad_slice_ptr = &grad_flat(idx_b, 0); 29926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SliceConstT grad_slice(grad_slice_ptr, slice_shape); 30026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_slice = grad_slice; 30126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower sum_counts->push_back(1); 30226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 30326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 30426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 30526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // (3) Keep output, i.e., switch pointers to point to new data structures 30626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // representing the sum 30726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Indices 30826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (accum_idx_vec_ != nullptr) delete accum_idx_vec_; 30926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_idx_vec_ = sum_indices_vec; 31026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Values 31126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_val_ = sum_tensor; 31226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower delete accum_val_persistent_; 31326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_val_persistent_ = tensor_sum_persistent; 31426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Counts 31526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (count_element_ != nullptr) delete count_element_; 31626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower count_element_ = sum_counts; 31726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 31826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // No need to copy shape, since shape remains the same after sum. 31926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 32026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 32126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower void DivideAccumGradByCounter(OpKernelContext* ctx) override 32226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 32326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 nnz = count_element_->size(); 32426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower auto accum_flat = accum_val_->flat_outer_dims<T>(); 32526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::vector<T> count_typet; 32626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::transform(count_element_->begin(), count_element_->end(), 32726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::back_inserter(count_typet), 32826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower TypeConverter<T, int>::ConvertUToT); 32926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 33026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Option 1: divide all by counter 33126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower /* 33226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::transform( 33326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower &accum_flat(0,0), &accum_flat(nnz,0), &accum_flat(0,0), 33426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::bind2nd(std::divides<T>(), 33526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower TypeConverter<T, int>::ConvertUToT(this->counter_))); 33626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower */ 33726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 33826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Option 2: average element-wise 33926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Eigen::DSizes<Eigen::DenseIndex, 1> slice_shape(accum_flat.dimension(1)); 34026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (int64 i = 0; i < nnz; i++) { 34126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower T* accum_slice_ptr = &accum_flat(i, 0); 34226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower SliceT accum_slice(accum_slice_ptr, slice_shape); 34326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_slice.device(ctx->template eigen_device<Device>()) = 34426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower accum_slice / count_typet[i]; 34526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 34626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 34726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 34826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower bool SetOutput(OpKernelContext* ctx) override { 34926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower bool is_successful = true; 35026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (is_successful) is_successful = ReturnIdxTensor(ctx); 35126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (is_successful) is_successful = ReturnValTensor(ctx); 35226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (is_successful) is_successful = ReturnShapeTensor(ctx); 35326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return is_successful; 35426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 35526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 35626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower bool GetAndValidateTensorInputForApplyGrad( 35726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OpKernelContext* ctx, 35826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower std::tuple<const Tensor*, const Tensor*, const Tensor*>** tensor) override 35926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { 36026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // TODO(xinghao, jmchen): The roundabout way of getting attr from 36126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // OpKernelContext (instead of OpKernelConstruction) is a hack, and should 36226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // be fixed if it affects efficiency. 36326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower bool has_known_shape = false; 36426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_OK_BOOLEAN( 36526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx, GetNodeAttr(ctx->op_kernel().def(), "has_known_shape", 36626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower &has_known_shape)); 36726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 36826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Get input gradient tensors 36926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* grad_idx_tensor; 37026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_OK_BOOLEAN(ctx, 37126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx->input("gradient_indices", &grad_idx_tensor)); 37226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* grad_val_tensor; 37326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_OK_BOOLEAN(ctx, 37426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx->input("gradient_values", &grad_val_tensor)); 37526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor* grad_shape_tensor = nullptr; 37626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (has_known_shape) { 37726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_OK_BOOLEAN(ctx, 37826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx->input("gradient_shape", &grad_shape_tensor)); 37926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 38026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 38126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // Checks 38226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_BOOLEAN( 38326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx, TensorShapeUtils::IsVector(grad_idx_tensor->shape()), 38426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower errors::InvalidArgument( 38526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower "Input indices should be vector but received shape: ", 38626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower grad_idx_tensor->shape().DebugString())); 38726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 nnz = grad_idx_tensor->dim_size(0); 38826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_BOOLEAN( 38926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx, grad_val_tensor->dims() > 0, 39026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower errors::InvalidArgument("Values cannot be 0-dimensional.")); 39126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_BOOLEAN(ctx, grad_val_tensor->dim_size(0) == nnz, 39226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower errors::InvalidArgument("Expected ", nnz, 39326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower " non-empty input values, got ", 39426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower grad_val_tensor->dim_size(0))); 39526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 39626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower *tensor = new std::tuple<const Tensor*, const Tensor*, const Tensor*>( 39726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower grad_idx_tensor, grad_val_tensor, grad_shape_tensor); 39826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 39926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_OK_BOOLEAN(ctx, this->ValidateShape(*tensor, has_known_shape)); 40026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 40126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return true; 40226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 40326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 40426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower void CleanUpGradTensor(std::tuple<const Tensor*, const Tensor*, 40526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const Tensor*>* tensor) override { 40626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (tensor != nullptr) delete tensor; 40726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 40826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 40926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower private: 41026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower inline int cmp(std::vector<int64>* a_idx, const Tensor* b_idx, 41126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 a_row, const int64 b_row) { 41226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 a = a_idx->at(a_row); 41326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 b = b_idx->vec<int64>()(b_row); 41426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower if (a < b) { 41526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return -1; 41626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } else if (a > b) { 41726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return 1; 41826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 41926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return 0; 42026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 42126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 42226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower inline bool ReturnIdxTensor(OpKernelContext* ctx) { 42326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Tensor* idx_tensor; 42426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower const int64 nnz = accum_idx_vec_->size(); 42526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_OK_BOOLEAN(ctx, ctx->allocate_output(0, {nnz}, &idx_tensor)); 42626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // If allocate_output fails, OP_REQUIRES_OK_BOOLEAN will short-circuit 42726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // the remaining code and just return false 42826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower auto idx_tensor_vec = idx_tensor->vec<int64>(); 42926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (int i = 0; i < nnz; ++i) { 43026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower idx_tensor_vec(i) = accum_idx_vec_->at(i); 43126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 43226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return true; 43326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 43426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 43526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower inline bool ReturnValTensor(OpKernelContext* ctx) { 43626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx->set_output(1, *accum_val_); 43726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return true; 43826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 43926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 44026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower inline bool ReturnShapeTensor(OpKernelContext* ctx) { 44126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower int64 accum_val_dims = accum_val_->dims(); 44226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower Tensor* shape_tensor; 44326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower OP_REQUIRES_OK_BOOLEAN( 44426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower ctx, ctx->allocate_output(2, {accum_val_dims}, &shape_tensor)); 44526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // If allocate_output fails, OP_REQUIRES_OK_BOOLEAN will short-circuit 44626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // the remaining code and just return false 44726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 44826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower // First dim of shape is defined by shape_, others by accum_val_->shape 44926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower shape_tensor->flat<int64>()(0) = 45026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower (shape_.dims() > 0) ? shape_.dim_size(0) : -1; 45126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower for (int64 i = 1; i < accum_val_dims; i++) { 45226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower shape_tensor->flat<int64>()(i) = accum_val_->dim_size(i); 45326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 45426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower return true; 45526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower } 45626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 45726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulator); 45826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower}; 45926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 46026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower} // namespace tensorflow 46126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower 46226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower#endif // TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_ 463