11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License"); 41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License. 51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at 61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// http://www.apache.org/licenses/LICENSE-2.0 81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software 101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS, 111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and 131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License. 141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ============================================================================= 15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ 16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ 171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" 181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/platform/types.h" 191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow { 211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest { 221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the smoothed, unweighted Gini impurity. 241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat GiniImpurity(const LeafStat& stats, int32 num_classes); 251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the smoothed, weighted Gini impurity 271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat WeightedGiniImpurity(const LeafStat& stats, int32 num_classes); 281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Updates the GiniStats given the old and new values of a class count that 301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// was updated. 311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid UpdateGini(LeafStat* stats, float old_val, float weight); 321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the variance in stats for the given output. 341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat Variance(const LeafStat& stats, int output); 351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the variance sum for all outputs. 371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat TotalVariance(const LeafStat& stats); 381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------- functions used by C++ stats classes -------- // 401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the smoothed gini score given the sum and sum of the squares of the 411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// class counts. 421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat SmoothedGini(float sum, float square, int num_classes); 431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the smoothed gini score weighted by the sum. 451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat WeightedSmoothedGini(float sum, float square, int num_classes); 461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorforest 481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorflow 491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 50f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ 51