1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
16#include <cfloat>
17
18#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
19
20namespace tensorflow {
21namespace tensorforest {
22
23// When using smoothing but only tracking sum and squares, and we're adding
24// num_classes for smoothing each class, then Gini looks more like this:
25//   Gini = 1 - \sum_i (c_i + 1)^2 / C^2
26//   = 1 - (1 / C^2) ( (\sum_i c_i)^2 + 2 (\sum_i c_i) + (\sum_i 1))
27//   = 1 - (1 / C^2) ( stats.square() + 2 stats.sum() + #_classes)
28//   = 1 - ( stats.square() + 2 stats.sum() + #_classes) / (smoothed_sum *
29//                                                          smoothed_sum)
30//
31//   where
32//   smoothed_sum = stats.sum() + #_classes
33float GiniImpurity(const LeafStat& stats, int32 num_classes) {
34  const float smoothed_sum = num_classes + stats.weight_sum();
35  return 1.0 - ((stats.classification().gini().square() +
36                 2 * stats.weight_sum() + num_classes) /
37                (smoothed_sum * smoothed_sum));
38}
39
40float WeightedGiniImpurity(const LeafStat& stats, int32 num_classes) {
41  return stats.weight_sum() * GiniImpurity(stats, num_classes);
42}
43
44void UpdateGini(LeafStat* stats, float old_val, float weight) {
45  stats->set_weight_sum(stats->weight_sum() + weight);
46  // Equivalent to stats->square() - old_val * old_val + new_val * new_val,
47  // (for new_val = old_val + weight), but more numerically stable.
48  stats->mutable_classification()->mutable_gini()->set_square(
49      stats->classification().gini().square() + weight * weight +
50      2 * old_val * weight);
51}
52
53float Variance(const LeafStat& stats, int output) {
54  if (stats.weight_sum() == 0) {
55    return 0;
56  }
57  const float e_x =
58      stats.regression().mean_output().value(output).float_value() /
59      stats.weight_sum();
60  const auto e_x2 =
61      stats.regression().mean_output_squares().value(output).float_value() /
62      stats.weight_sum();
63  return e_x2 - e_x * e_x;
64}
65
66float TotalVariance(const LeafStat& stats) {
67  float sum = 0;
68  for (int i = 0; i < stats.regression().mean_output().value_size(); ++i) {
69    sum += Variance(stats, i);
70  }
71  return sum;
72}
73
74float SmoothedGini(float sum, float square, int num_classes) {
75  // See comments for GiniImpurity above.
76  const float smoothed_sum = num_classes + sum;
77  return 1.0 - (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum);
78}
79
80float WeightedSmoothedGini(float sum, float square, int num_classes) {
81  return sum * SmoothedGini(sum, square, num_classes);
82}
83
84}  // namespace tensorforest
85}  // namespace tensorflow
86