1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" 16#include <math.h> 17#include <stdlib.h> 18#include "tensorflow/core/platform/logging.h" 19 20namespace tensorflow { 21namespace tensorforest { 22 23float ResolveParam(const DepthDependentParam& param, int32 depth) { 24 float val; 25 switch (param.ParamType_case()) { 26 case DepthDependentParam::kConstantValue: 27 return param.constant_value(); 28 29 case DepthDependentParam::kLinear: 30 val = depth * param.linear().slope() + param.linear().y_intercept(); 31 return std::min(std::max(val, param.linear().min_val()), 32 param.linear().max_val()); 33 34 case DepthDependentParam::kExponential: 35 return param.exponential().bias() + 36 param.exponential().multiplier() * 37 static_cast<float>( 38 pow(param.exponential().base(), 39 param.exponential().depth_multiplier() * depth)); 40 41 case DepthDependentParam::kThreshold: 42 if (depth >= param.threshold().threshold()) { 43 return param.threshold().on_value(); 44 } else { 45 return param.threshold().off_value(); 46 } 47 48 default: 49 LOG(FATAL) << "unknown parameter type"; 50 } 51} 52 53} // namespace tensorforest 54} // namespace tensorflow 55