12098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 22098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 32098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengLicensed under the Apache License, Version 2.0 (the "License"); 42098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengyou may not use this file except in compliance with the License. 52098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengYou may obtain a copy of the License at 62098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 72098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng http://www.apache.org/licenses/LICENSE-2.0 82098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 92098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengUnless required by applicable law or agreed to in writing, software 102098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengdistributed under the License is distributed on an "AS IS" BASIS, 112098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 122098b9abcf20d2c9694055bbfd6997bc00b73578Yifei FengSee the License for the specific language governing permissions and 132098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fenglimitations under the License. 142098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng==============================================================================*/ 152098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 162098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng#ifndef TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ 172098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng#define TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ 182098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 192098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng#include "tensorflow/core/kernels/conditional_accumulator_base.h" 202098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 212098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengnamespace tensorflow { 222098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 232098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng/* 242098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * TypedConditionalAccumulatorBase is a templated companion of 25cc2b908fa6f2acf06ffccf341ad2af8cd24aa12fTaehoon Lee * ConditionalAccumulatorBase which allows for subclasses to use different 262098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * types for the input gradients. (See ConditionalAccumulator and 272098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * SparseConditionalAccumulator.) 282098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * 292098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * TypedConditionalAccumulatorBase defines virtual methods and implements 302098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * methods which depend on the gradient type. These are mainly methods that are 312098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * used for adding a new gradient to the accumulator. 322098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng */ 332098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengtemplate <typename GradientTensorType> 342098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Fengclass TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase { 352098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng public: 362098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng TypedConditionalAccumulatorBase(const DataType& dtype, 372098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng const PartialTensorShape& shape, 382098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng const string& name) 392098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng : ConditionalAccumulatorBase(dtype, shape, name) {} 402098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 412098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng /** 422098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is 432098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * successful (i.e., has its gradient applied) if its local_step >= 442098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * current_global_step_ at the time the attempt is processed. Otherwise, if 452098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * local_step < current_global_step_, the stale gradient is silently dropped. 462098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * 472098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * local_step: Time-step at which the gradient was computed. 482098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * grad: Gradient tensor to be added to the accumulator. 492098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng * ctx: Context in which the op is executed. 502098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng */ 512098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng void TryApplyGrad(int64 local_step, OpKernelContext* ctx) override { 522098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng { 532098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng mutex_lock l(mu_); 542098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng if (local_step >= current_global_step_) { 552098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng GradientTensorType* grad = nullptr; 562098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng bool is_valid = GetAndValidateTensorInputForApplyGrad(ctx, &grad); 572098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng if (is_valid) { 582098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng if (counter_ > 0) { 592098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng AddToAccumGradFunction(ctx, grad); 602098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng } else { 612098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng AllocateAndAssignToAccumGradFunction(ctx, grad); 622098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng } 632098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng counter_++; 642098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng } 652098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng CleanUpGradTensor(grad); 662098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng } 672098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng } 682098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng FlushUnlocked(); 692098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng } 702098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 712098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng protected: 722098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng // Virtual methods to be implemented by sub-classes for different datatypes. 732098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng // Implements arithmetic operations specific to datatype. 742098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng virtual void AllocateAndAssignToAccumGradFunction( 752098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng OpKernelContext* ctx, GradientTensorType* grad) = 0; 762098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 772098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng virtual void AddToAccumGradFunction(OpKernelContext* ctx, 782098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng GradientTensorType* grad) = 0; 792098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 802098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng // Method for extracting and validating input provided in an OpKernelContext. 812098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng // Returns true if input was successfully retrieved and is valid. 822098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng // Gradient is returned via the GradientTensorType** tensor. 832098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng virtual bool GetAndValidateTensorInputForApplyGrad( 842098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng OpKernelContext* ctx, GradientTensorType** tensor) 852098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; 862098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 872098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng // Method for cleaning up any memory allocated in 882098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng // GetAndValidateTensorInputForApplyGrad 892098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng virtual void CleanUpGradTensor(GradientTensorType* tensor) = 0; 902098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng}; 912098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 922098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng} // namespace tensorflow 932098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng 942098b9abcf20d2c9694055bbfd6997bc00b73578Yifei Feng#endif // TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_ 95