126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFloweryou may not use this file except in compliance with the License.
526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerYou may obtain a copy of the License at
626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
1026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
1126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerSee the License for the specific language governing permissions and
1326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerlimitations under the License.
1426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower==============================================================================*/
1526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
1626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower#ifndef TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
1726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower#define TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
1826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
1926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower#include "tensorflow/core/kernels/typed_conditional_accumulator_base.h"
2026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
2126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowernamespace tensorflow {
2226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
2326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower/**
2426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * An aggregation object for adding sparse gradients, represented as a tuple of
2526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * indices, values, and a (possibly empty) shape.
2626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower *
2726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * The two main methods of this class are TryApplyGrad and TryTakeGrad.
2826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower *
2926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * TryApplyGrad tries add a gradient to the accumulator. The attempt is
3026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * successful if local_step >= global_step, i.e., if the gradient is not stale,
3126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * having been computed using up-to-date information. Otherwise, the gradient is
3226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * silently dropped.
3326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower *
3426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * TryTakeGrad logs an attempt to read the average gradient. The attempt is
3526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * blocked until the number of gradients accumulated (via TryApplyGrad) is equal
3626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * or exceeds the number requested by TryTakeGrad.
3726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * Once this condition is satisfied, the following actions are taken:
3826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * (1) the value of the average gradient is returned
3926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * (2) the count of accumulated gradients is reset to 0
4026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * (3) the internal global_step value (current_global_step_) is incremented by 1
4126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower *
4226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * SparseConditionalAccumulator is the datatype-dependent templated sub-class of
4326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * ConditionalAccumulatorBase. It implements the virtual arithmetic methods that
4426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower * are used by for aggregating, averaging, allocating, returning indexed slices.
4526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower */
4626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowertemplate <typename Device, typename T>
4726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlowerclass SparseConditionalAccumulator
4826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    : public TypedConditionalAccumulatorBase<
4926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          std::tuple<const Tensor*, const Tensor*, const Tensor*>> {
5026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower public:
5126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  SparseConditionalAccumulator(const DataType& dtype,
5226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                               const PartialTensorShape& shape,
5326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                               const string& name)
5426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      : TypedConditionalAccumulatorBase<
5526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
5626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            dtype, shape, name) {
5726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_idx_vec_ = nullptr;
5826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    count_element_ = nullptr;
5926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_val_ = nullptr;
6026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_val_persistent_ = new PersistentTensor();
6126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
6226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
6326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  ~SparseConditionalAccumulator() override {
6426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (accum_idx_vec_ != nullptr) delete accum_idx_vec_;
6526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (count_element_ != nullptr) delete count_element_;
6626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (accum_val_persistent_ != nullptr) delete accum_val_persistent_;
6726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Do not delete accum_val_! Will be automatically garbage collected
6826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  };
6926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
7026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower protected:
7126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  std::vector<int64>* accum_idx_vec_ = nullptr;
7226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  std::vector<int>* count_element_ = nullptr;
7326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
7426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  Tensor* accum_val_ = nullptr;
7526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  PersistentTensor* accum_val_persistent_ = nullptr;
7626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
7726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor>,
7826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                           Eigen::Unaligned>
7926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      SliceT;
8026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  typedef Eigen::TensorMap<Eigen::Tensor<const T, 1, Eigen::RowMajor>,
8126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                           Eigen::Unaligned>
8226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      SliceConstT;
8326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
8426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  Status ValidateShape(
8526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      std::tuple<const Tensor*, const Tensor*, const Tensor*>* tensor,
8626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      bool has_known_shape) EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
8726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* tensor_idx = std::get<0>(*tensor);
8826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* tensor_val = std::get<1>(*tensor);
8926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* tensor_shape = std::get<2>(*tensor);
9026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    int64 grad_val_dims = tensor_val->dims();
9126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    int64 grad_dims = grad_val_dims;
9226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
9326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Compare with provided shape
9426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (has_known_shape) {
9526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      if (shape_.dims() > tensor_shape->NumElements()) {
9626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        return errors::InvalidArgument(
9726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            "Shape mismatch: expected shape rank at least ", shape_.dims(),
9826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            ", got ", tensor_shape->NumElements());
9926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
10026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      const auto tensor_shape_flat = tensor_shape->flat<int64>();
10126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      for (int64 i = 0; i < shape_.dims(); i++) {
10226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        if (shape_.dim_size(i) != -1 &&
10326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            shape_.dim_size(i) != tensor_shape_flat(i)) {
10426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          return errors::InvalidArgument("Shape mismatch: expected shape dim ",
10526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                         i, " to be ", shape_.dim_size(i),
10626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                         ", got ", tensor_shape_flat(i));
10726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        }
10826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
10926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
11026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Check that indices are within limits
11126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (shape_.dims() > 0 && shape_.dim_size(0) != -1 &&
11226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        tensor_idx->dims() > 0) {
11326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      for (int64 i = 0; i < tensor_idx->dim_size(0); i++) {
11426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        if (tensor_idx->vec<int64>()(i) >= shape_.dim_size(0)) {
11526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          return errors::InvalidArgument(
11626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower              "Shape mismatch: index of slice ", i, " exceeded limits of shape",
11726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower              "; index is ", tensor_idx->vec<int64>()(i), " exceeded ",
11826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower              shape_.dim_size(0));
11926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        }
12026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
12126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
12226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
12326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Check values compatibility with accumulated gradient if available
12426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (counter_ > 0) {
12526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      int64 accum_val_dims = accum_val_->dims();
12626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      if (accum_val_dims != grad_val_dims) {
12726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        return errors::InvalidArgument("Shape mismatch: expected values rank ",
12826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                       accum_val_dims, ", got ", grad_val_dims);
12926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
13026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      for (int64 i = 1; i < accum_val_dims; i++) {
13126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        if (accum_val_->dim_size(i) != tensor_val->dim_size(i)) {
13226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          return errors::InvalidArgument("Shape mismatch: expected values dim ",
13326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                         i, " to be ", accum_val_->dim_size(i),
13426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                         ", got ", tensor_val->dim_size(i));
13526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        }
13626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
13726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    } else {
13826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      // If there are no accumulated gradients, check against shape_
13926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      if (shape_.dims() > grad_dims) {
14026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        return errors::InvalidArgument(
14126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            "Shape mismatch: expected values rank at least ", shape_.dims(),
14226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            ", got ", grad_dims);
14326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
14426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      // Check that values have correct dimensions
14526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      for (int64 i = 1; i < shape_.dims(); i++) {
14626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        if (shape_.dim_size(i) != -1 &&
14726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            shape_.dim_size(i) != tensor_val->dim_size(i)) {
14826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          return errors::InvalidArgument("Shape mismatch: expected values dim ",
14926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                         i, " to be ", shape_.dim_size(i),
15026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                         ", got ", tensor_val->dim_size(i));
15126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        }
15226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
15326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
15426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
15526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    return Status::OK();
15626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
15726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
15826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  void AllocateAndAssignToAccumGradFunction(
15926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      OpKernelContext* ctx,
16026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      std::tuple<const Tensor*, const Tensor*, const Tensor*>* grad) override {
16126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* grad_idx = std::get<0>(*grad);
16226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* grad_val = std::get<1>(*grad);
16326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
16426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 nnz = grad_idx->dim_size(0);
16526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
16626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Assign indices
16726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (accum_idx_vec_ != nullptr) delete accum_idx_vec_;
16826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_idx_vec_ = new std::vector<int64>();
16926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_idx_vec_->reserve(nnz);
17026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    for (int i = 0; i < nnz; i++) {
17126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      accum_idx_vec_->push_back(grad_idx->vec<int64>()(i));
17226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
17326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
17426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Assign values to accum_val_tensor
175ef7d75baec1b0b3861acef52f1973bbe379ae881Justin Lebar    // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
17626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    ctx->allocate_persistent(dtype_, grad_val->shape(), accum_val_persistent_,
177ef7d75baec1b0b3861acef52f1973bbe379ae881Justin Lebar                             &accum_val_)
178ef7d75baec1b0b3861acef52f1973bbe379ae881Justin Lebar        .IgnoreError();
17926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_val_->flat<T>().device(ctx->template eigen_device<Device>()) =
18026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        grad_val->flat<T>();
18126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
18226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Assign count_element_
18326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (count_element_ != nullptr) {
18426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      delete count_element_;
18526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
18626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    count_element_ = new std::vector<int>(nnz, 1);
18726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
18826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Do not need shape; Assume that the op has checked that the shapes match,
18926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // so grad's shape == shape_
19026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
19126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
19226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  void AddToAccumGradFunction(
19326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      OpKernelContext* ctx,
19426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      std::tuple<const Tensor*, const Tensor*, const Tensor*>* grad) override {
19526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Modeled after third_party/tensorflow/core/kernels/sparse_add_op
19626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
19726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* grad_idx = std::get<0>(*grad);
19826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* grad_val = std::get<1>(*grad);
19926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
20026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 accum_nnz = accum_idx_vec_->size();
20126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 grad_nnz = grad_idx->dim_size(0);
20226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
20326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Source enumerates the origin of a non-zero element: whether it is from
20426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // the new gradient, the accumulated gradient, or the sum of both.
20526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    enum Source { from_accum, from_grad, from_accum_and_grad };
20626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
20726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // (1) do a pass over inputs, and append values and indices to vectors
20826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    std::vector<std::tuple<Source, int64, int64>> entries_to_copy;
20926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    entries_to_copy.reserve(accum_nnz + grad_nnz);
21026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
21126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Pass over all non-zero elements of both the gradient and the accumulated
21226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // value, to identify where each non-zero element of the sum comes from.
21326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // The input and output indexed slices are assumed to be ordered along
21426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // increasing dimension number.
21526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    int64 i = 0, j = 0;
21626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    int64 sum_nnz = 0;
21726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    while (i < accum_nnz && j < grad_nnz) {
21826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      sum_nnz++;
21926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      switch (cmp(accum_idx_vec_, grad_idx, i, j)) {
22026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        case -1:
22126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          entries_to_copy.emplace_back(from_accum, i, -1);
22226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          ++i;
22326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          break;
22426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        case 0:
22526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          entries_to_copy.emplace_back(from_accum_and_grad, i, j);
22626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          ++i;
22726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          ++j;
22826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          break;
22926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        case 1:
23026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          entries_to_copy.emplace_back(from_grad, -1, j);
23126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          ++j;
23226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          break;
23326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
23426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
23526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
23626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Handle leftovers
23726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    while (i < accum_nnz) {
23826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      sum_nnz++;
23926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      entries_to_copy.emplace_back(from_accum, i, -1);
24026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      ++i;
24126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
24226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    while (j < grad_nnz) {
24326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      sum_nnz++;
24426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      entries_to_copy.emplace_back(from_grad, -1, j);
24526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      ++j;
24626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
24726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
24826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // (2) Copy or sum the non-zero elements into sum_indices and sum_tensor
24926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    std::vector<int64>* sum_indices_vec = new std::vector<int64>();
25026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    sum_indices_vec->reserve(sum_nnz);
25126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
25226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    std::vector<int>* sum_counts = new std::vector<int>();
25326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    sum_counts->reserve(sum_nnz);
25426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
25526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    Tensor* sum_tensor = nullptr;
25626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    PersistentTensor* tensor_sum_persistent = new PersistentTensor();
25726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
25826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    TensorShape sum_shape = grad_val->shape();
25926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    sum_shape.set_dim(0, sum_nnz);
26026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
26126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_OK(
26226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        ctx, ctx->allocate_persistent(dtype_, sum_shape, tensor_sum_persistent,
26326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                      &sum_tensor));
26426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    auto sum_flat = sum_tensor->flat_outer_dims<T>();
26526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    auto accum_flat = accum_val_->flat_outer_dims<T>();
26626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    auto grad_flat = grad_val->flat_outer_dims<T>();
26726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
26826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 num_col = grad_flat.dimension(1);
26926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
27026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    Eigen::DSizes<Eigen::DenseIndex, 1> slice_shape(num_col);
27126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
27226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    for (i = 0; i < sum_nnz; ++i) {
27326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      const Source src = std::get<0>(entries_to_copy[i]);
27426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      const int64 idx_a = std::get<1>(entries_to_copy[i]);
27526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      const int64 idx_b = std::get<2>(entries_to_copy[i]);
27626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      T* sum_slice_ptr = &sum_flat(i, 0);
27726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      SliceT sum_slice(sum_slice_ptr, slice_shape);
27826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      if (src == from_accum) {
27926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        // Element comes from accumulator; directly copy data structures over
28026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_indices_vec->push_back(accum_idx_vec_->at(idx_a));
28126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        T* accum_slice_ptr = &accum_flat(idx_a, 0);
28226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        SliceT accum_slice(accum_slice_ptr, slice_shape);
28326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_slice = accum_slice;
28426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_counts->push_back(count_element_->at(idx_a));
28526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      } else if (src == from_accum_and_grad) {
28626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        // Element is a sum of accumulated value and new gradient;
28726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        // compute sum here
28826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_indices_vec->push_back(accum_idx_vec_->at(idx_a));
28926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        const T* grad_slice_ptr = &grad_flat(idx_b, 0);
29026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        SliceConstT grad_slice(grad_slice_ptr, slice_shape);
29126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        T* accum_slice_ptr = &accum_flat(idx_a, 0);
29226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        SliceT accum_slice(accum_slice_ptr, slice_shape);
29326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_slice = grad_slice + accum_slice;
29426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_counts->push_back(count_element_->at(idx_a) + 1);
29526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      } else if (src == from_grad) {
29626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        // Element comes from new gradient; make a copy of indices and values
29726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_indices_vec->push_back(grad_idx->vec<int64>()(idx_b));
29826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        const T* grad_slice_ptr = &grad_flat(idx_b, 0);
29926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        SliceConstT grad_slice(grad_slice_ptr, slice_shape);
30026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_slice = grad_slice;
30126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        sum_counts->push_back(1);
30226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      }
30326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
30426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
30526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // (3) Keep output, i.e., switch pointers to point to new data structures
30626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // representing the sum
30726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Indices
30826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (accum_idx_vec_ != nullptr) delete accum_idx_vec_;
30926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_idx_vec_ = sum_indices_vec;
31026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Values
31126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_val_ = sum_tensor;
31226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    delete accum_val_persistent_;
31326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    accum_val_persistent_ = tensor_sum_persistent;
31426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Counts
31526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (count_element_ != nullptr) delete count_element_;
31626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    count_element_ = sum_counts;
31726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
31826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // No need to copy shape, since shape remains the same after sum.
31926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
32026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
32126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  void DivideAccumGradByCounter(OpKernelContext* ctx) override
32226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
32326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 nnz = count_element_->size();
32426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    auto accum_flat = accum_val_->flat_outer_dims<T>();
32526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    std::vector<T> count_typet;
32626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    std::transform(count_element_->begin(), count_element_->end(),
32726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                   std::back_inserter(count_typet),
32826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                   TypeConverter<T, int>::ConvertUToT);
32926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
33026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Option 1: divide all by counter
33126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    /*
33226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    std::transform(
33326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        &accum_flat(0,0), &accum_flat(nnz,0), &accum_flat(0,0),
33426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        std::bind2nd(std::divides<T>(),
33526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                     TypeConverter<T, int>::ConvertUToT(this->counter_)));
33626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    */
33726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
33826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Option 2: average element-wise
33926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    Eigen::DSizes<Eigen::DenseIndex, 1> slice_shape(accum_flat.dimension(1));
34026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    for (int64 i = 0; i < nnz; i++) {
34126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      T* accum_slice_ptr = &accum_flat(i, 0);
34226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      SliceT accum_slice(accum_slice_ptr, slice_shape);
34326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      accum_slice.device(ctx->template eigen_device<Device>()) =
34426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower          accum_slice / count_typet[i];
34526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
34626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
34726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
34826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  bool SetOutput(OpKernelContext* ctx) override {
34926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    bool is_successful = true;
35026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (is_successful) is_successful = ReturnIdxTensor(ctx);
35126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (is_successful) is_successful = ReturnValTensor(ctx);
35226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (is_successful) is_successful = ReturnShapeTensor(ctx);
35326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    return is_successful;
35426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
35526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
35626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  bool GetAndValidateTensorInputForApplyGrad(
35726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      OpKernelContext* ctx,
35826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      std::tuple<const Tensor*, const Tensor*, const Tensor*>** tensor) override
35926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {
36026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // TODO(xinghao, jmchen): The roundabout way of getting attr from
36126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // OpKernelContext (instead of OpKernelConstruction) is a hack, and should
36226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // be fixed if it affects efficiency.
36326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    bool has_known_shape = false;
36426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_OK_BOOLEAN(
36526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        ctx, GetNodeAttr(ctx->op_kernel().def(), "has_known_shape",
36626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                         &has_known_shape));
36726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
36826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Get input gradient tensors
36926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* grad_idx_tensor;
37026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_OK_BOOLEAN(ctx,
37126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                           ctx->input("gradient_indices", &grad_idx_tensor));
37226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* grad_val_tensor;
37326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_OK_BOOLEAN(ctx,
37426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                           ctx->input("gradient_values", &grad_val_tensor));
37526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const Tensor* grad_shape_tensor = nullptr;
37626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (has_known_shape) {
37726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      OP_REQUIRES_OK_BOOLEAN(ctx,
37826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                             ctx->input("gradient_shape", &grad_shape_tensor));
37926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
38026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
38126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // Checks
38226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_BOOLEAN(
38326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        ctx, TensorShapeUtils::IsVector(grad_idx_tensor->shape()),
38426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        errors::InvalidArgument(
38526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            "Input indices should be vector but received shape: ",
38626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower            grad_idx_tensor->shape().DebugString()));
38726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 nnz = grad_idx_tensor->dim_size(0);
38826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_BOOLEAN(
38926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        ctx, grad_val_tensor->dims() > 0,
39026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        errors::InvalidArgument("Values cannot be 0-dimensional."));
39126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_BOOLEAN(ctx, grad_val_tensor->dim_size(0) == nnz,
39226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                        errors::InvalidArgument("Expected ", nnz,
39326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                                " non-empty input values, got ",
39426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                                grad_val_tensor->dim_size(0)));
39526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
39626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    *tensor = new std::tuple<const Tensor*, const Tensor*, const Tensor*>(
39726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        grad_idx_tensor, grad_val_tensor, grad_shape_tensor);
39826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
39926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_OK_BOOLEAN(ctx, this->ValidateShape(*tensor, has_known_shape));
40026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
40126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    return true;
40226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
40326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
40426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  void CleanUpGradTensor(std::tuple<const Tensor*, const Tensor*,
40526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                                    const Tensor*>* tensor) override {
40626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (tensor != nullptr) delete tensor;
40726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
40826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
40926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower private:
41026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  inline int cmp(std::vector<int64>* a_idx, const Tensor* b_idx,
41126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower                 const int64 a_row, const int64 b_row) {
41226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 a = a_idx->at(a_row);
41326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 b = b_idx->vec<int64>()(b_row);
41426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    if (a < b) {
41526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      return -1;
41626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    } else if (a > b) {
41726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      return 1;
41826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
41926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    return 0;
42026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
42126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
42226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  inline bool ReturnIdxTensor(OpKernelContext* ctx) {
42326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    Tensor* idx_tensor;
42426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    const int64 nnz = accum_idx_vec_->size();
42526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_OK_BOOLEAN(ctx, ctx->allocate_output(0, {nnz}, &idx_tensor));
42626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // If allocate_output fails, OP_REQUIRES_OK_BOOLEAN will short-circuit
42726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // the remaining code and just return false
42826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    auto idx_tensor_vec = idx_tensor->vec<int64>();
42926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    for (int i = 0; i < nnz; ++i) {
43026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      idx_tensor_vec(i) = accum_idx_vec_->at(i);
43126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
43226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    return true;
43326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
43426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
43526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  inline bool ReturnValTensor(OpKernelContext* ctx) {
43626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    ctx->set_output(1, *accum_val_);
43726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    return true;
43826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
43926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
44026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  inline bool ReturnShapeTensor(OpKernelContext* ctx) {
44126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    int64 accum_val_dims = accum_val_->dims();
44226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    Tensor* shape_tensor;
44326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    OP_REQUIRES_OK_BOOLEAN(
44426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        ctx, ctx->allocate_output(2, {accum_val_dims}, &shape_tensor));
44526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // If allocate_output fails, OP_REQUIRES_OK_BOOLEAN will short-circuit
44626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // the remaining code and just return false
44726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
44826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    // First dim of shape is defined by shape_, others by accum_val_->shape
44926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    shape_tensor->flat<int64>()(0) =
45026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower        (shape_.dims() > 0) ? shape_.dim_size(0) : -1;
45126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    for (int64 i = 1; i < accum_val_dims; i++) {
45226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower      shape_tensor->flat<int64>()(i) = accum_val_->dim_size(i);
45326d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    }
45426d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower    return true;
45526d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  }
45626d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
45726d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower  TF_DISALLOW_COPY_AND_ASSIGN(SparseConditionalAccumulator);
45826d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower};
45926d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
46026d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower}  // namespace tensorflow
46126d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower
46226d51ae7aae4cc8fe4807bde1aeec86f0e0c0fc2A. Unique TensorFlower#endif  // TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
463