1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for cross entropy related functionality in tensorflow.ops.nn."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import math
22
23import numpy as np
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.ops import gradient_checker
28from tensorflow.python.ops import gradients_impl
29from tensorflow.python.ops import nn_impl
30import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
31from tensorflow.python.platform import test
32
33exp = math.exp
34log = math.log
35
36
37class SigmoidCrossEntropyWithLogitsTest(test.TestCase):
38
39  def _SigmoidCrossEntropyWithLogits(self, logits, targets):
40    assert len(logits) == len(targets)
41    pred = [1 / (1 + exp(-x)) for x in logits]
42    eps = 0.0001
43    pred = [min(max(p, eps), 1 - eps) for p in pred]
44    return [-z * log(y) - (1 - z) * log(1 - y) for y, z in zip(pred, targets)]
45
46  def _Inputs(self, x=None, y=None, dtype=dtypes.float64, sizes=None):
47    x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
48    y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
49    assert len(x) == len(y)
50    sizes = sizes if sizes else [len(x)]
51    logits = constant_op.constant(x, shape=sizes, dtype=dtype, name="logits")
52    targets = constant_op.constant(y, shape=sizes, dtype=dtype, name="targets")
53    losses = np.array(self._SigmoidCrossEntropyWithLogits(x, y)).reshape(*sizes)
54    return logits, targets, losses
55
56  def testConstructionNamed(self):
57    with self.test_session():
58      logits, targets, _ = self._Inputs()
59      loss = nn_impl.sigmoid_cross_entropy_with_logits(
60          labels=targets, logits=logits, name="mylogistic")
61    self.assertEqual("mylogistic", loss.op.name)
62
63  def testLogisticOutput(self):
64    for use_gpu in [True, False]:
65      for dtype in [dtypes.float32, dtypes.float16]:
66        with self.test_session(use_gpu=use_gpu):
67          logits, targets, losses = self._Inputs(dtype=dtype)
68          loss = nn_impl.sigmoid_cross_entropy_with_logits(
69              labels=targets, logits=logits)
70          np_loss = np.array(losses).astype(np.float32)
71          tf_loss = loss.eval()
72        self.assertAllClose(np_loss, tf_loss, atol=0.001)
73
74  def testLogisticOutputMultiDim(self):
75    for use_gpu in [True, False]:
76      for dtype in [dtypes.float32, dtypes.float16]:
77        with self.test_session(use_gpu=use_gpu):
78          logits, targets, losses = self._Inputs(dtype=dtype, sizes=[2, 2, 2])
79          loss = nn_impl.sigmoid_cross_entropy_with_logits(
80              labels=targets, logits=logits)
81          np_loss = np.array(losses).astype(np.float32)
82          tf_loss = loss.eval()
83        self.assertAllClose(np_loss, tf_loss, atol=0.001)
84
85  def testGradient(self):
86    sizes = [4, 2]
87    with self.test_session():
88      logits, targets, _ = self._Inputs(sizes=sizes)
89      loss = nn_impl.sigmoid_cross_entropy_with_logits(
90          labels=targets, logits=logits)
91      err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
92    print("logistic loss gradient err = ", err)
93    self.assertLess(err, 1e-7)
94
95  def testGradientAtZero(self):
96    with self.test_session():
97      logits = constant_op.constant([0.0, 0.0], dtype=dtypes.float64)
98      targets = constant_op.constant([0.0, 1.0], dtype=dtypes.float64)
99      loss = nn_impl.sigmoid_cross_entropy_with_logits(
100          labels=targets, logits=logits)
101      grads = gradients_impl.gradients(loss, logits)[0].eval()
102    self.assertAllClose(grads, [0.5, -0.5])
103
104  def testShapeError(self):
105    with self.assertRaisesRegexp(ValueError, "must have the same shape"):
106      nn_impl.sigmoid_cross_entropy_with_logits(labels=[1, 2, 3],
107                                                logits=[[2, 1]])
108
109
110class WeightedCrossEntropyTest(test.TestCase):
111
112  def _WeightedCrossEntropy(self, logits, targets, pos_coeff):
113    assert len(logits) == len(targets)
114    pred = [1 / (1 + exp(-x)) for x in logits]
115    eps = 0.0001
116    pred = [min(max(p, eps), 1 - eps) for p in pred]
117    return [
118        -z * pos_coeff * log(y) - (1 - z) * log(1 - y)
119        for y, z in zip(pred, targets)
120    ]
121
122  def _Inputs(self, x=None, y=None, q=3.0, dtype=dtypes.float64, sizes=None):
123    x = [-100, -2, -2, 0, 2, 2, 2, 100] if x is None else x
124    y = [0, 0, 1, 0, 0, 1, 0.5, 1] if y is None else y
125    assert len(x) == len(y)
126    sizes = sizes if sizes else [len(x)]
127    logits = constant_op.constant(x, shape=sizes, dtype=dtype, name="logits")
128    targets = constant_op.constant(y, shape=sizes, dtype=dtype, name="targets")
129    losses = np.array(self._WeightedCrossEntropy(x, y, q)).reshape(*sizes)
130    return logits, targets, q, losses
131
132  def testConstructionNamed(self):
133    with self.test_session():
134      logits, targets, pos_weight, _ = self._Inputs()
135      loss = nn_impl.weighted_cross_entropy_with_logits(
136          targets=targets, logits=logits, pos_weight=pos_weight, name="mybce")
137    self.assertEqual("mybce", loss.op.name)
138
139  def testOutput(self):
140    for use_gpu in [True, False]:
141      with self.test_session(use_gpu=use_gpu):
142        logits, targets, pos_weight, losses = self._Inputs(dtype=dtypes.float32)
143        loss = nn_impl.weighted_cross_entropy_with_logits(
144            targets=targets, logits=logits, pos_weight=pos_weight)
145        np_loss = np.array(losses).astype(np.float32)
146        tf_loss = loss.eval()
147      self.assertAllClose(np_loss, tf_loss, atol=0.001)
148
149  def testOutputMultiDim(self):
150    for use_gpu in [True, False]:
151      with self.test_session(use_gpu=use_gpu):
152        logits, targets, pos_weight, losses = self._Inputs(
153            dtype=dtypes.float32, sizes=[2, 2, 2])
154        loss = nn_impl.weighted_cross_entropy_with_logits(
155            targets=targets, logits=logits, pos_weight=pos_weight)
156        np_loss = np.array(losses).astype(np.float32)
157        tf_loss = loss.eval()
158      self.assertAllClose(np_loss, tf_loss, atol=0.001)
159
160  def testGradient(self):
161    sizes = [4, 2]
162    with self.test_session():
163      logits, targets, pos_weight, _ = self._Inputs(sizes=sizes)
164      loss = nn_impl.weighted_cross_entropy_with_logits(
165          targets=targets, logits=logits, pos_weight=pos_weight)
166      err = gradient_checker.compute_gradient_error(logits, sizes, loss, sizes)
167    print("logistic loss gradient err = ", err)
168    self.assertLess(err, 1e-7)
169
170  def testShapeError(self):
171    with self.assertRaisesRegexp(ValueError, "must have the same shape"):
172      nn_impl.weighted_cross_entropy_with_logits(
173          targets=[1, 2, 3], logits=[[2, 1]], pos_weight=2.0)
174
175
176if __name__ == "__main__":
177  test.main()
178