11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License");
41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License.
51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at
61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//     http://www.apache.org/licenses/LICENSE-2.0
81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software
101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS,
111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and
131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License.
141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// =============================================================================
15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/platform/types.h"
191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow {
211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest {
221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the smoothed, unweighted Gini impurity.
241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat GiniImpurity(const LeafStat& stats, int32 num_classes);
251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the smoothed, weighted Gini impurity
271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat WeightedGiniImpurity(const LeafStat& stats, int32 num_classes);
281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Updates the GiniStats given the old and new values of a class count that
301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// was updated.
311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid UpdateGini(LeafStat* stats, float old_val, float weight);
321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the variance in stats for the given output.
341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat Variance(const LeafStat& stats, int output);
351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the variance sum for all outputs.
371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat TotalVariance(const LeafStat& stats);
381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------- functions used by C++ stats classes  -------- //
401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the smoothed gini score given the sum and sum of the squares of the
411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// class counts.
421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat SmoothedGini(float sum, float square, int num_classes);
431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Returns the smoothed gini score weighted by the sum.
451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat WeightedSmoothedGini(float sum, float square, int num_classes);
461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorforest
481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorflow
491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
50f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_
51