1# encoding: utf-8
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Categorical tests."""
17
18#  limitations under the License.
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import numpy as np
24
25from tensorflow.contrib.learn.python.learn.learn_io import HAS_PANDAS
26from tensorflow.contrib.learn.python.learn.preprocessing import categorical
27from tensorflow.python.platform import test
28
29
30class CategoricalTest(test.TestCase):
31  """Categorical tests."""
32
33  def testSingleCategoricalProcessor(self):
34    cat_processor = categorical.CategoricalProcessor(min_frequency=1)
35    x = cat_processor.fit_transform([["0"], [1], [float("nan")], ["C"], ["C"],
36                                     [1], ["0"], [np.nan], [3]])
37    self.assertAllEqual(list(x), [[2], [1], [0], [3], [3], [1], [2], [0], [0]])
38
39  def testSingleCategoricalProcessorPandasSingleDF(self):
40    if HAS_PANDAS:
41      import pandas as pd  # pylint: disable=g-import-not-at-top
42      cat_processor = categorical.CategoricalProcessor()
43      data = pd.DataFrame({"Gender": ["Male", "Female", "Male"]})
44      x = list(cat_processor.fit_transform(data))
45      self.assertAllEqual(list(x), [[1], [2], [1]])
46
47  def testMultiCategoricalProcessor(self):
48    cat_processor = categorical.CategoricalProcessor(
49        min_frequency=0, share=False)
50    x = cat_processor.fit_transform([["0", "Male"], [1, "Female"],
51                                     ["3", "Male"]])
52    self.assertAllEqual(list(x), [[1, 1], [2, 2], [3, 1]])
53
54
55if __name__ == "__main__":
56  test.main()
57