1# encoding: utf-8 2# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Categorical tests.""" 17 18# limitations under the License. 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as np 24 25from tensorflow.contrib.learn.python.learn.learn_io import HAS_PANDAS 26from tensorflow.contrib.learn.python.learn.preprocessing import categorical 27from tensorflow.python.platform import test 28 29 30class CategoricalTest(test.TestCase): 31 """Categorical tests.""" 32 33 def testSingleCategoricalProcessor(self): 34 cat_processor = categorical.CategoricalProcessor(min_frequency=1) 35 x = cat_processor.fit_transform([["0"], [1], [float("nan")], ["C"], ["C"], 36 [1], ["0"], [np.nan], [3]]) 37 self.assertAllEqual(list(x), [[2], [1], [0], [3], [3], [1], [2], [0], [0]]) 38 39 def testSingleCategoricalProcessorPandasSingleDF(self): 40 if HAS_PANDAS: 41 import pandas as pd # pylint: disable=g-import-not-at-top 42 cat_processor = categorical.CategoricalProcessor() 43 data = pd.DataFrame({"Gender": ["Male", "Female", "Male"]}) 44 x = list(cat_processor.fit_transform(data)) 45 self.assertAllEqual(list(x), [[1], [2], [1]]) 46 47 def testMultiCategoricalProcessor(self): 48 cat_processor = categorical.CategoricalProcessor( 49 min_frequency=0, share=False) 50 x = cat_processor.fit_transform([["0", "Male"], [1, "Female"], 51 ["3", "Male"]]) 52 self.assertAllEqual(list(x), [[1, 1], [2, 2], [3, 1]]) 53 54 55if __name__ == "__main__": 56 test.main() 57