12098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
22098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
32098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengLicensed under the Apache License, Version 2.0 (the "License");
42098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengyou may not use this file except in compliance with the License.
52098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengYou may obtain a copy of the License at
62098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
72098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng    http://www.apache.org/licenses/LICENSE-2.0
82098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
92098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengUnless required by applicable law or agreed to in writing, software
102098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengdistributed under the License is distributed on an "AS IS" BASIS,
112098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengSee the License for the specific language governing permissions and
132098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fenglimitations under the License.
142098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng==============================================================================*/
152098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
162098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng#ifndef TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
172098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng#define TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
182098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
192098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng#include "tensorflow/core/kernels/conditional_accumulator_base.h"
202098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
212098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengnamespace tensorflow {
222098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
232098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng/*
242098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * TypedConditionalAccumulatorBase is a templated companion of
25cc2b908fa6f2acf06ffccf341ad2af8cd24aa12fTaehoon Lee * ConditionalAccumulatorBase which allows for subclasses to use different
262098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * types for the input gradients. (See ConditionalAccumulator and
272098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * SparseConditionalAccumulator.)
282098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng *
292098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * TypedConditionalAccumulatorBase defines virtual methods and implements
302098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * methods which depend on the gradient type. These are mainly methods that are
312098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * used for adding a new gradient to the accumulator.
322098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng */
332098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengtemplate <typename GradientTensorType>
342098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengclass TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
352098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng public:
362098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  TypedConditionalAccumulatorBase(const DataType& dtype,
372098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng                                  const PartialTensorShape& shape,
382098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng                                  const string& name)
392098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng      : ConditionalAccumulatorBase(dtype, shape, name) {}
402098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
412098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  /**
422098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
432098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   * successful (i.e., has its gradient applied) if its local_step >=
442098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   * current_global_step_ at the time the attempt is processed. Otherwise, if
452098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   * local_step < current_global_step_, the stale gradient is silently dropped.
462098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   *
472098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   * local_step: Time-step at which the gradient was computed.
482098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   * grad:       Gradient tensor to be added to the accumulator.
492098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   * ctx:        Context in which the op is executed.
502098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng   */
512098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  void TryApplyGrad(int64 local_step, OpKernelContext* ctx) override {
522098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng    {
532098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng      mutex_lock l(mu_);
542098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng      if (local_step >= current_global_step_) {
552098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng        GradientTensorType* grad = nullptr;
562098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng        bool is_valid = GetAndValidateTensorInputForApplyGrad(ctx, &grad);
572098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng        if (is_valid) {
582098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng          if (counter_ > 0) {
592098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng            AddToAccumGradFunction(ctx, grad);
602098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng          } else {
612098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng            AllocateAndAssignToAccumGradFunction(ctx, grad);
622098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng          }
632098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng          counter_++;
642098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng        }
652098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng        CleanUpGradTensor(grad);
662098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng      }
672098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng    }
682098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng    FlushUnlocked();
692098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  }
702098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
712098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng protected:
722098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  // Virtual methods to be implemented by sub-classes for different datatypes.
732098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  // Implements arithmetic operations specific to datatype.
742098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  virtual void AllocateAndAssignToAccumGradFunction(
752098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng      OpKernelContext* ctx, GradientTensorType* grad) = 0;
762098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
772098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  virtual void AddToAccumGradFunction(OpKernelContext* ctx,
782098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng                                      GradientTensorType* grad) = 0;
792098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
802098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  // Method for extracting and validating input provided in an OpKernelContext.
812098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  // Returns true if input was successfully retrieved and is valid.
822098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  // Gradient is returned via the GradientTensorType** tensor.
832098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  virtual bool GetAndValidateTensorInputForApplyGrad(
842098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng      OpKernelContext* ctx, GradientTensorType** tensor)
852098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng      EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
862098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
872098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  // Method for cleaning up any memory allocated in
882098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  // GetAndValidateTensorInputForApplyGrad
892098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng  virtual void CleanUpGradTensor(GradientTensorType* tensor) = 0;
902098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng};
912098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
922098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng}  // namespace tensorflow
932098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng
942098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng#endif  // TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
95