1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for feature_column."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23
24import numpy as np
25
26from tensorflow.core.example import example_pb2
27from tensorflow.core.example import feature_pb2
28from tensorflow.python.client import session
29from tensorflow.python.eager import backprop
30from tensorflow.python.eager import context
31from tensorflow.python.estimator.inputs import numpy_io
32from tensorflow.python.feature_column import feature_column_lib as fc
33from tensorflow.python.feature_column.feature_column import _CategoricalColumn
34from tensorflow.python.feature_column.feature_column import _DenseColumn
35from tensorflow.python.feature_column.feature_column import _FeatureColumn
36from tensorflow.python.feature_column.feature_column import _LazyBuilder
37from tensorflow.python.feature_column.feature_column import _transform_features
38from tensorflow.python.feature_column.feature_column import InputLayer
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import dtypes
41from tensorflow.python.framework import errors
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import sparse_tensor
44from tensorflow.python.framework import test_util
45from tensorflow.python.ops import array_ops
46from tensorflow.python.ops import lookup_ops
47from tensorflow.python.ops import parsing_ops
48from tensorflow.python.ops import partitioned_variables
49from tensorflow.python.ops import variable_scope
50from tensorflow.python.ops import variables as variables_lib
51from tensorflow.python.platform import test
52from tensorflow.python.training import coordinator
53from tensorflow.python.training import queue_runner_impl
54
55
56def _initialized_session():
57  sess = session.Session()
58  sess.run(variables_lib.global_variables_initializer())
59  sess.run(lookup_ops.tables_initializer())
60  return sess
61
62
63class LazyColumnTest(test.TestCase):
64
65  def test_transormations_called_once(self):
66
67    class TransformCounter(_FeatureColumn):
68
69      def __init__(self):
70        self.num_transform = 0
71
72      @property
73      def name(self):
74        return 'TransformCounter'
75
76      def _transform_feature(self, cache):
77        self.num_transform += 1  # Count transform calls.
78        return cache.get('a')
79
80      @property
81      def _parse_example_spec(self):
82        pass
83
84    builder = _LazyBuilder(features={'a': [[2], [3.]]})
85    column = TransformCounter()
86    self.assertEqual(0, column.num_transform)
87    builder.get(column)
88    self.assertEqual(1, column.num_transform)
89    builder.get(column)
90    self.assertEqual(1, column.num_transform)
91
92  def test_returns_transform_output(self):
93
94    class Transformer(_FeatureColumn):
95
96      @property
97      def name(self):
98        return 'Transformer'
99
100      def _transform_feature(self, cache):
101        return 'Output'
102
103      @property
104      def _parse_example_spec(self):
105        pass
106
107    builder = _LazyBuilder(features={'a': [[2], [3.]]})
108    column = Transformer()
109    self.assertEqual('Output', builder.get(column))
110    self.assertEqual('Output', builder.get(column))
111
112  def test_does_not_pollute_given_features_dict(self):
113
114    class Transformer(_FeatureColumn):
115
116      @property
117      def name(self):
118        return 'Transformer'
119
120      def _transform_feature(self, cache):
121        return 'Output'
122
123      @property
124      def _parse_example_spec(self):
125        pass
126
127    features = {'a': [[2], [3.]]}
128    builder = _LazyBuilder(features=features)
129    builder.get(Transformer())
130    self.assertEqual(['a'], list(features.keys()))
131
132  def test_error_if_feature_is_not_found(self):
133    builder = _LazyBuilder(features={'a': [[2], [3.]]})
134    with self.assertRaisesRegexp(ValueError,
135                                 'bbb is not in features dictionary'):
136      builder.get('bbb')
137
138  def test_not_supported_feature_column(self):
139
140    class NotAProperColumn(_FeatureColumn):
141
142      @property
143      def name(self):
144        return 'NotAProperColumn'
145
146      def _transform_feature(self, cache):
147        # It should return not None.
148        pass
149
150      @property
151      def _parse_example_spec(self):
152        pass
153
154    builder = _LazyBuilder(features={'a': [[2], [3.]]})
155    with self.assertRaisesRegexp(ValueError,
156                                 'NotAProperColumn is not supported'):
157      builder.get(NotAProperColumn())
158
159  def test_key_should_be_string_or_feature_colum(self):
160
161    class NotAFeatureColumn(object):
162      pass
163
164    builder = _LazyBuilder(features={'a': [[2], [3.]]})
165    with self.assertRaisesRegexp(
166        TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
167      builder.get(NotAFeatureColumn())
168
169
170class NumericColumnTest(test.TestCase):
171
172  def test_defaults(self):
173    a = fc.numeric_column('aaa')
174    self.assertEqual('aaa', a.key)
175    self.assertEqual('aaa', a.name)
176    self.assertEqual('aaa', a._var_scope_name)
177    self.assertEqual((1,), a.shape)
178    self.assertIsNone(a.default_value)
179    self.assertEqual(dtypes.float32, a.dtype)
180    self.assertIsNone(a.normalizer_fn)
181
182  def test_shape_saved_as_tuple(self):
183    a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
184    self.assertEqual((1, 2), a.shape)
185
186  def test_default_value_saved_as_tuple(self):
187    a = fc.numeric_column('aaa', default_value=4.)
188    self.assertEqual((4.,), a.default_value)
189    a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
190    self.assertEqual(((3., 2.),), a.default_value)
191
192  def test_shape_and_default_value_compatibility(self):
193    fc.numeric_column('aaa', shape=[2], default_value=[1, 2.])
194    with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
195      fc.numeric_column('aaa', shape=[2], default_value=[1, 2, 3.])
196    fc.numeric_column(
197        'aaa', shape=[3, 2], default_value=[[2, 3], [1, 2], [2, 3.]])
198    with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
199      fc.numeric_column(
200          'aaa', shape=[3, 1], default_value=[[2, 3], [1, 2], [2, 3.]])
201    with self.assertRaisesRegexp(ValueError, 'The shape of default_value'):
202      fc.numeric_column(
203          'aaa', shape=[3, 3], default_value=[[2, 3], [1, 2], [2, 3.]])
204
205  def test_default_value_type_check(self):
206    fc.numeric_column(
207        'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.float32)
208    fc.numeric_column(
209        'aaa', shape=[2], default_value=[1, 2], dtype=dtypes.int32)
210    with self.assertRaisesRegexp(TypeError, 'must be compatible with dtype'):
211      fc.numeric_column(
212          'aaa', shape=[2], default_value=[1, 2.], dtype=dtypes.int32)
213    with self.assertRaisesRegexp(TypeError,
214                                 'default_value must be compatible with dtype'):
215      fc.numeric_column('aaa', default_value=['string'])
216
217  def test_shape_must_be_positive_integer(self):
218    with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'):
219      fc.numeric_column(
220          'aaa', shape=[
221              1.0,
222          ])
223
224    with self.assertRaisesRegexp(ValueError,
225                                 'shape dimensions must be greater than 0'):
226      fc.numeric_column(
227          'aaa', shape=[
228              0,
229          ])
230
231  def test_dtype_is_convertible_to_float(self):
232    with self.assertRaisesRegexp(ValueError,
233                                 'dtype must be convertible to float'):
234      fc.numeric_column('aaa', dtype=dtypes.string)
235
236  def test_scalar_default_value_fills_the_shape(self):
237    a = fc.numeric_column('aaa', shape=[2, 3], default_value=2.)
238    self.assertEqual(((2., 2., 2.), (2., 2., 2.)), a.default_value)
239
240  def test_parse_spec(self):
241    a = fc.numeric_column('aaa', shape=[2, 3], dtype=dtypes.int32)
242    self.assertEqual({
243        'aaa': parsing_ops.FixedLenFeature((2, 3), dtype=dtypes.int32)
244    }, a._parse_example_spec)
245
246  def test_parse_example_no_default_value(self):
247    price = fc.numeric_column('price', shape=[2])
248    data = example_pb2.Example(features=feature_pb2.Features(
249        feature={
250            'price':
251                feature_pb2.Feature(float_list=feature_pb2.FloatList(
252                    value=[20., 110.]))
253        }))
254    features = parsing_ops.parse_example(
255        serialized=[data.SerializeToString()],
256        features=fc.make_parse_example_spec([price]))
257    self.assertIn('price', features)
258    with self.test_session():
259      self.assertAllEqual([[20., 110.]], features['price'].eval())
260
261  def test_parse_example_with_default_value(self):
262    price = fc.numeric_column('price', shape=[2], default_value=11.)
263    data = example_pb2.Example(features=feature_pb2.Features(
264        feature={
265            'price':
266                feature_pb2.Feature(float_list=feature_pb2.FloatList(
267                    value=[20., 110.]))
268        }))
269    no_data = example_pb2.Example(features=feature_pb2.Features(
270        feature={
271            'something_else':
272                feature_pb2.Feature(float_list=feature_pb2.FloatList(
273                    value=[20., 110.]))
274        }))
275    features = parsing_ops.parse_example(
276        serialized=[data.SerializeToString(),
277                    no_data.SerializeToString()],
278        features=fc.make_parse_example_spec([price]))
279    self.assertIn('price', features)
280    with self.test_session():
281      self.assertAllEqual([[20., 110.], [11., 11.]], features['price'].eval())
282
283  def test_normalizer_fn_must_be_callable(self):
284    with self.assertRaisesRegexp(TypeError, 'must be a callable'):
285      fc.numeric_column('price', normalizer_fn='NotACallable')
286
287  def test_normalizer_fn_transform_feature(self):
288
289    def _increment_two(input_tensor):
290      return input_tensor + 2.
291
292    price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
293    output = _transform_features({'price': [[1., 2.], [5., 6.]]}, [price])
294    with self.test_session():
295      self.assertAllEqual([[3., 4.], [7., 8.]], output[price].eval())
296
297  def test_get_dense_tensor(self):
298
299    def _increment_two(input_tensor):
300      return input_tensor + 2.
301
302    price = fc.numeric_column('price', shape=[2], normalizer_fn=_increment_two)
303    builder = _LazyBuilder({'price': [[1., 2.], [5., 6.]]})
304    self.assertEqual(builder.get(price), price._get_dense_tensor(builder))
305
306  def test_sparse_tensor_not_supported(self):
307    price = fc.numeric_column('price')
308    builder = _LazyBuilder({
309        'price':
310            sparse_tensor.SparseTensor(
311                indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
312    })
313    with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
314      price._transform_feature(builder)
315
316  def test_deep_copy(self):
317    a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3., 2.]])
318    a_copy = copy.deepcopy(a)
319    self.assertEqual(a_copy.name, 'aaa')
320    self.assertEqual(a_copy.shape, (1, 2))
321    self.assertEqual(a_copy.default_value, ((3., 2.),))
322
323  def test_numpy_default_value(self):
324    a = fc.numeric_column(
325        'aaa', shape=[1, 2], default_value=np.array([[3., 2.]]))
326    self.assertEqual(a.default_value, ((3., 2.),))
327
328  def test_linear_model(self):
329    price = fc.numeric_column('price')
330    with ops.Graph().as_default():
331      features = {'price': [[1.], [5.]]}
332      predictions = fc.linear_model(features, [price])
333      bias = get_linear_model_bias()
334      price_var = get_linear_model_column_var(price)
335      with _initialized_session() as sess:
336        self.assertAllClose([0.], bias.eval())
337        self.assertAllClose([[0.]], price_var.eval())
338        self.assertAllClose([[0.], [0.]], predictions.eval())
339        sess.run(price_var.assign([[10.]]))
340        self.assertAllClose([[10.], [50.]], predictions.eval())
341
342
343class BucketizedColumnTest(test.TestCase):
344
345  def test_invalid_source_column_type(self):
346    a = fc.categorical_column_with_hash_bucket('aaa', hash_bucket_size=10)
347    with self.assertRaisesRegexp(
348        ValueError,
349        'source_column must be a column generated with numeric_column'):
350      fc.bucketized_column(a, boundaries=[0, 1])
351
352  def test_invalid_source_column_shape(self):
353    a = fc.numeric_column('aaa', shape=[2, 3])
354    with self.assertRaisesRegexp(
355        ValueError, 'source_column must be one-dimensional column'):
356      fc.bucketized_column(a, boundaries=[0, 1])
357
358  def test_invalid_boundaries(self):
359    a = fc.numeric_column('aaa')
360    with self.assertRaisesRegexp(
361        ValueError, 'boundaries must be a sorted list'):
362      fc.bucketized_column(a, boundaries=None)
363    with self.assertRaisesRegexp(
364        ValueError, 'boundaries must be a sorted list'):
365      fc.bucketized_column(a, boundaries=1.)
366    with self.assertRaisesRegexp(
367        ValueError, 'boundaries must be a sorted list'):
368      fc.bucketized_column(a, boundaries=[1, 0])
369    with self.assertRaisesRegexp(
370        ValueError, 'boundaries must be a sorted list'):
371      fc.bucketized_column(a, boundaries=[1, 1])
372
373  def test_name(self):
374    a = fc.numeric_column('aaa', dtype=dtypes.int32)
375    b = fc.bucketized_column(a, boundaries=[0, 1])
376    self.assertEqual('aaa_bucketized', b.name)
377
378  def test_var_scope_name(self):
379    a = fc.numeric_column('aaa', dtype=dtypes.int32)
380    b = fc.bucketized_column(a, boundaries=[0, 1])
381    self.assertEqual('aaa_bucketized', b._var_scope_name)
382
383  def test_parse_spec(self):
384    a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
385    b = fc.bucketized_column(a, boundaries=[0, 1])
386    self.assertEqual({
387        'aaa': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32)
388    }, b._parse_example_spec)
389
390  def test_variable_shape(self):
391    a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
392    b = fc.bucketized_column(a, boundaries=[0, 1])
393    # Column 'aaa` has shape [2] times three buckets -> variable_shape=[2, 3].
394    self.assertAllEqual((2, 3), b._variable_shape)
395
396  def test_num_buckets(self):
397    a = fc.numeric_column('aaa', shape=[2], dtype=dtypes.int32)
398    b = fc.bucketized_column(a, boundaries=[0, 1])
399    # Column 'aaa` has shape [2] times three buckets -> num_buckets=6.
400    self.assertEqual(6, b._num_buckets)
401
402  def test_parse_example(self):
403    price = fc.numeric_column('price', shape=[2])
404    bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
405    data = example_pb2.Example(features=feature_pb2.Features(
406        feature={
407            'price':
408                feature_pb2.Feature(float_list=feature_pb2.FloatList(
409                    value=[20., 110.]))
410        }))
411    features = parsing_ops.parse_example(
412        serialized=[data.SerializeToString()],
413        features=fc.make_parse_example_spec([bucketized_price]))
414    self.assertIn('price', features)
415    with self.test_session():
416      self.assertAllEqual([[20., 110.]], features['price'].eval())
417
418  def test_transform_feature(self):
419    price = fc.numeric_column('price', shape=[2])
420    bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
421    with ops.Graph().as_default():
422      transformed_tensor = _transform_features({
423          'price': [[-1., 1.], [5., 6.]]
424      }, [bucketized_price])
425      with _initialized_session():
426        self.assertAllEqual([[0, 1], [3, 4]],
427                            transformed_tensor[bucketized_price].eval())
428
429  def test_get_dense_tensor_one_input_value(self):
430    """Tests _get_dense_tensor() for input with shape=[1]."""
431    price = fc.numeric_column('price', shape=[1])
432    bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
433    with ops.Graph().as_default():
434      builder = _LazyBuilder({'price': [[-1.], [1.], [5.], [6.]]})
435      with _initialized_session():
436        bucketized_price_tensor = bucketized_price._get_dense_tensor(builder)
437        self.assertAllClose(
438            # One-hot tensor.
439            [[[1., 0., 0., 0., 0.]],
440             [[0., 1., 0., 0., 0.]],
441             [[0., 0., 0., 1., 0.]],
442             [[0., 0., 0., 0., 1.]]],
443            bucketized_price_tensor.eval())
444
445  def test_get_dense_tensor_two_input_values(self):
446    """Tests _get_dense_tensor() for input with shape=[2]."""
447    price = fc.numeric_column('price', shape=[2])
448    bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
449    with ops.Graph().as_default():
450      builder = _LazyBuilder({'price': [[-1., 1.], [5., 6.]]})
451      with _initialized_session():
452        bucketized_price_tensor = bucketized_price._get_dense_tensor(builder)
453        self.assertAllClose(
454            # One-hot tensor.
455            [[[1., 0., 0., 0., 0.], [0., 1., 0., 0., 0.]],
456             [[0., 0., 0., 1., 0.], [0., 0., 0., 0., 1.]]],
457            bucketized_price_tensor.eval())
458
459  def test_get_sparse_tensors_one_input_value(self):
460    """Tests _get_sparse_tensors() for input with shape=[1]."""
461    price = fc.numeric_column('price', shape=[1])
462    bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
463    with ops.Graph().as_default():
464      builder = _LazyBuilder({'price': [[-1.], [1.], [5.], [6.]]})
465      with _initialized_session() as sess:
466        id_weight_pair = bucketized_price._get_sparse_tensors(builder)
467        self.assertIsNone(id_weight_pair.weight_tensor)
468        id_tensor_value = sess.run(id_weight_pair.id_tensor)
469        self.assertAllEqual(
470            [[0, 0], [1, 0], [2, 0], [3, 0]], id_tensor_value.indices)
471        self.assertAllEqual([0, 1, 3, 4], id_tensor_value.values)
472        self.assertAllEqual([4, 1], id_tensor_value.dense_shape)
473
474  def test_get_sparse_tensors_two_input_values(self):
475    """Tests _get_sparse_tensors() for input with shape=[2]."""
476    price = fc.numeric_column('price', shape=[2])
477    bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
478    with ops.Graph().as_default():
479      builder = _LazyBuilder({'price': [[-1., 1.], [5., 6.]]})
480      with _initialized_session() as sess:
481        id_weight_pair = bucketized_price._get_sparse_tensors(builder)
482        self.assertIsNone(id_weight_pair.weight_tensor)
483        id_tensor_value = sess.run(id_weight_pair.id_tensor)
484        self.assertAllEqual(
485            [[0, 0], [0, 1], [1, 0], [1, 1]], id_tensor_value.indices)
486        # Values 0-4 correspond to the first column of the input price.
487        # Values 5-9 correspond to the second column of the input price.
488        self.assertAllEqual([0, 6, 3, 9], id_tensor_value.values)
489        self.assertAllEqual([2, 2], id_tensor_value.dense_shape)
490
491  def test_sparse_tensor_input_not_supported(self):
492    price = fc.numeric_column('price')
493    bucketized_price = fc.bucketized_column(price, boundaries=[0, 1])
494    builder = _LazyBuilder({
495        'price':
496            sparse_tensor.SparseTensor(
497                indices=[[0, 0]], values=[0.3], dense_shape=[1, 1])
498    })
499    with self.assertRaisesRegexp(ValueError, 'must be a Tensor'):
500      bucketized_price._transform_feature(builder)
501
502  def test_deep_copy(self):
503    a = fc.numeric_column('aaa', shape=[2])
504    a_bucketized = fc.bucketized_column(a, boundaries=[0, 1])
505    a_bucketized_copy = copy.deepcopy(a_bucketized)
506    self.assertEqual(a_bucketized_copy.name, 'aaa_bucketized')
507    self.assertAllEqual(a_bucketized_copy._variable_shape, (2, 3))
508    self.assertEqual(a_bucketized_copy.boundaries, (0, 1))
509
510  def test_linear_model_one_input_value(self):
511    """Tests linear_model() for input with shape=[1]."""
512    price = fc.numeric_column('price', shape=[1])
513    bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
514    with ops.Graph().as_default():
515      features = {'price': [[-1.], [1.], [5.], [6.]]}
516      predictions = fc.linear_model(features, [bucketized_price])
517      bias = get_linear_model_bias()
518      bucketized_price_var = get_linear_model_column_var(bucketized_price)
519      with _initialized_session() as sess:
520        self.assertAllClose([0.], bias.eval())
521        # One weight variable per bucket, all initialized to zero.
522        self.assertAllClose(
523            [[0.], [0.], [0.], [0.], [0.]], bucketized_price_var.eval())
524        self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
525        sess.run(bucketized_price_var.assign(
526            [[10.], [20.], [30.], [40.], [50.]]))
527        # price -1. is in the 0th bucket, whose weight is 10.
528        # price 1. is in the 1st bucket, whose weight is 20.
529        # price 5. is in the 3rd bucket, whose weight is 40.
530        # price 6. is in the 4th bucket, whose weight is 50.
531        self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
532        sess.run(bias.assign([1.]))
533        self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
534
535  def test_linear_model_two_input_values(self):
536    """Tests linear_model() for input with shape=[2]."""
537    price = fc.numeric_column('price', shape=[2])
538    bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
539    with ops.Graph().as_default():
540      features = {'price': [[-1., 1.], [5., 6.]]}
541      predictions = fc.linear_model(features, [bucketized_price])
542      bias = get_linear_model_bias()
543      bucketized_price_var = get_linear_model_column_var(bucketized_price)
544      with _initialized_session() as sess:
545        self.assertAllClose([0.], bias.eval())
546        # One weight per bucket per input column, all initialized to zero.
547        self.assertAllClose(
548            [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
549            bucketized_price_var.eval())
550        self.assertAllClose([[0.], [0.]], predictions.eval())
551        sess.run(bucketized_price_var.assign(
552            [[10.], [20.], [30.], [40.], [50.],
553             [60.], [70.], [80.], [90.], [100.]]))
554        # 1st example:
555        #   price -1. is in the 0th bucket, whose weight is 10.
556        #   price 1. is in the 6th bucket, whose weight is 70.
557        # 2nd example:
558        #   price 5. is in the 3rd bucket, whose weight is 40.
559        #   price 6. is in the 9th bucket, whose weight is 100.
560        self.assertAllClose([[80.], [140.]], predictions.eval())
561        sess.run(bias.assign([1.]))
562        self.assertAllClose([[81.], [141.]], predictions.eval())
563
564
565class HashedCategoricalColumnTest(test.TestCase):
566
567  def test_defaults(self):
568    a = fc.categorical_column_with_hash_bucket('aaa', 10)
569    self.assertEqual('aaa', a.name)
570    self.assertEqual('aaa', a._var_scope_name)
571    self.assertEqual('aaa', a.key)
572    self.assertEqual(10, a.hash_bucket_size)
573    self.assertEqual(dtypes.string, a.dtype)
574
575  def test_bucket_size_should_be_given(self):
576    with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'):
577      fc.categorical_column_with_hash_bucket('aaa', None)
578
579  def test_bucket_size_should_be_positive(self):
580    with self.assertRaisesRegexp(ValueError,
581                                 'hash_bucket_size must be at least 1'):
582      fc.categorical_column_with_hash_bucket('aaa', 0)
583
584  def test_dtype_should_be_string_or_integer(self):
585    fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.string)
586    fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
587    with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
588      fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.float32)
589
590  def test_deep_copy(self):
591    original = fc.categorical_column_with_hash_bucket('aaa', 10)
592    for column in (original, copy.deepcopy(original)):
593      self.assertEqual('aaa', column.name)
594      self.assertEqual(10, column.hash_bucket_size)
595      self.assertEqual(10, column._num_buckets)
596      self.assertEqual(dtypes.string, column.dtype)
597
598  def test_parse_spec_string(self):
599    a = fc.categorical_column_with_hash_bucket('aaa', 10)
600    self.assertEqual({
601        'aaa': parsing_ops.VarLenFeature(dtypes.string)
602    }, a._parse_example_spec)
603
604  def test_parse_spec_int(self):
605    a = fc.categorical_column_with_hash_bucket('aaa', 10, dtype=dtypes.int32)
606    self.assertEqual({
607        'aaa': parsing_ops.VarLenFeature(dtypes.int32)
608    }, a._parse_example_spec)
609
610  def test_parse_example(self):
611    a = fc.categorical_column_with_hash_bucket('aaa', 10)
612    data = example_pb2.Example(features=feature_pb2.Features(
613        feature={
614            'aaa':
615                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
616                    value=[b'omar', b'stringer']))
617        }))
618    features = parsing_ops.parse_example(
619        serialized=[data.SerializeToString()],
620        features=fc.make_parse_example_spec([a]))
621    self.assertIn('aaa', features)
622    with self.test_session():
623      _assert_sparse_tensor_value(
624          self,
625          sparse_tensor.SparseTensorValue(
626              indices=[[0, 0], [0, 1]],
627              values=np.array([b'omar', b'stringer'], dtype=np.object_),
628              dense_shape=[1, 2]),
629          features['aaa'].eval())
630
631  def test_strings_should_be_hashed(self):
632    hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
633    wire_tensor = sparse_tensor.SparseTensor(
634        values=['omar', 'stringer', 'marlo'],
635        indices=[[0, 0], [1, 0], [1, 1]],
636        dense_shape=[2, 2])
637    outputs = _transform_features({'wire': wire_tensor}, [hashed_sparse])
638    output = outputs[hashed_sparse]
639    # Check exact hashed output. If hashing changes this test will break.
640    expected_values = [6, 4, 1]
641    with self.test_session():
642      self.assertEqual(dtypes.int64, output.values.dtype)
643      self.assertAllEqual(expected_values, output.values.eval())
644      self.assertAllEqual(wire_tensor.indices.eval(), output.indices.eval())
645      self.assertAllEqual(wire_tensor.dense_shape.eval(),
646                          output.dense_shape.eval())
647
648  def test_tensor_dtype_should_be_string_or_integer(self):
649    string_fc = fc.categorical_column_with_hash_bucket(
650        'a_string', 10, dtype=dtypes.string)
651    int_fc = fc.categorical_column_with_hash_bucket(
652        'a_int', 10, dtype=dtypes.int32)
653    float_fc = fc.categorical_column_with_hash_bucket(
654        'a_float', 10, dtype=dtypes.string)
655    int_tensor = sparse_tensor.SparseTensor(
656        values=[101],
657        indices=[[0, 0]],
658        dense_shape=[1, 1])
659    string_tensor = sparse_tensor.SparseTensor(
660        values=['101'],
661        indices=[[0, 0]],
662        dense_shape=[1, 1])
663    float_tensor = sparse_tensor.SparseTensor(
664        values=[101.],
665        indices=[[0, 0]],
666        dense_shape=[1, 1])
667    builder = _LazyBuilder({
668        'a_int': int_tensor,
669        'a_string': string_tensor,
670        'a_float': float_tensor
671    })
672    builder.get(string_fc)
673    builder.get(int_fc)
674    with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
675      builder.get(float_fc)
676
677  def test_dtype_should_match_with_tensor(self):
678    hashed_sparse = fc.categorical_column_with_hash_bucket(
679        'wire', 10, dtype=dtypes.int64)
680    wire_tensor = sparse_tensor.SparseTensor(
681        values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
682    builder = _LazyBuilder({'wire': wire_tensor})
683    with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
684      builder.get(hashed_sparse)
685
686  def test_ints_should_be_hashed(self):
687    hashed_sparse = fc.categorical_column_with_hash_bucket(
688        'wire', 10, dtype=dtypes.int64)
689    wire_tensor = sparse_tensor.SparseTensor(
690        values=[101, 201, 301],
691        indices=[[0, 0], [1, 0], [1, 1]],
692        dense_shape=[2, 2])
693    builder = _LazyBuilder({'wire': wire_tensor})
694    output = builder.get(hashed_sparse)
695    # Check exact hashed output. If hashing changes this test will break.
696    expected_values = [3, 7, 5]
697    with self.test_session():
698      self.assertAllEqual(expected_values, output.values.eval())
699
700  def test_int32_64_is_compatible(self):
701    hashed_sparse = fc.categorical_column_with_hash_bucket(
702        'wire', 10, dtype=dtypes.int64)
703    wire_tensor = sparse_tensor.SparseTensor(
704        values=constant_op.constant([101, 201, 301], dtype=dtypes.int32),
705        indices=[[0, 0], [1, 0], [1, 1]],
706        dense_shape=[2, 2])
707    builder = _LazyBuilder({'wire': wire_tensor})
708    output = builder.get(hashed_sparse)
709    # Check exact hashed output. If hashing changes this test will break.
710    expected_values = [3, 7, 5]
711    with self.test_session():
712      self.assertAllEqual(expected_values, output.values.eval())
713
714  def test_get_sparse_tensors(self):
715    hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
716    builder = _LazyBuilder({
717        'wire':
718            sparse_tensor.SparseTensor(
719                values=['omar', 'stringer', 'marlo'],
720                indices=[[0, 0], [1, 0], [1, 1]],
721                dense_shape=[2, 2])
722    })
723    id_weight_pair = hashed_sparse._get_sparse_tensors(builder)
724    self.assertIsNone(id_weight_pair.weight_tensor)
725    self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
726
727  def test_get_sparse_tensors_weight_collections(self):
728    column = fc.categorical_column_with_hash_bucket('aaa', 10)
729    inputs = sparse_tensor.SparseTensor(
730        values=['omar', 'stringer', 'marlo'],
731        indices=[[0, 0], [1, 0], [1, 1]],
732        dense_shape=[2, 2])
733    column._get_sparse_tensors(
734        _LazyBuilder({
735            'aaa': inputs
736        }), weight_collections=('my_weights',))
737
738    self.assertItemsEqual(
739        [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
740    self.assertItemsEqual([], ops.get_collection('my_weights'))
741
742  def test_get_sparse_tensors_dense_input(self):
743    hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
744    builder = _LazyBuilder({'wire': (('omar', ''), ('stringer', 'marlo'))})
745    id_weight_pair = hashed_sparse._get_sparse_tensors(builder)
746    self.assertIsNone(id_weight_pair.weight_tensor)
747    self.assertEqual(builder.get(hashed_sparse), id_weight_pair.id_tensor)
748
749  def test_linear_model(self):
750    wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
751    self.assertEqual(4, wire_column._num_buckets)
752    with ops.Graph().as_default():
753      predictions = fc.linear_model({
754          wire_column.name: sparse_tensor.SparseTensorValue(
755              indices=((0, 0), (1, 0), (1, 1)),
756              values=('marlo', 'skywalker', 'omar'),
757              dense_shape=(2, 2))
758      }, (wire_column,))
759      bias = get_linear_model_bias()
760      wire_var = get_linear_model_column_var(wire_column)
761      with _initialized_session():
762        self.assertAllClose((0.,), bias.eval())
763        self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
764        self.assertAllClose(((0.,), (0.,)), predictions.eval())
765        wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
766        # 'marlo' -> 3: wire_var[3] = 4
767        # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
768        self.assertAllClose(((4.,), (6.,)), predictions.eval())
769
770
771class CrossedColumnTest(test.TestCase):
772
773  def test_keys_empty(self):
774    with self.assertRaisesRegexp(
775        ValueError, 'keys must be a list with length > 1'):
776      fc.crossed_column([], 10)
777
778  def test_keys_length_one(self):
779    with self.assertRaisesRegexp(
780        ValueError, 'keys must be a list with length > 1'):
781      fc.crossed_column(['a'], 10)
782
783  def test_key_type_unsupported(self):
784    with self.assertRaisesRegexp(ValueError, 'Unsupported key type'):
785      fc.crossed_column(['a', fc.numeric_column('c')], 10)
786
787    with self.assertRaisesRegexp(
788        ValueError, 'categorical_column_with_hash_bucket is not supported'):
789      fc.crossed_column(
790          ['a', fc.categorical_column_with_hash_bucket('c', 10)], 10)
791
792  def test_hash_bucket_size_negative(self):
793    with self.assertRaisesRegexp(
794        ValueError, 'hash_bucket_size must be > 1'):
795      fc.crossed_column(['a', 'c'], -1)
796
797  def test_hash_bucket_size_zero(self):
798    with self.assertRaisesRegexp(
799        ValueError, 'hash_bucket_size must be > 1'):
800      fc.crossed_column(['a', 'c'], 0)
801
802  def test_hash_bucket_size_none(self):
803    with self.assertRaisesRegexp(
804        ValueError, 'hash_bucket_size must be > 1'):
805      fc.crossed_column(['a', 'c'], None)
806
807  def test_name(self):
808    a = fc.numeric_column('a', dtype=dtypes.int32)
809    b = fc.bucketized_column(a, boundaries=[0, 1])
810    crossed1 = fc.crossed_column(['d1', 'd2'], 10)
811
812    crossed2 = fc.crossed_column([b, 'c', crossed1], 10)
813    self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
814
815  def test_name_ordered_alphabetically(self):
816    """Tests that the name does not depend on the order of given columns."""
817    a = fc.numeric_column('a', dtype=dtypes.int32)
818    b = fc.bucketized_column(a, boundaries=[0, 1])
819    crossed1 = fc.crossed_column(['d1', 'd2'], 10)
820
821    crossed2 = fc.crossed_column([crossed1, 'c', b], 10)
822    self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
823
824  def test_name_leaf_keys_ordered_alphabetically(self):
825    """Tests that the name does not depend on the order of given columns."""
826    a = fc.numeric_column('a', dtype=dtypes.int32)
827    b = fc.bucketized_column(a, boundaries=[0, 1])
828    crossed1 = fc.crossed_column(['d2', 'c'], 10)
829
830    crossed2 = fc.crossed_column([crossed1, 'd1', b], 10)
831    self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2.name)
832
833  def test_var_scope_name(self):
834    a = fc.numeric_column('a', dtype=dtypes.int32)
835    b = fc.bucketized_column(a, boundaries=[0, 1])
836    crossed1 = fc.crossed_column(['d1', 'd2'], 10)
837
838    crossed2 = fc.crossed_column([b, 'c', crossed1], 10)
839    self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2._var_scope_name)
840
841  def test_parse_spec(self):
842    a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
843    b = fc.bucketized_column(a, boundaries=[0, 1])
844    crossed = fc.crossed_column([b, 'c'], 10)
845    self.assertEqual({
846        'a': parsing_ops.FixedLenFeature((2,), dtype=dtypes.int32),
847        'c': parsing_ops.VarLenFeature(dtypes.string),
848    }, crossed._parse_example_spec)
849
850  def test_num_buckets(self):
851    a = fc.numeric_column('a', shape=[2], dtype=dtypes.int32)
852    b = fc.bucketized_column(a, boundaries=[0, 1])
853    crossed = fc.crossed_column([b, 'c'], 15)
854    self.assertEqual(15, crossed._num_buckets)
855
856  def test_deep_copy(self):
857    a = fc.numeric_column('a', dtype=dtypes.int32)
858    b = fc.bucketized_column(a, boundaries=[0, 1])
859    crossed1 = fc.crossed_column(['d1', 'd2'], 10)
860    crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
861    crossed2_copy = copy.deepcopy(crossed2)
862    self.assertEqual('a_bucketized_X_c_X_d1_X_d2', crossed2_copy.name,)
863    self.assertEqual(15, crossed2_copy.hash_bucket_size)
864    self.assertEqual(5, crossed2_copy.hash_key)
865
866  def test_parse_example(self):
867    price = fc.numeric_column('price', shape=[2])
868    bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
869    price_cross_wire = fc.crossed_column([bucketized_price, 'wire'], 10)
870    data = example_pb2.Example(features=feature_pb2.Features(
871        feature={
872            'price':
873                feature_pb2.Feature(float_list=feature_pb2.FloatList(
874                    value=[20., 110.])),
875            'wire':
876                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
877                    value=[b'omar', b'stringer'])),
878        }))
879    features = parsing_ops.parse_example(
880        serialized=[data.SerializeToString()],
881        features=fc.make_parse_example_spec([price_cross_wire]))
882    self.assertIn('price', features)
883    self.assertIn('wire', features)
884    with self.test_session():
885      self.assertAllEqual([[20., 110.]], features['price'].eval())
886      wire_sparse = features['wire']
887      self.assertAllEqual([[0, 0], [0, 1]], wire_sparse.indices.eval())
888      # Use byte constants to pass the open-source test.
889      self.assertAllEqual([b'omar', b'stringer'], wire_sparse.values.eval())
890      self.assertAllEqual([1, 2], wire_sparse.dense_shape.eval())
891
892  def test_transform_feature(self):
893    price = fc.numeric_column('price', shape=[2])
894    bucketized_price = fc.bucketized_column(price, boundaries=[0, 50])
895    hash_bucket_size = 10
896    price_cross_wire = fc.crossed_column(
897        [bucketized_price, 'wire'], hash_bucket_size)
898    features = {
899        'price': constant_op.constant([[1., 2.], [5., 6.]]),
900        'wire': sparse_tensor.SparseTensor(
901            values=['omar', 'stringer', 'marlo'],
902            indices=[[0, 0], [1, 0], [1, 1]],
903            dense_shape=[2, 2]),
904    }
905    outputs = _transform_features(features, [price_cross_wire])
906    output = outputs[price_cross_wire]
907    with self.test_session() as sess:
908      output_val = sess.run(output)
909      self.assertAllEqual(
910          [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]], output_val.indices)
911      for val in output_val.values:
912        self.assertIn(val, list(range(hash_bucket_size)))
913      self.assertAllEqual([2, 4], output_val.dense_shape)
914
915  def test_get_sparse_tensors(self):
916    a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
917    b = fc.bucketized_column(a, boundaries=(0, 1))
918    crossed1 = fc.crossed_column(['d1', 'd2'], 10)
919    crossed2 = fc.crossed_column([b, 'c', crossed1], 15, hash_key=5)
920    with ops.Graph().as_default():
921      builder = _LazyBuilder({
922          'a':
923              constant_op.constant(((-1., .5), (.5, 1.))),
924          'c':
925              sparse_tensor.SparseTensor(
926                  indices=((0, 0), (1, 0), (1, 1)),
927                  values=['cA', 'cB', 'cC'],
928                  dense_shape=(2, 2)),
929          'd1':
930              sparse_tensor.SparseTensor(
931                  indices=((0, 0), (1, 0), (1, 1)),
932                  values=['d1A', 'd1B', 'd1C'],
933                  dense_shape=(2, 2)),
934          'd2':
935              sparse_tensor.SparseTensor(
936                  indices=((0, 0), (1, 0), (1, 1)),
937                  values=['d2A', 'd2B', 'd2C'],
938                  dense_shape=(2, 2)),
939      })
940      id_weight_pair = crossed2._get_sparse_tensors(builder)
941      with _initialized_session():
942        id_tensor_eval = id_weight_pair.id_tensor.eval()
943        self.assertAllEqual(
944            ((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5),
945             (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (1, 12), (1, 13),
946             (1, 14), (1, 15)),
947            id_tensor_eval.indices)
948        # Check exact hashed output. If hashing changes this test will break.
949        # All values are within [0, hash_bucket_size).
950        expected_values = (
951            6, 14, 0, 13, 8, 8, 10, 12, 2, 0, 1, 9, 8, 12, 2, 0, 10, 11)
952        self.assertAllEqual(expected_values, id_tensor_eval.values)
953        self.assertAllEqual((2, 16), id_tensor_eval.dense_shape)
954
955  def test_get_sparse_tensors_simple(self):
956    """Same as test_get_sparse_tensors, but with simpler values."""
957    a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
958    b = fc.bucketized_column(a, boundaries=(0, 1))
959    crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
960    with ops.Graph().as_default():
961      builder = _LazyBuilder({
962          'a':
963              constant_op.constant(((-1., .5), (.5, 1.))),
964          'c':
965              sparse_tensor.SparseTensor(
966                  indices=((0, 0), (1, 0), (1, 1)),
967                  values=['cA', 'cB', 'cC'],
968                  dense_shape=(2, 2)),
969      })
970      id_weight_pair = crossed._get_sparse_tensors(builder)
971      with _initialized_session():
972        id_tensor_eval = id_weight_pair.id_tensor.eval()
973        self.assertAllEqual(
974            ((0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (1, 3)),
975            id_tensor_eval.indices)
976        # Check exact hashed output. If hashing changes this test will break.
977        # All values are within [0, hash_bucket_size).
978        expected_values = (1, 0, 1, 3, 4, 2)
979        self.assertAllEqual(expected_values, id_tensor_eval.values)
980        self.assertAllEqual((2, 4), id_tensor_eval.dense_shape)
981
982  def test_linear_model(self):
983    """Tests linear_model.
984
985    Uses data from test_get_sparse_tesnsors_simple.
986    """
987    a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
988    b = fc.bucketized_column(a, boundaries=(0, 1))
989    crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
990    with ops.Graph().as_default():
991      predictions = fc.linear_model({
992          'a': constant_op.constant(((-1., .5), (.5, 1.))),
993          'c': sparse_tensor.SparseTensor(
994              indices=((0, 0), (1, 0), (1, 1)),
995              values=['cA', 'cB', 'cC'],
996              dense_shape=(2, 2)),
997      }, (crossed,))
998      bias = get_linear_model_bias()
999      crossed_var = get_linear_model_column_var(crossed)
1000      with _initialized_session() as sess:
1001        self.assertAllClose((0.,), bias.eval())
1002        self.assertAllClose(
1003            ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
1004        self.assertAllClose(((0.,), (0.,)), predictions.eval())
1005        sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
1006        # Expected ids after cross = (1, 0, 1, 3, 4, 2)
1007        self.assertAllClose(((3.,), (14.,)), predictions.eval())
1008        sess.run(bias.assign((.1,)))
1009        self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
1010
1011  def test_linear_model_with_weights(self):
1012    class _TestColumnWithWeights(_CategoricalColumn):
1013      """Produces sparse IDs and sparse weights."""
1014
1015      @property
1016      def name(self):
1017        return 'test_column'
1018
1019      @property
1020      def _parse_example_spec(self):
1021        return {
1022            self.name: parsing_ops.VarLenFeature(dtypes.int32),
1023            '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
1024                dtypes.float32),
1025            }
1026
1027      @property
1028      def _num_buckets(self):
1029        return 5
1030
1031      def _transform_feature(self, inputs):
1032        return (inputs.get(self.name),
1033                inputs.get('{}_weights'.format(self.name)))
1034
1035      def _get_sparse_tensors(self, inputs, weight_collections=None,
1036                              trainable=None):
1037        """Populates both id_tensor and weight_tensor."""
1038        ids_and_weights = inputs.get(self)
1039        return _CategoricalColumn.IdWeightPair(
1040            id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
1041
1042    t = _TestColumnWithWeights()
1043    crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
1044    with ops.Graph().as_default():
1045      with self.assertRaisesRegexp(
1046          ValueError,
1047          'crossed_column does not support weight_tensor.*{}'.format(t.name)):
1048        fc.linear_model({
1049            t.name: sparse_tensor.SparseTensor(
1050                indices=((0, 0), (1, 0), (1, 1)),
1051                values=[0, 1, 2],
1052                dense_shape=(2, 2)),
1053            '{}_weights'.format(t.name): sparse_tensor.SparseTensor(
1054                indices=((0, 0), (1, 0), (1, 1)),
1055                values=[1., 10., 2.],
1056                dense_shape=(2, 2)),
1057            'c': sparse_tensor.SparseTensor(
1058                indices=((0, 0), (1, 0), (1, 1)),
1059                values=['cA', 'cB', 'cC'],
1060                dense_shape=(2, 2)),
1061        }, (crossed,))
1062
1063
1064def get_linear_model_bias():
1065  with variable_scope.variable_scope('linear_model', reuse=True):
1066    return variable_scope.get_variable('bias_weights')
1067
1068
1069def get_linear_model_column_var(column):
1070  return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
1071                            'linear_model/' + column.name)[0]
1072
1073
1074@test_util.with_c_api
1075class LinearModelTest(test.TestCase):
1076
1077  def test_raises_if_empty_feature_columns(self):
1078    with self.assertRaisesRegexp(ValueError,
1079                                 'feature_columns must not be empty'):
1080      fc.linear_model(features={}, feature_columns=[])
1081
1082  def test_should_be_feature_column(self):
1083    with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
1084      fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
1085
1086  def test_should_be_dense_or_categorical_column(self):
1087
1088    class NotSupportedColumn(_FeatureColumn):
1089
1090      @property
1091      def name(self):
1092        return 'NotSupportedColumn'
1093
1094      def _transform_feature(self, cache):
1095        pass
1096
1097      @property
1098      def _parse_example_spec(self):
1099        pass
1100
1101    with self.assertRaisesRegexp(
1102        ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
1103      fc.linear_model(
1104          features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
1105
1106  def test_does_not_support_dict_columns(self):
1107    with self.assertRaisesRegexp(
1108        ValueError, 'Expected feature_columns to be iterable, found dict.'):
1109      fc.linear_model(
1110          features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
1111
1112  def test_raises_if_duplicate_name(self):
1113    with self.assertRaisesRegexp(
1114        ValueError, 'Duplicate feature column name found for columns'):
1115      fc.linear_model(
1116          features={'a': [[0]]},
1117          feature_columns=[fc.numeric_column('a'),
1118                           fc.numeric_column('a')])
1119
1120  def test_dense_bias(self):
1121    price = fc.numeric_column('price')
1122    with ops.Graph().as_default():
1123      features = {'price': [[1.], [5.]]}
1124      predictions = fc.linear_model(features, [price])
1125      bias = get_linear_model_bias()
1126      price_var = get_linear_model_column_var(price)
1127      with _initialized_session() as sess:
1128        self.assertAllClose([0.], bias.eval())
1129        sess.run(price_var.assign([[10.]]))
1130        sess.run(bias.assign([5.]))
1131        self.assertAllClose([[15.], [55.]], predictions.eval())
1132
1133  def test_sparse_bias(self):
1134    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1135    with ops.Graph().as_default():
1136      wire_tensor = sparse_tensor.SparseTensor(
1137          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]
1138          indices=[[0, 0], [1, 0], [1, 1]],
1139          dense_shape=[2, 2])
1140      features = {'wire_cast': wire_tensor}
1141      predictions = fc.linear_model(features, [wire_cast])
1142      bias = get_linear_model_bias()
1143      wire_cast_var = get_linear_model_column_var(wire_cast)
1144      with _initialized_session() as sess:
1145        self.assertAllClose([0.], bias.eval())
1146        self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
1147        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
1148        sess.run(bias.assign([5.]))
1149        self.assertAllClose([[1005.], [10015.]], predictions.eval())
1150
1151  def test_dense_and_sparse_bias(self):
1152    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1153    price = fc.numeric_column('price')
1154    with ops.Graph().as_default():
1155      wire_tensor = sparse_tensor.SparseTensor(
1156          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]
1157          indices=[[0, 0], [1, 0], [1, 1]],
1158          dense_shape=[2, 2])
1159      features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
1160      predictions = fc.linear_model(features, [wire_cast, price])
1161      bias = get_linear_model_bias()
1162      wire_cast_var = get_linear_model_column_var(wire_cast)
1163      price_var = get_linear_model_column_var(price)
1164      with _initialized_session() as sess:
1165        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
1166        sess.run(bias.assign([5.]))
1167        sess.run(price_var.assign([[10.]]))
1168        self.assertAllClose([[1015.], [10065.]], predictions.eval())
1169
1170  def test_dense_and_sparse_column(self):
1171    """When the column is both dense and sparse, uses sparse tensors."""
1172
1173    class _DenseAndSparseColumn(_DenseColumn, _CategoricalColumn):
1174
1175      @property
1176      def name(self):
1177        return 'dense_and_sparse_column'
1178
1179      @property
1180      def _parse_example_spec(self):
1181        return {self.name: parsing_ops.VarLenFeature(self.dtype)}
1182
1183      def _transform_feature(self, inputs):
1184        return inputs.get(self.name)
1185
1186      @property
1187      def _variable_shape(self):
1188        raise ValueError('Should not use this method.')
1189
1190      def _get_dense_tensor(self, inputs, weight_collections=None,
1191                            trainable=None):
1192        raise ValueError('Should not use this method.')
1193
1194      @property
1195      def _num_buckets(self):
1196        return 4
1197
1198      def _get_sparse_tensors(self, inputs, weight_collections=None,
1199                              trainable=None):
1200        sp_tensor = sparse_tensor.SparseTensor(
1201            indices=[[0, 0], [1, 0], [1, 1]],
1202            values=[2, 0, 3],
1203            dense_shape=[2, 2])
1204        return _CategoricalColumn.IdWeightPair(sp_tensor, None)
1205
1206    dense_and_sparse_column = _DenseAndSparseColumn()
1207    with ops.Graph().as_default():
1208      sp_tensor = sparse_tensor.SparseTensor(
1209          values=['omar', 'stringer', 'marlo'],
1210          indices=[[0, 0], [1, 0], [1, 1]],
1211          dense_shape=[2, 2])
1212      features = {dense_and_sparse_column.name: sp_tensor}
1213      predictions = fc.linear_model(features, [dense_and_sparse_column])
1214      bias = get_linear_model_bias()
1215      dense_and_sparse_column_var = get_linear_model_column_var(
1216          dense_and_sparse_column)
1217      with _initialized_session() as sess:
1218        sess.run(dense_and_sparse_column_var.assign(
1219            [[10.], [100.], [1000.], [10000.]]))
1220        sess.run(bias.assign([5.]))
1221        self.assertAllClose([[1005.], [10015.]], predictions.eval())
1222
1223  def test_dense_multi_output(self):
1224    price = fc.numeric_column('price')
1225    with ops.Graph().as_default():
1226      features = {'price': [[1.], [5.]]}
1227      predictions = fc.linear_model(features, [price], units=3)
1228      bias = get_linear_model_bias()
1229      price_var = get_linear_model_column_var(price)
1230      with _initialized_session() as sess:
1231        self.assertAllClose(np.zeros((3,)), bias.eval())
1232        self.assertAllClose(np.zeros((1, 3)), price_var.eval())
1233        sess.run(price_var.assign([[10., 100., 1000.]]))
1234        sess.run(bias.assign([5., 6., 7.]))
1235        self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
1236                            predictions.eval())
1237
1238  def test_sparse_multi_output(self):
1239    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1240    with ops.Graph().as_default():
1241      wire_tensor = sparse_tensor.SparseTensor(
1242          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]
1243          indices=[[0, 0], [1, 0], [1, 1]],
1244          dense_shape=[2, 2])
1245      features = {'wire_cast': wire_tensor}
1246      predictions = fc.linear_model(features, [wire_cast], units=3)
1247      bias = get_linear_model_bias()
1248      wire_cast_var = get_linear_model_column_var(wire_cast)
1249      with _initialized_session() as sess:
1250        self.assertAllClose(np.zeros((3,)), bias.eval())
1251        self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
1252        sess.run(
1253            wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.], [
1254                1000., 1100., 1200.
1255            ], [10000., 11000., 12000.]]))
1256        sess.run(bias.assign([5., 6., 7.]))
1257        self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
1258                            predictions.eval())
1259
1260  def test_dense_multi_dimension(self):
1261    price = fc.numeric_column('price', shape=2)
1262    with ops.Graph().as_default():
1263      features = {'price': [[1., 2.], [5., 6.]]}
1264      predictions = fc.linear_model(features, [price])
1265      price_var = get_linear_model_column_var(price)
1266      with _initialized_session() as sess:
1267        self.assertAllClose([[0.], [0.]], price_var.eval())
1268        sess.run(price_var.assign([[10.], [100.]]))
1269        self.assertAllClose([[210.], [650.]], predictions.eval())
1270
1271  def test_sparse_multi_rank(self):
1272    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1273    with ops.Graph().as_default():
1274      wire_tensor = array_ops.sparse_placeholder(dtypes.string)
1275      wire_value = sparse_tensor.SparseTensorValue(
1276          values=['omar', 'stringer', 'marlo', 'omar'],  # hashed = [2, 0, 3, 2]
1277          indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
1278          dense_shape=[2, 2, 2])
1279      features = {'wire_cast': wire_tensor}
1280      predictions = fc.linear_model(features, [wire_cast])
1281      wire_cast_var = get_linear_model_column_var(wire_cast)
1282      with _initialized_session() as sess:
1283        self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
1284        self.assertAllClose(
1285            np.zeros((2, 1)),
1286            predictions.eval(feed_dict={wire_tensor: wire_value}))
1287        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
1288        self.assertAllClose(
1289            [[1010.], [11000.]],
1290            predictions.eval(feed_dict={wire_tensor: wire_value}))
1291
1292  def test_sparse_combiner(self):
1293    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1294    with ops.Graph().as_default():
1295      wire_tensor = sparse_tensor.SparseTensor(
1296          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]
1297          indices=[[0, 0], [1, 0], [1, 1]],
1298          dense_shape=[2, 2])
1299      features = {'wire_cast': wire_tensor}
1300      predictions = fc.linear_model(
1301          features, [wire_cast], sparse_combiner='mean')
1302      bias = get_linear_model_bias()
1303      wire_cast_var = get_linear_model_column_var(wire_cast)
1304      with _initialized_session() as sess:
1305        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
1306        sess.run(bias.assign([5.]))
1307        self.assertAllClose([[1005.], [5010.]], predictions.eval())
1308
1309  def test_dense_multi_dimension_multi_output(self):
1310    price = fc.numeric_column('price', shape=2)
1311    with ops.Graph().as_default():
1312      features = {'price': [[1., 2.], [5., 6.]]}
1313      predictions = fc.linear_model(features, [price], units=3)
1314      bias = get_linear_model_bias()
1315      price_var = get_linear_model_column_var(price)
1316      with _initialized_session() as sess:
1317        self.assertAllClose(np.zeros((3,)), bias.eval())
1318        self.assertAllClose(np.zeros((2, 3)), price_var.eval())
1319        sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
1320        sess.run(bias.assign([2., 3., 4.]))
1321        self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
1322                            predictions.eval())
1323
1324  def test_raises_if_shape_mismatch(self):
1325    price = fc.numeric_column('price', shape=2)
1326    with ops.Graph().as_default():
1327      features = {'price': [[1.], [5.]]}
1328      if ops._USE_C_API:
1329        with self.assertRaisesRegexp(
1330            Exception,
1331            r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
1332          predictions = fc.linear_model(features, [price])
1333      else:
1334        predictions = fc.linear_model(features, [price])
1335        with _initialized_session():
1336          with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
1337            predictions.eval()
1338
1339  def test_dense_reshaping(self):
1340    price = fc.numeric_column('price', shape=[1, 2])
1341    with ops.Graph().as_default():
1342      features = {'price': [[[1., 2.]], [[5., 6.]]]}
1343      predictions = fc.linear_model(features, [price])
1344      bias = get_linear_model_bias()
1345      price_var = get_linear_model_column_var(price)
1346      with _initialized_session() as sess:
1347        self.assertAllClose([0.], bias.eval())
1348        self.assertAllClose([[0.], [0.]], price_var.eval())
1349        self.assertAllClose([[0.], [0.]], predictions.eval())
1350        sess.run(price_var.assign([[10.], [100.]]))
1351        self.assertAllClose([[210.], [650.]], predictions.eval())
1352
1353  def test_dense_multi_column(self):
1354    price1 = fc.numeric_column('price1', shape=2)
1355    price2 = fc.numeric_column('price2')
1356    with ops.Graph().as_default():
1357      features = {
1358          'price1': [[1., 2.], [5., 6.]],
1359          'price2': [[3.], [4.]]
1360      }
1361      predictions = fc.linear_model(features, [price1, price2])
1362      bias = get_linear_model_bias()
1363      price1_var = get_linear_model_column_var(price1)
1364      price2_var = get_linear_model_column_var(price2)
1365      with _initialized_session() as sess:
1366        self.assertAllClose([0.], bias.eval())
1367        self.assertAllClose([[0.], [0.]], price1_var.eval())
1368        self.assertAllClose([[0.]], price2_var.eval())
1369        self.assertAllClose([[0.], [0.]], predictions.eval())
1370        sess.run(price1_var.assign([[10.], [100.]]))
1371        sess.run(price2_var.assign([[1000.]]))
1372        sess.run(bias.assign([7.]))
1373        self.assertAllClose([[3217.], [4657.]], predictions.eval())
1374
1375  def test_fills_cols_to_vars(self):
1376    price1 = fc.numeric_column('price1', shape=2)
1377    price2 = fc.numeric_column('price2')
1378    with ops.Graph().as_default():
1379      features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
1380      cols_to_vars = {}
1381      fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
1382      bias = get_linear_model_bias()
1383      price1_var = get_linear_model_column_var(price1)
1384      price2_var = get_linear_model_column_var(price2)
1385      self.assertAllEqual(cols_to_vars['bias'], [bias])
1386      self.assertAllEqual(cols_to_vars[price1], [price1_var])
1387      self.assertAllEqual(cols_to_vars[price2], [price2_var])
1388
1389  def test_fills_cols_to_vars_partitioned_variables(self):
1390    price1 = fc.numeric_column('price1', shape=2)
1391    price2 = fc.numeric_column('price2', shape=3)
1392    with ops.Graph().as_default():
1393      features = {
1394          'price1': [[1., 2.], [6., 7.]],
1395          'price2': [[3., 4., 5.], [8., 9., 10.]]
1396      }
1397      cols_to_vars = {}
1398      with variable_scope.variable_scope(
1399          'linear',
1400          partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
1401        fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
1402      with _initialized_session():
1403        self.assertEqual([0.], cols_to_vars['bias'][0].eval())
1404        # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
1405        self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
1406        self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
1407        # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
1408        # a [1, 1] Variable.
1409        self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
1410        self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
1411
1412  def test_dense_collection(self):
1413    price = fc.numeric_column('price')
1414    with ops.Graph().as_default() as g:
1415      features = {'price': [[1.], [5.]]}
1416      fc.linear_model(features, [price], weight_collections=['my-vars'])
1417      my_vars = g.get_collection('my-vars')
1418      bias = get_linear_model_bias()
1419      price_var = get_linear_model_column_var(price)
1420      self.assertIn(bias, my_vars)
1421      self.assertIn(price_var, my_vars)
1422
1423  def test_sparse_collection(self):
1424    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1425    with ops.Graph().as_default() as g:
1426      wire_tensor = sparse_tensor.SparseTensor(
1427          values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
1428      features = {'wire_cast': wire_tensor}
1429      fc.linear_model(
1430          features, [wire_cast], weight_collections=['my-vars'])
1431      my_vars = g.get_collection('my-vars')
1432      bias = get_linear_model_bias()
1433      wire_cast_var = get_linear_model_column_var(wire_cast)
1434      self.assertIn(bias, my_vars)
1435      self.assertIn(wire_cast_var, my_vars)
1436
1437  def test_dense_trainable_default(self):
1438    price = fc.numeric_column('price')
1439    with ops.Graph().as_default() as g:
1440      features = {'price': [[1.], [5.]]}
1441      fc.linear_model(features, [price])
1442      bias = get_linear_model_bias()
1443      price_var = get_linear_model_column_var(price)
1444      trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1445      self.assertIn(bias, trainable_vars)
1446      self.assertIn(price_var, trainable_vars)
1447
1448  def test_sparse_trainable_default(self):
1449    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1450    with ops.Graph().as_default() as g:
1451      wire_tensor = sparse_tensor.SparseTensor(
1452          values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
1453      features = {'wire_cast': wire_tensor}
1454      fc.linear_model(features, [wire_cast])
1455      trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1456      bias = get_linear_model_bias()
1457      wire_cast_var = get_linear_model_column_var(wire_cast)
1458      self.assertIn(bias, trainable_vars)
1459      self.assertIn(wire_cast_var, trainable_vars)
1460
1461  def test_dense_trainable_false(self):
1462    price = fc.numeric_column('price')
1463    with ops.Graph().as_default() as g:
1464      features = {'price': [[1.], [5.]]}
1465      fc.linear_model(features, [price], trainable=False)
1466      trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1467      self.assertEqual([], trainable_vars)
1468
1469  def test_sparse_trainable_false(self):
1470    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1471    with ops.Graph().as_default() as g:
1472      wire_tensor = sparse_tensor.SparseTensor(
1473          values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
1474      features = {'wire_cast': wire_tensor}
1475      fc.linear_model(features, [wire_cast], trainable=False)
1476      trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
1477      self.assertEqual([], trainable_vars)
1478
1479  def test_column_order(self):
1480    price_a = fc.numeric_column('price_a')
1481    price_b = fc.numeric_column('price_b')
1482    wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
1483    with ops.Graph().as_default() as g:
1484      features = {
1485          'price_a': [[1.]],
1486          'price_b': [[3.]],
1487          'wire_cast':
1488              sparse_tensor.SparseTensor(
1489                  values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
1490      }
1491      fc.linear_model(
1492          features, [price_a, wire_cast, price_b],
1493          weight_collections=['my-vars'])
1494      my_vars = g.get_collection('my-vars')
1495      self.assertIn('price_a', my_vars[0].name)
1496      self.assertIn('price_b', my_vars[1].name)
1497      self.assertIn('wire_cast', my_vars[2].name)
1498
1499    with ops.Graph().as_default() as g:
1500      features = {
1501          'price_a': [[1.]],
1502          'price_b': [[3.]],
1503          'wire_cast':
1504              sparse_tensor.SparseTensor(
1505                  values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
1506      }
1507      fc.linear_model(
1508          features, [wire_cast, price_b, price_a],
1509          weight_collections=['my-vars'])
1510      my_vars = g.get_collection('my-vars')
1511      self.assertIn('price_a', my_vars[0].name)
1512      self.assertIn('price_b', my_vars[1].name)
1513      self.assertIn('wire_cast', my_vars[2].name)
1514
1515  def test_static_batch_size_mismatch(self):
1516    price1 = fc.numeric_column('price1')
1517    price2 = fc.numeric_column('price2')
1518    with ops.Graph().as_default():
1519      features = {
1520          'price1': [[1.], [5.], [7.]],  # batchsize = 3
1521          'price2': [[3.], [4.]]  # batchsize = 2
1522      }
1523    with self.assertRaisesRegexp(
1524        ValueError,
1525        'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
1526      fc.linear_model(features, [price1, price2])
1527
1528  def test_subset_of_static_batch_size_mismatch(self):
1529    price1 = fc.numeric_column('price1')
1530    price2 = fc.numeric_column('price2')
1531    price3 = fc.numeric_column('price3')
1532    with ops.Graph().as_default():
1533      features = {
1534          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
1535          'price2': [[3.], [4.]],  # batchsize = 2
1536          'price3': [[3.], [4.], [5.]]  # batchsize = 3
1537      }
1538      with self.assertRaisesRegexp(
1539          ValueError,
1540          'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
1541        fc.linear_model(features, [price1, price2, price3])
1542
1543  def test_runtime_batch_size_mismatch(self):
1544    price1 = fc.numeric_column('price1')
1545    price2 = fc.numeric_column('price2')
1546    with ops.Graph().as_default():
1547      features = {
1548          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
1549          'price2': [[3.], [4.]]  # batchsize = 2
1550      }
1551      predictions = fc.linear_model(features, [price1, price2])
1552      with _initialized_session() as sess:
1553        with self.assertRaisesRegexp(errors.OpError,
1554                                     'must have the same size and shape'):
1555          sess.run(
1556              predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
1557
1558  def test_runtime_batch_size_matches(self):
1559    price1 = fc.numeric_column('price1')
1560    price2 = fc.numeric_column('price2')
1561    with ops.Graph().as_default():
1562      features = {
1563          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
1564          'price2': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
1565      }
1566      predictions = fc.linear_model(features, [price1, price2])
1567      with _initialized_session() as sess:
1568        sess.run(
1569            predictions,
1570            feed_dict={
1571                features['price1']: [[1.], [5.]],
1572                features['price2']: [[1.], [5.]],
1573            })
1574
1575  def test_with_numpy_input_fn(self):
1576    price = fc.numeric_column('price')
1577    price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
1578    body_style = fc.categorical_column_with_vocabulary_list(
1579        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
1580
1581    input_fn = numpy_io.numpy_input_fn(
1582        x={
1583            'price': np.array([-1., 2., 13., 104.]),
1584            'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
1585        },
1586        batch_size=2,
1587        shuffle=False)
1588    features = input_fn()
1589    net = fc.linear_model(features, [price_buckets, body_style])
1590    # self.assertEqual(1 + 3 + 5, net.shape[1])
1591    with _initialized_session() as sess:
1592      coord = coordinator.Coordinator()
1593      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
1594
1595      bias = get_linear_model_bias()
1596      price_buckets_var = get_linear_model_column_var(price_buckets)
1597      body_style_var = get_linear_model_column_var(body_style)
1598
1599      sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
1600      sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
1601      sess.run(bias.assign([5.]))
1602
1603      self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
1604
1605      coord.request_stop()
1606      coord.join(threads)
1607
1608  def test_with_1d_sparse_tensor(self):
1609    price = fc.numeric_column('price')
1610    price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
1611    body_style = fc.categorical_column_with_vocabulary_list(
1612        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
1613
1614    # Provides 1-dim tensor and dense tensor.
1615    features = {
1616        'price': constant_op.constant([-1., 12.,]),
1617        'body-style': sparse_tensor.SparseTensor(
1618            indices=((0,), (1,)),
1619            values=('sedan', 'hardtop'),
1620            dense_shape=(2,)),
1621    }
1622    self.assertEqual(1, features['price'].shape.ndims)
1623    self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
1624
1625    net = fc.linear_model(features, [price_buckets, body_style])
1626    with _initialized_session() as sess:
1627      bias = get_linear_model_bias()
1628      price_buckets_var = get_linear_model_column_var(price_buckets)
1629      body_style_var = get_linear_model_column_var(body_style)
1630
1631      sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
1632      sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
1633      sess.run(bias.assign([5.]))
1634
1635      self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
1636
1637  def test_with_1d_unknown_shape_sparse_tensor(self):
1638    price = fc.numeric_column('price')
1639    price_buckets = fc.bucketized_column(price, boundaries=[0., 10., 100.,])
1640    body_style = fc.categorical_column_with_vocabulary_list(
1641        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
1642    country = fc.categorical_column_with_vocabulary_list(
1643        'country', vocabulary_list=['US', 'JP', 'CA'])
1644
1645    # Provides 1-dim tensor and dense tensor.
1646    features = {
1647        'price': array_ops.placeholder(dtypes.float32),
1648        'body-style': array_ops.sparse_placeholder(dtypes.string),
1649        'country': array_ops.placeholder(dtypes.string),
1650    }
1651    self.assertIsNone(features['price'].shape.ndims)
1652    self.assertIsNone(features['body-style'].get_shape().ndims)
1653
1654    price_data = np.array([-1., 12.])
1655    body_style_data = sparse_tensor.SparseTensorValue(
1656        indices=((0,), (1,)),
1657        values=('sedan', 'hardtop'),
1658        dense_shape=(2,))
1659    country_data = np.array(['US', 'CA'])
1660
1661    net = fc.linear_model(features, [price_buckets, body_style, country])
1662    bias = get_linear_model_bias()
1663    price_buckets_var = get_linear_model_column_var(price_buckets)
1664    body_style_var = get_linear_model_column_var(body_style)
1665    with _initialized_session() as sess:
1666      sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
1667      sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
1668      sess.run(bias.assign([5.]))
1669
1670      self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
1671                          sess.run(
1672                              net,
1673                              feed_dict={
1674                                  features['price']: price_data,
1675                                  features['body-style']: body_style_data,
1676                                  features['country']: country_data
1677                              }))
1678
1679  def test_with_rank_0_feature(self):
1680    price = fc.numeric_column('price')
1681    features = {
1682        'price': constant_op.constant(0),
1683    }
1684    self.assertEqual(0, features['price'].shape.ndims)
1685
1686    # Static rank 0 should fail
1687    with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
1688      fc.linear_model(features, [price])
1689
1690    # Dynamic rank 0 should fail
1691    features = {
1692        'price': array_ops.placeholder(dtypes.float32),
1693    }
1694    net = fc.linear_model(features, [price])
1695    self.assertEqual(1, net.shape[1])
1696    with _initialized_session() as sess:
1697      with self.assertRaisesOpError('Feature .* cannot have rank 0'):
1698        sess.run(net, feed_dict={features['price']: np.array(1)})
1699
1700
1701class InputLayerTest(test.TestCase):
1702
1703  @test_util.run_in_graph_and_eager_modes()
1704  def test_retrieving_input(self):
1705    features = {'a': [0.]}
1706    input_layer = InputLayer(fc.numeric_column('a'))
1707    inputs = self.evaluate(input_layer(features))
1708    self.assertAllClose([[0.]], inputs)
1709
1710  def test_reuses_variables(self):
1711    with context.eager_mode():
1712      sparse_input = sparse_tensor.SparseTensor(
1713          indices=((0, 0), (1, 0), (2, 0)),
1714          values=(0, 1, 2),
1715          dense_shape=(3, 3))
1716
1717      # Create feature columns (categorical and embedding).
1718      categorical_column = fc.categorical_column_with_identity(key='a',
1719                                                               num_buckets=3)
1720      embedding_dimension = 2
1721      def _embedding_column_initializer(shape, dtype, partition_info):
1722        del shape  # unused
1723        del dtype  # unused
1724        del partition_info  # unused
1725        embedding_values = (
1726            (1, 0),  # id 0
1727            (0, 1),  # id 1
1728            (1, 1))  # id 2
1729        return embedding_values
1730      embedding_column = fc.embedding_column(
1731          categorical_column,
1732          dimension=embedding_dimension,
1733          initializer=_embedding_column_initializer)
1734
1735      input_layer = InputLayer([embedding_column])
1736      features = {'a': sparse_input}
1737
1738      inputs = input_layer(features)
1739      variables = input_layer.variables
1740
1741      # Sanity check: test that the inputs are correct.
1742      self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
1743
1744      # Check that only one variable was created.
1745      self.assertEqual(1, len(variables))
1746
1747      # Check that invoking input_layer on the same features does not create
1748      # additional variables
1749      _ = input_layer(features)
1750      self.assertEqual(1, len(variables))
1751      self.assertEqual(variables[0], input_layer.variables[0])
1752
1753  def test_feature_column_input_layer_gradient(self):
1754    with context.eager_mode():
1755      sparse_input = sparse_tensor.SparseTensor(
1756          indices=((0, 0), (1, 0), (2, 0)),
1757          values=(0, 1, 2),
1758          dense_shape=(3, 3))
1759
1760      # Create feature columns (categorical and embedding).
1761      categorical_column = fc.categorical_column_with_identity(key='a',
1762                                                               num_buckets=3)
1763      embedding_dimension = 2
1764
1765      def _embedding_column_initializer(shape, dtype, partition_info):
1766        del shape  # unused
1767        del dtype  # unused
1768        del partition_info  # unused
1769        embedding_values = (
1770            (1, 0),  # id 0
1771            (0, 1),  # id 1
1772            (1, 1))  # id 2
1773        return embedding_values
1774
1775      embedding_column = fc.embedding_column(
1776          categorical_column,
1777          dimension=embedding_dimension,
1778          initializer=_embedding_column_initializer)
1779
1780      input_layer = InputLayer([embedding_column])
1781      features = {'a': sparse_input}
1782
1783      def scale_matrix():
1784        matrix = input_layer(features)
1785        return 2 * matrix
1786
1787      # Sanity check: Verify that scale_matrix returns the correct output.
1788      self.assertAllEqual([[2, 0], [0, 2], [2, 2]], scale_matrix())
1789
1790      # Check that the returned gradient is correct.
1791      grad_function = backprop.implicit_grad(scale_matrix)
1792      grads_and_vars = grad_function()
1793      indexed_slice = grads_and_vars[0][0]
1794      gradient = grads_and_vars[0][0].values
1795
1796      self.assertAllEqual([0, 1, 2], indexed_slice.indices)
1797      self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
1798
1799
1800@test_util.with_c_api
1801class FunctionalInputLayerTest(test.TestCase):
1802
1803  def test_raises_if_empty_feature_columns(self):
1804    with self.assertRaisesRegexp(ValueError,
1805                                 'feature_columns must not be empty'):
1806      fc.input_layer(features={}, feature_columns=[])
1807
1808  def test_should_be_dense_column(self):
1809    with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
1810      fc.input_layer(
1811          features={'a': [[0]]},
1812          feature_columns=[
1813              fc.categorical_column_with_hash_bucket('wire_cast', 4)
1814          ])
1815
1816  def test_does_not_support_dict_columns(self):
1817    with self.assertRaisesRegexp(
1818        ValueError, 'Expected feature_columns to be iterable, found dict.'):
1819      fc.input_layer(
1820          features={'a': [[0]]}, feature_columns={'a': fc.numeric_column('a')})
1821
1822  def test_bare_column(self):
1823    with ops.Graph().as_default():
1824      features = features = {'a': [0.]}
1825      net = fc.input_layer(features, fc.numeric_column('a'))
1826      with _initialized_session():
1827        self.assertAllClose([[0.]], net.eval())
1828
1829  def test_column_generator(self):
1830    with ops.Graph().as_default():
1831      features = features = {'a': [0.], 'b': [1.]}
1832      columns = (fc.numeric_column(key) for key in features)
1833      net = fc.input_layer(features, columns)
1834      with _initialized_session():
1835        self.assertAllClose([[0., 1.]], net.eval())
1836
1837  def test_raises_if_duplicate_name(self):
1838    with self.assertRaisesRegexp(
1839        ValueError, 'Duplicate feature column name found for columns'):
1840      fc.input_layer(
1841          features={'a': [[0]]},
1842          feature_columns=[fc.numeric_column('a'),
1843                           fc.numeric_column('a')])
1844
1845  def test_one_column(self):
1846    price = fc.numeric_column('price')
1847    with ops.Graph().as_default():
1848      features = {'price': [[1.], [5.]]}
1849      net = fc.input_layer(features, [price])
1850      with _initialized_session():
1851        self.assertAllClose([[1.], [5.]], net.eval())
1852
1853  def test_multi_dimension(self):
1854    price = fc.numeric_column('price', shape=2)
1855    with ops.Graph().as_default():
1856      features = {'price': [[1., 2.], [5., 6.]]}
1857      net = fc.input_layer(features, [price])
1858      with _initialized_session():
1859        self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
1860
1861  def test_raises_if_shape_mismatch(self):
1862    price = fc.numeric_column('price', shape=2)
1863    with ops.Graph().as_default():
1864      features = {'price': [[1.], [5.]]}
1865      if ops._USE_C_API:
1866        with self.assertRaisesRegexp(
1867            Exception,
1868            r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
1869          net = fc.input_layer(features, [price])
1870      else:
1871        net = fc.input_layer(features, [price])
1872        with _initialized_session():
1873          with self.assertRaisesRegexp(Exception, 'requested shape has 4'):
1874            net.eval()
1875
1876  def test_reshaping(self):
1877    price = fc.numeric_column('price', shape=[1, 2])
1878    with ops.Graph().as_default():
1879      features = {'price': [[[1., 2.]], [[5., 6.]]]}
1880      net = fc.input_layer(features, [price])
1881      with _initialized_session():
1882        self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
1883
1884  def test_multi_column(self):
1885    price1 = fc.numeric_column('price1', shape=2)
1886    price2 = fc.numeric_column('price2')
1887    with ops.Graph().as_default():
1888      features = {
1889          'price1': [[1., 2.], [5., 6.]],
1890          'price2': [[3.], [4.]]
1891      }
1892      net = fc.input_layer(features, [price1, price2])
1893      with _initialized_session():
1894        self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
1895
1896  def test_fills_cols_to_vars(self):
1897    # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
1898    # _BucketizedColumn, and an _EmbeddingColumn.  Only the _EmbeddingColumn
1899    # creates a Variable.
1900    price1 = fc.numeric_column('price1')
1901    dense_feature = fc.numeric_column('dense_feature')
1902    dense_feature_bucketized = fc.bucketized_column(
1903        dense_feature, boundaries=[0.])
1904    some_sparse_column = fc.categorical_column_with_hash_bucket(
1905        'sparse_feature', hash_bucket_size=5)
1906    some_embedding_column = fc.embedding_column(
1907        some_sparse_column, dimension=10)
1908    with ops.Graph().as_default():
1909      features = {
1910          'price1': [[3.], [4.]],
1911          'dense_feature': [[-1.], [4.]],
1912          'sparse_feature': [['a'], ['x']],
1913      }
1914      cols_to_vars = {}
1915      all_cols = [price1, dense_feature_bucketized, some_embedding_column]
1916      fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
1917      self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
1918      self.assertEqual(0, len(cols_to_vars[price1]))
1919      self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
1920      self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
1921      self.assertIsInstance(cols_to_vars[some_embedding_column][0],
1922                            variables_lib.Variable)
1923      self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
1924
1925  def test_fills_cols_to_vars_partitioned_variables(self):
1926    price1 = fc.numeric_column('price1')
1927    dense_feature = fc.numeric_column('dense_feature')
1928    dense_feature_bucketized = fc.bucketized_column(
1929        dense_feature, boundaries=[0.])
1930    some_sparse_column = fc.categorical_column_with_hash_bucket(
1931        'sparse_feature', hash_bucket_size=5)
1932    some_embedding_column = fc.embedding_column(
1933        some_sparse_column, dimension=10)
1934    with ops.Graph().as_default():
1935      features = {
1936          'price1': [[3.], [4.]],
1937          'dense_feature': [[-1.], [4.]],
1938          'sparse_feature': [['a'], ['x']],
1939      }
1940      cols_to_vars = {}
1941      all_cols = [price1, dense_feature_bucketized, some_embedding_column]
1942      with variable_scope.variable_scope(
1943          'input_from_feature_columns',
1944          partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
1945        fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
1946      self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
1947      self.assertEqual(0, len(cols_to_vars[price1]))
1948      self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
1949      self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
1950      self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
1951      self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
1952      self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
1953
1954  def test_column_order(self):
1955    price_a = fc.numeric_column('price_a')
1956    price_b = fc.numeric_column('price_b')
1957    with ops.Graph().as_default():
1958      features = {
1959          'price_a': [[1.]],
1960          'price_b': [[3.]],
1961      }
1962      net1 = fc.input_layer(features, [price_a, price_b])
1963      net2 = fc.input_layer(features, [price_b, price_a])
1964      with _initialized_session():
1965        self.assertAllClose([[1., 3.]], net1.eval())
1966        self.assertAllClose([[1., 3.]], net2.eval())
1967
1968  def test_fails_for_categorical_column(self):
1969    animal = fc.categorical_column_with_identity('animal', num_buckets=4)
1970    with ops.Graph().as_default():
1971      features = {
1972          'animal':
1973              sparse_tensor.SparseTensor(
1974                  indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
1975      }
1976      with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'):
1977        fc.input_layer(features, [animal])
1978
1979  def test_static_batch_size_mismatch(self):
1980    price1 = fc.numeric_column('price1')
1981    price2 = fc.numeric_column('price2')
1982    with ops.Graph().as_default():
1983      features = {
1984          'price1': [[1.], [5.], [7.]],  # batchsize = 3
1985          'price2': [[3.], [4.]]  # batchsize = 2
1986      }
1987      with self.assertRaisesRegexp(
1988          ValueError,
1989          'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
1990        fc.input_layer(features, [price1, price2])
1991
1992  def test_subset_of_static_batch_size_mismatch(self):
1993    price1 = fc.numeric_column('price1')
1994    price2 = fc.numeric_column('price2')
1995    price3 = fc.numeric_column('price3')
1996    with ops.Graph().as_default():
1997      features = {
1998          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
1999          'price2': [[3.], [4.]],  # batchsize = 2
2000          'price3': [[3.], [4.], [5.]]  # batchsize = 3
2001      }
2002      with self.assertRaisesRegexp(
2003          ValueError,
2004          'Batch size \(first dimension\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string
2005        fc.input_layer(features, [price1, price2, price3])
2006
2007  def test_runtime_batch_size_mismatch(self):
2008    price1 = fc.numeric_column('price1')
2009    price2 = fc.numeric_column('price2')
2010    with ops.Graph().as_default():
2011      features = {
2012          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 3
2013          'price2': [[3.], [4.]]  # batchsize = 2
2014      }
2015      net = fc.input_layer(features, [price1, price2])
2016      with _initialized_session() as sess:
2017        with self.assertRaisesRegexp(errors.OpError,
2018                                     'Dimensions of inputs should match'):
2019          sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
2020
2021  def test_runtime_batch_size_matches(self):
2022    price1 = fc.numeric_column('price1')
2023    price2 = fc.numeric_column('price2')
2024    with ops.Graph().as_default():
2025      features = {
2026          'price1': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
2027          'price2': array_ops.placeholder(dtype=dtypes.int64),  # batchsize = 2
2028      }
2029      net = fc.input_layer(features, [price1, price2])
2030      with _initialized_session() as sess:
2031        sess.run(
2032            net,
2033            feed_dict={
2034                features['price1']: [[1.], [5.]],
2035                features['price2']: [[1.], [5.]],
2036            })
2037
2038  def test_with_numpy_input_fn(self):
2039    embedding_values = (
2040        (1., 2., 3., 4., 5.),  # id 0
2041        (6., 7., 8., 9., 10.),  # id 1
2042        (11., 12., 13., 14., 15.)  # id 2
2043    )
2044    def _initializer(shape, dtype, partition_info):
2045      del shape, dtype, partition_info
2046      return embedding_values
2047
2048    # price has 1 dimension in input_layer
2049    price = fc.numeric_column('price')
2050    body_style = fc.categorical_column_with_vocabulary_list(
2051        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
2052    # one_hot_body_style has 3 dims in input_layer.
2053    one_hot_body_style = fc.indicator_column(body_style)
2054    # embedded_body_style has 5 dims in input_layer.
2055    embedded_body_style = fc.embedding_column(body_style, dimension=5,
2056                                              initializer=_initializer)
2057
2058    input_fn = numpy_io.numpy_input_fn(
2059        x={
2060            'price': np.array([11., 12., 13., 14.]),
2061            'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
2062        },
2063        batch_size=2,
2064        shuffle=False)
2065    features = input_fn()
2066    net = fc.input_layer(features,
2067                         [price, one_hot_body_style, embedded_body_style])
2068    self.assertEqual(1 + 3 + 5, net.shape[1])
2069    with _initialized_session() as sess:
2070      coord = coordinator.Coordinator()
2071      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
2072
2073      # Each row is formed by concatenating `embedded_body_style`,
2074      # `one_hot_body_style`, and `price` in order.
2075      self.assertAllEqual(
2076          [[11., 12., 13., 14., 15., 0., 0., 1., 11.],
2077           [1., 2., 3., 4., 5., 1., 0., 0., 12]],
2078          sess.run(net))
2079
2080      coord.request_stop()
2081      coord.join(threads)
2082
2083  def test_with_1d_sparse_tensor(self):
2084    embedding_values = (
2085        (1., 2., 3., 4., 5.),  # id 0
2086        (6., 7., 8., 9., 10.),  # id 1
2087        (11., 12., 13., 14., 15.)  # id 2
2088    )
2089    def _initializer(shape, dtype, partition_info):
2090      del shape, dtype, partition_info
2091      return embedding_values
2092
2093    # price has 1 dimension in input_layer
2094    price = fc.numeric_column('price')
2095
2096    # one_hot_body_style has 3 dims in input_layer.
2097    body_style = fc.categorical_column_with_vocabulary_list(
2098        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
2099    one_hot_body_style = fc.indicator_column(body_style)
2100
2101    # embedded_body_style has 5 dims in input_layer.
2102    country = fc.categorical_column_with_vocabulary_list(
2103        'country', vocabulary_list=['US', 'JP', 'CA'])
2104    embedded_country = fc.embedding_column(country, dimension=5,
2105                                           initializer=_initializer)
2106
2107    # Provides 1-dim tensor and dense tensor.
2108    features = {
2109        'price': constant_op.constant([11., 12.,]),
2110        'body-style': sparse_tensor.SparseTensor(
2111            indices=((0,), (1,)),
2112            values=('sedan', 'hardtop'),
2113            dense_shape=(2,)),
2114        # This is dense tensor for the categorical_column.
2115        'country': constant_op.constant(['CA', 'US']),
2116    }
2117    self.assertEqual(1, features['price'].shape.ndims)
2118    self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
2119    self.assertEqual(1, features['country'].shape.ndims)
2120
2121    net = fc.input_layer(features,
2122                         [price, one_hot_body_style, embedded_country])
2123    self.assertEqual(1 + 3 + 5, net.shape[1])
2124    with _initialized_session() as sess:
2125
2126      # Each row is formed by concatenating `embedded_body_style`,
2127      # `one_hot_body_style`, and `price` in order.
2128      self.assertAllEqual(
2129          [[0., 0., 1., 11., 12., 13., 14., 15., 11.],
2130           [1., 0., 0., 1., 2., 3., 4., 5., 12.]],
2131          sess.run(net))
2132
2133  def test_with_1d_unknown_shape_sparse_tensor(self):
2134    embedding_values = (
2135        (1., 2.),  # id 0
2136        (6., 7.),  # id 1
2137        (11., 12.)  # id 2
2138    )
2139    def _initializer(shape, dtype, partition_info):
2140      del shape, dtype, partition_info
2141      return embedding_values
2142
2143    # price has 1 dimension in input_layer
2144    price = fc.numeric_column('price')
2145
2146    # one_hot_body_style has 3 dims in input_layer.
2147    body_style = fc.categorical_column_with_vocabulary_list(
2148        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
2149    one_hot_body_style = fc.indicator_column(body_style)
2150
2151    # embedded_body_style has 5 dims in input_layer.
2152    country = fc.categorical_column_with_vocabulary_list(
2153        'country', vocabulary_list=['US', 'JP', 'CA'])
2154    embedded_country = fc.embedding_column(
2155        country, dimension=2, initializer=_initializer)
2156
2157    # Provides 1-dim tensor and dense tensor.
2158    features = {
2159        'price': array_ops.placeholder(dtypes.float32),
2160        'body-style': array_ops.sparse_placeholder(dtypes.string),
2161        # This is dense tensor for the categorical_column.
2162        'country': array_ops.placeholder(dtypes.string),
2163    }
2164    self.assertIsNone(features['price'].shape.ndims)
2165    self.assertIsNone(features['body-style'].get_shape().ndims)
2166    self.assertIsNone(features['country'].shape.ndims)
2167
2168    price_data = np.array([11., 12.])
2169    body_style_data = sparse_tensor.SparseTensorValue(
2170        indices=((0,), (1,)),
2171        values=('sedan', 'hardtop'),
2172        dense_shape=(2,))
2173    country_data = np.array([['US'], ['CA']])
2174
2175    net = fc.input_layer(features,
2176                         [price, one_hot_body_style, embedded_country])
2177    self.assertEqual(1 + 3 + 2, net.shape[1])
2178    with _initialized_session() as sess:
2179
2180      # Each row is formed by concatenating `embedded_body_style`,
2181      # `one_hot_body_style`, and `price` in order.
2182      self.assertAllEqual(
2183          [[0., 0., 1., 1., 2., 11.], [1., 0., 0., 11., 12., 12.]],
2184          sess.run(
2185              net,
2186              feed_dict={
2187                  features['price']: price_data,
2188                  features['body-style']: body_style_data,
2189                  features['country']: country_data
2190              }))
2191
2192  def test_with_rank_0_feature(self):
2193    # price has 1 dimension in input_layer
2194    price = fc.numeric_column('price')
2195    features = {
2196        'price': constant_op.constant(0),
2197    }
2198    self.assertEqual(0, features['price'].shape.ndims)
2199
2200    # Static rank 0 should fail
2201    with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
2202      fc.input_layer(features, [price])
2203
2204    # Dynamic rank 0 should fail
2205    features = {
2206        'price': array_ops.placeholder(dtypes.float32),
2207    }
2208    net = fc.input_layer(features, [price])
2209    self.assertEqual(1, net.shape[1])
2210    with _initialized_session() as sess:
2211      with self.assertRaisesOpError('Feature .* cannot have rank 0'):
2212        sess.run(net, feed_dict={features['price']: np.array(1)})
2213
2214
2215class MakeParseExampleSpecTest(test.TestCase):
2216
2217  class _TestFeatureColumn(_FeatureColumn,
2218                           collections.namedtuple('_TestFeatureColumn',
2219                                                  ['parse_spec'])):
2220
2221    @property
2222    def _parse_example_spec(self):
2223      return self.parse_spec
2224
2225  def test_no_feature_columns(self):
2226    actual = fc.make_parse_example_spec([])
2227    self.assertDictEqual({}, actual)
2228
2229  def test_invalid_type(self):
2230    key1 = 'key1'
2231    parse_spec1 = parsing_ops.FixedLenFeature(
2232        shape=(2,), dtype=dtypes.float32, default_value=0.)
2233    with self.assertRaisesRegexp(
2234        ValueError,
2235        'All feature_columns must be _FeatureColumn instances.*invalid_column'):
2236      fc.make_parse_example_spec(
2237          (self._TestFeatureColumn({key1: parse_spec1}), 'invalid_column'))
2238
2239  def test_one_feature_column(self):
2240    key1 = 'key1'
2241    parse_spec1 = parsing_ops.FixedLenFeature(
2242        shape=(2,), dtype=dtypes.float32, default_value=0.)
2243    actual = fc.make_parse_example_spec(
2244        (self._TestFeatureColumn({key1: parse_spec1}),))
2245    self.assertDictEqual({key1: parse_spec1}, actual)
2246
2247  def test_two_feature_columns(self):
2248    key1 = 'key1'
2249    parse_spec1 = parsing_ops.FixedLenFeature(
2250        shape=(2,), dtype=dtypes.float32, default_value=0.)
2251    key2 = 'key2'
2252    parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
2253    actual = fc.make_parse_example_spec(
2254        (self._TestFeatureColumn({key1: parse_spec1}),
2255         self._TestFeatureColumn({key2: parse_spec2})))
2256    self.assertDictEqual({key1: parse_spec1, key2: parse_spec2}, actual)
2257
2258  def test_equal_keys_different_parse_spec(self):
2259    key1 = 'key1'
2260    parse_spec1 = parsing_ops.FixedLenFeature(
2261        shape=(2,), dtype=dtypes.float32, default_value=0.)
2262    parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
2263    with self.assertRaisesRegexp(
2264        ValueError,
2265        'feature_columns contain different parse_spec for key key1'):
2266      fc.make_parse_example_spec(
2267          (self._TestFeatureColumn({key1: parse_spec1}),
2268           self._TestFeatureColumn({key1: parse_spec2})))
2269
2270  def test_equal_keys_equal_parse_spec(self):
2271    key1 = 'key1'
2272    parse_spec1 = parsing_ops.FixedLenFeature(
2273        shape=(2,), dtype=dtypes.float32, default_value=0.)
2274    actual = fc.make_parse_example_spec(
2275        (self._TestFeatureColumn({key1: parse_spec1}),
2276         self._TestFeatureColumn({key1: parse_spec1})))
2277    self.assertDictEqual({key1: parse_spec1}, actual)
2278
2279  def test_multiple_features_dict(self):
2280    """parse_spc for one column is a dict with length > 1."""
2281    key1 = 'key1'
2282    parse_spec1 = parsing_ops.FixedLenFeature(
2283        shape=(2,), dtype=dtypes.float32, default_value=0.)
2284    key2 = 'key2'
2285    parse_spec2 = parsing_ops.VarLenFeature(dtype=dtypes.string)
2286    key3 = 'key3'
2287    parse_spec3 = parsing_ops.VarLenFeature(dtype=dtypes.int32)
2288    actual = fc.make_parse_example_spec(
2289        (self._TestFeatureColumn({key1: parse_spec1}),
2290         self._TestFeatureColumn({key2: parse_spec2, key3: parse_spec3})))
2291    self.assertDictEqual(
2292        {key1: parse_spec1, key2: parse_spec2, key3: parse_spec3}, actual)
2293
2294
2295def _assert_sparse_tensor_value(test_case, expected, actual):
2296  test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
2297  test_case.assertAllEqual(expected.indices, actual.indices)
2298
2299  test_case.assertEqual(
2300      np.array(expected.values).dtype, np.array(actual.values).dtype)
2301  test_case.assertAllEqual(expected.values, actual.values)
2302
2303  test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
2304  test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
2305
2306
2307class VocabularyFileCategoricalColumnTest(test.TestCase):
2308
2309  def setUp(self):
2310    super(VocabularyFileCategoricalColumnTest, self).setUp()
2311
2312    # Contains ints, Golden State Warriors jersey numbers: 30, 35, 11, 23, 22
2313    self._warriors_vocabulary_file_name = test.test_src_dir_path(
2314        'python/feature_column/testdata/warriors_vocabulary.txt')
2315    self._warriors_vocabulary_size = 5
2316
2317    # Contains strings, character names from 'The Wire': omar, stringer, marlo
2318    self._wire_vocabulary_file_name = test.test_src_dir_path(
2319        'python/feature_column/testdata/wire_vocabulary.txt')
2320    self._wire_vocabulary_size = 3
2321
2322  def test_defaults(self):
2323    column = fc.categorical_column_with_vocabulary_file(
2324        key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
2325    self.assertEqual('aaa', column.name)
2326    self.assertEqual('aaa', column._var_scope_name)
2327    self.assertEqual('aaa', column.key)
2328    self.assertEqual(3, column._num_buckets)
2329    self.assertEqual({
2330        'aaa': parsing_ops.VarLenFeature(dtypes.string)
2331    }, column._parse_example_spec)
2332
2333  def test_all_constructor_args(self):
2334    column = fc.categorical_column_with_vocabulary_file(
2335        key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
2336        num_oov_buckets=4, dtype=dtypes.int32)
2337    self.assertEqual(7, column._num_buckets)
2338    self.assertEqual({
2339        'aaa': parsing_ops.VarLenFeature(dtypes.int32)
2340    }, column._parse_example_spec)
2341
2342  def test_deep_copy(self):
2343    original = fc.categorical_column_with_vocabulary_file(
2344        key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
2345        num_oov_buckets=4, dtype=dtypes.int32)
2346    for column in (original, copy.deepcopy(original)):
2347      self.assertEqual('aaa', column.name)
2348      self.assertEqual(7, column._num_buckets)
2349      self.assertEqual({
2350          'aaa': parsing_ops.VarLenFeature(dtypes.int32)
2351      }, column._parse_example_spec)
2352
2353  def test_vocabulary_file_none(self):
2354    with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
2355      fc.categorical_column_with_vocabulary_file(
2356          key='aaa', vocabulary_file=None, vocabulary_size=3)
2357
2358  def test_vocabulary_file_empty_string(self):
2359    with self.assertRaisesRegexp(ValueError, 'Missing vocabulary_file'):
2360      fc.categorical_column_with_vocabulary_file(
2361          key='aaa', vocabulary_file='', vocabulary_size=3)
2362
2363  def test_invalid_vocabulary_file(self):
2364    column = fc.categorical_column_with_vocabulary_file(
2365        key='aaa', vocabulary_file='file_does_not_exist', vocabulary_size=10)
2366    inputs = sparse_tensor.SparseTensorValue(
2367        indices=((0, 0), (1, 0), (1, 1)),
2368        values=('marlo', 'skywalker', 'omar'),
2369        dense_shape=(2, 2))
2370    column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2371    with self.assertRaisesRegexp(errors.OpError, 'file_does_not_exist'):
2372      with self.test_session():
2373        lookup_ops.tables_initializer().run()
2374
2375  def test_invalid_vocabulary_size(self):
2376    with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
2377      fc.categorical_column_with_vocabulary_file(
2378          key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
2379          vocabulary_size=-1)
2380    with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
2381      fc.categorical_column_with_vocabulary_file(
2382          key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
2383          vocabulary_size=0)
2384
2385  def test_too_large_vocabulary_size(self):
2386    column = fc.categorical_column_with_vocabulary_file(
2387        key='aaa',
2388        vocabulary_file=self._wire_vocabulary_file_name,
2389        vocabulary_size=self._wire_vocabulary_size + 1)
2390    inputs = sparse_tensor.SparseTensorValue(
2391        indices=((0, 0), (1, 0), (1, 1)),
2392        values=('marlo', 'skywalker', 'omar'),
2393        dense_shape=(2, 2))
2394    column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2395    with self.assertRaisesRegexp(errors.OpError, 'Invalid vocab_size'):
2396      with self.test_session():
2397        lookup_ops.tables_initializer().run()
2398
2399  def test_invalid_num_oov_buckets(self):
2400    with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
2401      fc.categorical_column_with_vocabulary_file(
2402          key='aaa', vocabulary_file='path', vocabulary_size=3,
2403          num_oov_buckets=-1)
2404
2405  def test_invalid_dtype(self):
2406    with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
2407      fc.categorical_column_with_vocabulary_file(
2408          key='aaa', vocabulary_file='path', vocabulary_size=3,
2409          dtype=dtypes.float64)
2410
2411  def test_invalid_buckets_and_default_value(self):
2412    with self.assertRaisesRegexp(
2413        ValueError, 'both num_oov_buckets and default_value'):
2414      fc.categorical_column_with_vocabulary_file(
2415          key='aaa',
2416          vocabulary_file=self._wire_vocabulary_file_name,
2417          vocabulary_size=self._wire_vocabulary_size,
2418          num_oov_buckets=100,
2419          default_value=2)
2420
2421  def test_invalid_input_dtype_int32(self):
2422    column = fc.categorical_column_with_vocabulary_file(
2423        key='aaa',
2424        vocabulary_file=self._wire_vocabulary_file_name,
2425        vocabulary_size=self._wire_vocabulary_size,
2426        dtype=dtypes.string)
2427    inputs = sparse_tensor.SparseTensorValue(
2428        indices=((0, 0), (1, 0), (1, 1)),
2429        values=(12, 24, 36),
2430        dense_shape=(2, 2))
2431    with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
2432      column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2433
2434  def test_invalid_input_dtype_string(self):
2435    column = fc.categorical_column_with_vocabulary_file(
2436        key='aaa',
2437        vocabulary_file=self._warriors_vocabulary_file_name,
2438        vocabulary_size=self._warriors_vocabulary_size,
2439        dtype=dtypes.int32)
2440    inputs = sparse_tensor.SparseTensorValue(
2441        indices=((0, 0), (1, 0), (1, 1)),
2442        values=('omar', 'stringer', 'marlo'),
2443        dense_shape=(2, 2))
2444    with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
2445      column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2446
2447  def test_parse_example(self):
2448    a = fc.categorical_column_with_vocabulary_file(
2449        key='aaa', vocabulary_file='path_to_file', vocabulary_size=3)
2450    data = example_pb2.Example(features=feature_pb2.Features(
2451        feature={
2452            'aaa':
2453                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
2454                    value=[b'omar', b'stringer']))
2455        }))
2456    features = parsing_ops.parse_example(
2457        serialized=[data.SerializeToString()],
2458        features=fc.make_parse_example_spec([a]))
2459    self.assertIn('aaa', features)
2460    with self.test_session():
2461      _assert_sparse_tensor_value(
2462          self,
2463          sparse_tensor.SparseTensorValue(
2464              indices=[[0, 0], [0, 1]],
2465              values=np.array([b'omar', b'stringer'], dtype=np.object_),
2466              dense_shape=[1, 2]),
2467          features['aaa'].eval())
2468
2469  def test_get_sparse_tensors(self):
2470    column = fc.categorical_column_with_vocabulary_file(
2471        key='aaa',
2472        vocabulary_file=self._wire_vocabulary_file_name,
2473        vocabulary_size=self._wire_vocabulary_size)
2474    inputs = sparse_tensor.SparseTensorValue(
2475        indices=((0, 0), (1, 0), (1, 1)),
2476        values=('marlo', 'skywalker', 'omar'),
2477        dense_shape=(2, 2))
2478    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2479    self.assertIsNone(id_weight_pair.weight_tensor)
2480    with _initialized_session():
2481      _assert_sparse_tensor_value(
2482          self,
2483          sparse_tensor.SparseTensorValue(
2484              indices=inputs.indices,
2485              values=np.array((2, -1, 0), dtype=np.int64),
2486              dense_shape=inputs.dense_shape),
2487          id_weight_pair.id_tensor.eval())
2488
2489  def test_get_sparse_tensors_none_vocabulary_size(self):
2490    column = fc.categorical_column_with_vocabulary_file(
2491        key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
2492    inputs = sparse_tensor.SparseTensorValue(
2493        indices=((0, 0), (1, 0), (1, 1)),
2494        values=('marlo', 'skywalker', 'omar'),
2495        dense_shape=(2, 2))
2496    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2497    self.assertIsNone(id_weight_pair.weight_tensor)
2498    with _initialized_session():
2499      _assert_sparse_tensor_value(self,
2500                                  sparse_tensor.SparseTensorValue(
2501                                      indices=inputs.indices,
2502                                      values=np.array(
2503                                          (2, -1, 0), dtype=np.int64),
2504                                      dense_shape=inputs.dense_shape),
2505                                  id_weight_pair.id_tensor.eval())
2506
2507  def test_transform_feature(self):
2508    column = fc.categorical_column_with_vocabulary_file(
2509        key='aaa',
2510        vocabulary_file=self._wire_vocabulary_file_name,
2511        vocabulary_size=self._wire_vocabulary_size)
2512    inputs = sparse_tensor.SparseTensorValue(
2513        indices=((0, 0), (1, 0), (1, 1)),
2514        values=('marlo', 'skywalker', 'omar'),
2515        dense_shape=(2, 2))
2516    id_tensor = _transform_features({'aaa': inputs}, [column])[column]
2517    with _initialized_session():
2518      _assert_sparse_tensor_value(self,
2519                                  sparse_tensor.SparseTensorValue(
2520                                      indices=inputs.indices,
2521                                      values=np.array(
2522                                          (2, -1, 0), dtype=np.int64),
2523                                      dense_shape=inputs.dense_shape),
2524                                  id_tensor.eval())
2525
2526  def test_get_sparse_tensors_weight_collections(self):
2527    column = fc.categorical_column_with_vocabulary_file(
2528        key='aaa',
2529        vocabulary_file=self._wire_vocabulary_file_name,
2530        vocabulary_size=self._wire_vocabulary_size)
2531    inputs = sparse_tensor.SparseTensor(
2532        values=['omar', 'stringer', 'marlo'],
2533        indices=[[0, 0], [1, 0], [1, 1]],
2534        dense_shape=[2, 2])
2535    column._get_sparse_tensors(
2536        _LazyBuilder({
2537            'aaa': inputs
2538        }), weight_collections=('my_weights',))
2539
2540    self.assertItemsEqual(
2541        [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
2542    self.assertItemsEqual([], ops.get_collection('my_weights'))
2543
2544  def test_get_sparse_tensors_dense_input(self):
2545    column = fc.categorical_column_with_vocabulary_file(
2546        key='aaa',
2547        vocabulary_file=self._wire_vocabulary_file_name,
2548        vocabulary_size=self._wire_vocabulary_size)
2549    id_weight_pair = column._get_sparse_tensors(
2550        _LazyBuilder({
2551            'aaa': (('marlo', ''), ('skywalker', 'omar'))
2552        }))
2553    self.assertIsNone(id_weight_pair.weight_tensor)
2554    with _initialized_session():
2555      _assert_sparse_tensor_value(
2556          self,
2557          sparse_tensor.SparseTensorValue(
2558              indices=((0, 0), (1, 0), (1, 1)),
2559              values=np.array((2, -1, 0), dtype=np.int64),
2560              dense_shape=(2, 2)),
2561          id_weight_pair.id_tensor.eval())
2562
2563  def test_get_sparse_tensors_default_value_in_vocabulary(self):
2564    column = fc.categorical_column_with_vocabulary_file(
2565        key='aaa',
2566        vocabulary_file=self._wire_vocabulary_file_name,
2567        vocabulary_size=self._wire_vocabulary_size,
2568        default_value=2)
2569    inputs = sparse_tensor.SparseTensorValue(
2570        indices=((0, 0), (1, 0), (1, 1)),
2571        values=('marlo', 'skywalker', 'omar'),
2572        dense_shape=(2, 2))
2573    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2574    self.assertIsNone(id_weight_pair.weight_tensor)
2575    with _initialized_session():
2576      _assert_sparse_tensor_value(
2577          self,
2578          sparse_tensor.SparseTensorValue(
2579              indices=inputs.indices,
2580              values=np.array((2, 2, 0), dtype=np.int64),
2581              dense_shape=inputs.dense_shape),
2582          id_weight_pair.id_tensor.eval())
2583
2584  def test_get_sparse_tensors_with_oov_buckets(self):
2585    column = fc.categorical_column_with_vocabulary_file(
2586        key='aaa',
2587        vocabulary_file=self._wire_vocabulary_file_name,
2588        vocabulary_size=self._wire_vocabulary_size,
2589        num_oov_buckets=100)
2590    inputs = sparse_tensor.SparseTensorValue(
2591        indices=((0, 0), (1, 0), (1, 1), (1, 2)),
2592        values=('marlo', 'skywalker', 'omar', 'heisenberg'),
2593        dense_shape=(2, 3))
2594    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2595    self.assertIsNone(id_weight_pair.weight_tensor)
2596    with _initialized_session():
2597      _assert_sparse_tensor_value(
2598          self,
2599          sparse_tensor.SparseTensorValue(
2600              indices=inputs.indices,
2601              values=np.array((2, 33, 0, 62), dtype=np.int64),
2602              dense_shape=inputs.dense_shape),
2603          id_weight_pair.id_tensor.eval())
2604
2605  def test_get_sparse_tensors_small_vocabulary_size(self):
2606    # 'marlo' is the last entry in our vocabulary file, so be setting
2607    # `vocabulary_size` to 1 less than number of entries in file, we take
2608    # 'marlo' out of the vocabulary.
2609    column = fc.categorical_column_with_vocabulary_file(
2610        key='aaa',
2611        vocabulary_file=self._wire_vocabulary_file_name,
2612        vocabulary_size=self._wire_vocabulary_size - 1)
2613    inputs = sparse_tensor.SparseTensorValue(
2614        indices=((0, 0), (1, 0), (1, 1)),
2615        values=('marlo', 'skywalker', 'omar'),
2616        dense_shape=(2, 2))
2617    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2618    self.assertIsNone(id_weight_pair.weight_tensor)
2619    with _initialized_session():
2620      _assert_sparse_tensor_value(
2621          self,
2622          sparse_tensor.SparseTensorValue(
2623              indices=inputs.indices,
2624              values=np.array((-1, -1, 0), dtype=np.int64),
2625              dense_shape=inputs.dense_shape),
2626          id_weight_pair.id_tensor.eval())
2627
2628  def test_get_sparse_tensors_int32(self):
2629    column = fc.categorical_column_with_vocabulary_file(
2630        key='aaa',
2631        vocabulary_file=self._warriors_vocabulary_file_name,
2632        vocabulary_size=self._warriors_vocabulary_size,
2633        dtype=dtypes.int32)
2634    inputs = sparse_tensor.SparseTensorValue(
2635        indices=((0, 0), (1, 0), (1, 1), (2, 2)),
2636        values=(11, 100, 30, 22),
2637        dense_shape=(3, 3))
2638    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2639    self.assertIsNone(id_weight_pair.weight_tensor)
2640    with _initialized_session():
2641      _assert_sparse_tensor_value(
2642          self,
2643          sparse_tensor.SparseTensorValue(
2644              indices=inputs.indices,
2645              values=np.array((2, -1, 0, 4), dtype=np.int64),
2646              dense_shape=inputs.dense_shape),
2647          id_weight_pair.id_tensor.eval())
2648
2649  def test_get_sparse_tensors_int32_dense_input(self):
2650    default_value = -100
2651    column = fc.categorical_column_with_vocabulary_file(
2652        key='aaa',
2653        vocabulary_file=self._warriors_vocabulary_file_name,
2654        vocabulary_size=self._warriors_vocabulary_size,
2655        dtype=dtypes.int32,
2656        default_value=default_value)
2657    id_weight_pair = column._get_sparse_tensors(
2658        _LazyBuilder({
2659            'aaa': ((11, -1, -1), (100, 30, -1), (-1, -1, 22))
2660        }))
2661    self.assertIsNone(id_weight_pair.weight_tensor)
2662    with _initialized_session():
2663      _assert_sparse_tensor_value(
2664          self,
2665          sparse_tensor.SparseTensorValue(
2666              indices=((0, 0), (1, 0), (1, 1), (2, 2)),
2667              values=np.array((2, default_value, 0, 4), dtype=np.int64),
2668              dense_shape=(3, 3)),
2669          id_weight_pair.id_tensor.eval())
2670
2671  def test_get_sparse_tensors_int32_with_oov_buckets(self):
2672    column = fc.categorical_column_with_vocabulary_file(
2673        key='aaa',
2674        vocabulary_file=self._warriors_vocabulary_file_name,
2675        vocabulary_size=self._warriors_vocabulary_size,
2676        dtype=dtypes.int32,
2677        num_oov_buckets=100)
2678    inputs = sparse_tensor.SparseTensorValue(
2679        indices=((0, 0), (1, 0), (1, 1), (2, 2)),
2680        values=(11, 100, 30, 22),
2681        dense_shape=(3, 3))
2682    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2683    self.assertIsNone(id_weight_pair.weight_tensor)
2684    with _initialized_session():
2685      _assert_sparse_tensor_value(
2686          self,
2687          sparse_tensor.SparseTensorValue(
2688              indices=inputs.indices,
2689              values=np.array((2, 60, 0, 4), dtype=np.int64),
2690              dense_shape=inputs.dense_shape),
2691          id_weight_pair.id_tensor.eval())
2692
2693  def test_linear_model(self):
2694    wire_column = fc.categorical_column_with_vocabulary_file(
2695        key='wire',
2696        vocabulary_file=self._wire_vocabulary_file_name,
2697        vocabulary_size=self._wire_vocabulary_size,
2698        num_oov_buckets=1)
2699    self.assertEqual(4, wire_column._num_buckets)
2700    with ops.Graph().as_default():
2701      predictions = fc.linear_model({
2702          wire_column.name: sparse_tensor.SparseTensorValue(
2703              indices=((0, 0), (1, 0), (1, 1)),
2704              values=('marlo', 'skywalker', 'omar'),
2705              dense_shape=(2, 2))
2706      }, (wire_column,))
2707      bias = get_linear_model_bias()
2708      wire_var = get_linear_model_column_var(wire_column)
2709      with _initialized_session():
2710        self.assertAllClose((0.,), bias.eval())
2711        self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
2712        self.assertAllClose(((0.,), (0.,)), predictions.eval())
2713        wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
2714        # 'marlo' -> 2: wire_var[2] = 3
2715        # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
2716        self.assertAllClose(((3.,), (5.,)), predictions.eval())
2717
2718
2719class VocabularyListCategoricalColumnTest(test.TestCase):
2720
2721  def test_defaults_string(self):
2722    column = fc.categorical_column_with_vocabulary_list(
2723        key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
2724    self.assertEqual('aaa', column.name)
2725    self.assertEqual('aaa', column.key)
2726    self.assertEqual('aaa', column._var_scope_name)
2727    self.assertEqual(3, column._num_buckets)
2728    self.assertEqual({
2729        'aaa': parsing_ops.VarLenFeature(dtypes.string)
2730    }, column._parse_example_spec)
2731
2732  def test_defaults_int(self):
2733    column = fc.categorical_column_with_vocabulary_list(
2734        key='aaa', vocabulary_list=(12, 24, 36))
2735    self.assertEqual('aaa', column.name)
2736    self.assertEqual('aaa', column.key)
2737    self.assertEqual('aaa', column._var_scope_name)
2738    self.assertEqual(3, column._num_buckets)
2739    self.assertEqual({
2740        'aaa': parsing_ops.VarLenFeature(dtypes.int64)
2741    }, column._parse_example_spec)
2742
2743  def test_all_constructor_args(self):
2744    column = fc.categorical_column_with_vocabulary_list(
2745        key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32,
2746        default_value=-99)
2747    self.assertEqual(3, column._num_buckets)
2748    self.assertEqual({
2749        'aaa': parsing_ops.VarLenFeature(dtypes.int32)
2750    }, column._parse_example_spec)
2751
2752  def test_deep_copy(self):
2753    original = fc.categorical_column_with_vocabulary_list(
2754        key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.int32)
2755    for column in (original, copy.deepcopy(original)):
2756      self.assertEqual('aaa', column.name)
2757      self.assertEqual(3, column._num_buckets)
2758      self.assertEqual({
2759          'aaa': parsing_ops.VarLenFeature(dtypes.int32)
2760      }, column._parse_example_spec)
2761
2762  def test_invalid_dtype(self):
2763    with self.assertRaisesRegexp(ValueError, 'dtype must be string or integer'):
2764      fc.categorical_column_with_vocabulary_list(
2765          key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
2766          dtype=dtypes.float32)
2767
2768  def test_invalid_mapping_dtype(self):
2769    with self.assertRaisesRegexp(
2770        ValueError, r'vocabulary dtype must be string or integer'):
2771      fc.categorical_column_with_vocabulary_list(
2772          key='aaa', vocabulary_list=(12., 24., 36.))
2773
2774  def test_mismatched_int_dtype(self):
2775    with self.assertRaisesRegexp(
2776        ValueError, r'dtype.*and vocabulary dtype.*do not match'):
2777      fc.categorical_column_with_vocabulary_list(
2778          key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'),
2779          dtype=dtypes.int32)
2780
2781  def test_mismatched_string_dtype(self):
2782    with self.assertRaisesRegexp(
2783        ValueError, r'dtype.*and vocabulary dtype.*do not match'):
2784      fc.categorical_column_with_vocabulary_list(
2785          key='aaa', vocabulary_list=(12, 24, 36), dtype=dtypes.string)
2786
2787  def test_none_mapping(self):
2788    with self.assertRaisesRegexp(
2789        ValueError, r'vocabulary_list.*must be non-empty'):
2790      fc.categorical_column_with_vocabulary_list(
2791          key='aaa', vocabulary_list=None)
2792
2793  def test_empty_mapping(self):
2794    with self.assertRaisesRegexp(
2795        ValueError, r'vocabulary_list.*must be non-empty'):
2796      fc.categorical_column_with_vocabulary_list(
2797          key='aaa', vocabulary_list=tuple([]))
2798
2799  def test_duplicate_mapping(self):
2800    with self.assertRaisesRegexp(ValueError, 'Duplicate keys'):
2801      fc.categorical_column_with_vocabulary_list(
2802          key='aaa', vocabulary_list=(12, 24, 12))
2803
2804  def test_invalid_num_oov_buckets(self):
2805    with self.assertRaisesRegexp(ValueError, 'Invalid num_oov_buckets'):
2806      fc.categorical_column_with_vocabulary_list(
2807          key='aaa', vocabulary_list=(12, 24, 36),
2808          num_oov_buckets=-1)
2809
2810  def test_invalid_buckets_and_default_value(self):
2811    with self.assertRaisesRegexp(
2812        ValueError, 'both num_oov_buckets and default_value'):
2813      fc.categorical_column_with_vocabulary_list(
2814          key='aaa',
2815          vocabulary_list=(12, 24, 36),
2816          num_oov_buckets=100,
2817          default_value=2)
2818
2819  def test_invalid_input_dtype_int32(self):
2820    column = fc.categorical_column_with_vocabulary_list(
2821        key='aaa',
2822        vocabulary_list=('omar', 'stringer', 'marlo'))
2823    inputs = sparse_tensor.SparseTensorValue(
2824        indices=((0, 0), (1, 0), (1, 1)),
2825        values=(12, 24, 36),
2826        dense_shape=(2, 2))
2827    with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
2828      column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2829
2830  def test_invalid_input_dtype_string(self):
2831    column = fc.categorical_column_with_vocabulary_list(
2832        key='aaa',
2833        vocabulary_list=(12, 24, 36))
2834    inputs = sparse_tensor.SparseTensorValue(
2835        indices=((0, 0), (1, 0), (1, 1)),
2836        values=('omar', 'stringer', 'marlo'),
2837        dense_shape=(2, 2))
2838    with self.assertRaisesRegexp(ValueError, 'dtype must be compatible'):
2839      column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2840
2841  def test_parse_example_string(self):
2842    a = fc.categorical_column_with_vocabulary_list(
2843        key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
2844    data = example_pb2.Example(features=feature_pb2.Features(
2845        feature={
2846            'aaa':
2847                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
2848                    value=[b'omar', b'stringer']))
2849        }))
2850    features = parsing_ops.parse_example(
2851        serialized=[data.SerializeToString()],
2852        features=fc.make_parse_example_spec([a]))
2853    self.assertIn('aaa', features)
2854    with self.test_session():
2855      _assert_sparse_tensor_value(
2856          self,
2857          sparse_tensor.SparseTensorValue(
2858              indices=[[0, 0], [0, 1]],
2859              values=np.array([b'omar', b'stringer'], dtype=np.object_),
2860              dense_shape=[1, 2]),
2861          features['aaa'].eval())
2862
2863  def test_parse_example_int(self):
2864    a = fc.categorical_column_with_vocabulary_list(
2865        key='aaa', vocabulary_list=(11, 21, 31))
2866    data = example_pb2.Example(features=feature_pb2.Features(
2867        feature={
2868            'aaa':
2869                feature_pb2.Feature(int64_list=feature_pb2.Int64List(
2870                    value=[11, 21]))
2871        }))
2872    features = parsing_ops.parse_example(
2873        serialized=[data.SerializeToString()],
2874        features=fc.make_parse_example_spec([a]))
2875    self.assertIn('aaa', features)
2876    with self.test_session():
2877      _assert_sparse_tensor_value(
2878          self,
2879          sparse_tensor.SparseTensorValue(
2880              indices=[[0, 0], [0, 1]],
2881              values=[11, 21],
2882              dense_shape=[1, 2]),
2883          features['aaa'].eval())
2884
2885  def test_get_sparse_tensors(self):
2886    column = fc.categorical_column_with_vocabulary_list(
2887        key='aaa',
2888        vocabulary_list=('omar', 'stringer', 'marlo'))
2889    inputs = sparse_tensor.SparseTensorValue(
2890        indices=((0, 0), (1, 0), (1, 1)),
2891        values=('marlo', 'skywalker', 'omar'),
2892        dense_shape=(2, 2))
2893    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2894    self.assertIsNone(id_weight_pair.weight_tensor)
2895    with _initialized_session():
2896      _assert_sparse_tensor_value(
2897          self,
2898          sparse_tensor.SparseTensorValue(
2899              indices=inputs.indices,
2900              values=np.array((2, -1, 0), dtype=np.int64),
2901              dense_shape=inputs.dense_shape),
2902          id_weight_pair.id_tensor.eval())
2903
2904  def test_transform_feature(self):
2905    column = fc.categorical_column_with_vocabulary_list(
2906        key='aaa',
2907        vocabulary_list=('omar', 'stringer', 'marlo'))
2908    inputs = sparse_tensor.SparseTensorValue(
2909        indices=((0, 0), (1, 0), (1, 1)),
2910        values=('marlo', 'skywalker', 'omar'),
2911        dense_shape=(2, 2))
2912    id_tensor = _transform_features({'aaa': inputs}, [column])[column]
2913    with _initialized_session():
2914      _assert_sparse_tensor_value(
2915          self,
2916          sparse_tensor.SparseTensorValue(
2917              indices=inputs.indices,
2918              values=np.array((2, -1, 0), dtype=np.int64),
2919              dense_shape=inputs.dense_shape),
2920          id_tensor.eval())
2921
2922  def test_get_sparse_tensors_weight_collections(self):
2923    column = fc.categorical_column_with_vocabulary_list(
2924        key='aaa',
2925        vocabulary_list=('omar', 'stringer', 'marlo'))
2926    inputs = sparse_tensor.SparseTensor(
2927        values=['omar', 'stringer', 'marlo'],
2928        indices=[[0, 0], [1, 0], [1, 1]],
2929        dense_shape=[2, 2])
2930    column._get_sparse_tensors(
2931        _LazyBuilder({
2932            'aaa': inputs
2933        }), weight_collections=('my_weights',))
2934
2935    self.assertItemsEqual(
2936        [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
2937    self.assertItemsEqual([], ops.get_collection('my_weights'))
2938
2939  def test_get_sparse_tensors_dense_input(self):
2940    column = fc.categorical_column_with_vocabulary_list(
2941        key='aaa',
2942        vocabulary_list=('omar', 'stringer', 'marlo'))
2943    id_weight_pair = column._get_sparse_tensors(
2944        _LazyBuilder({
2945            'aaa': (('marlo', ''), ('skywalker', 'omar'))
2946        }))
2947    self.assertIsNone(id_weight_pair.weight_tensor)
2948    with _initialized_session():
2949      _assert_sparse_tensor_value(
2950          self,
2951          sparse_tensor.SparseTensorValue(
2952              indices=((0, 0), (1, 0), (1, 1)),
2953              values=np.array((2, -1, 0), dtype=np.int64),
2954              dense_shape=(2, 2)),
2955          id_weight_pair.id_tensor.eval())
2956
2957  def test_get_sparse_tensors_default_value_in_vocabulary(self):
2958    column = fc.categorical_column_with_vocabulary_list(
2959        key='aaa',
2960        vocabulary_list=('omar', 'stringer', 'marlo'),
2961        default_value=2)
2962    inputs = sparse_tensor.SparseTensorValue(
2963        indices=((0, 0), (1, 0), (1, 1)),
2964        values=('marlo', 'skywalker', 'omar'),
2965        dense_shape=(2, 2))
2966    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2967    self.assertIsNone(id_weight_pair.weight_tensor)
2968    with _initialized_session():
2969      _assert_sparse_tensor_value(
2970          self,
2971          sparse_tensor.SparseTensorValue(
2972              indices=inputs.indices,
2973              values=np.array((2, 2, 0), dtype=np.int64),
2974              dense_shape=inputs.dense_shape),
2975          id_weight_pair.id_tensor.eval())
2976
2977  def test_get_sparse_tensors_with_oov_buckets(self):
2978    column = fc.categorical_column_with_vocabulary_list(
2979        key='aaa',
2980        vocabulary_list=('omar', 'stringer', 'marlo'),
2981        num_oov_buckets=100)
2982    inputs = sparse_tensor.SparseTensorValue(
2983        indices=((0, 0), (1, 0), (1, 1), (1, 2)),
2984        values=('marlo', 'skywalker', 'omar', 'heisenberg'),
2985        dense_shape=(2, 3))
2986    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
2987    self.assertIsNone(id_weight_pair.weight_tensor)
2988    with _initialized_session():
2989      _assert_sparse_tensor_value(
2990          self,
2991          sparse_tensor.SparseTensorValue(
2992              indices=inputs.indices,
2993              values=np.array((2, 33, 0, 62), dtype=np.int64),
2994              dense_shape=inputs.dense_shape),
2995          id_weight_pair.id_tensor.eval())
2996
2997  def test_get_sparse_tensors_int32(self):
2998    column = fc.categorical_column_with_vocabulary_list(
2999        key='aaa',
3000        vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
3001        dtype=dtypes.int32)
3002    inputs = sparse_tensor.SparseTensorValue(
3003        indices=((0, 0), (1, 0), (1, 1), (2, 2)),
3004        values=np.array((11, 100, 30, 22), dtype=np.int32),
3005        dense_shape=(3, 3))
3006    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
3007    self.assertIsNone(id_weight_pair.weight_tensor)
3008    with _initialized_session():
3009      _assert_sparse_tensor_value(
3010          self,
3011          sparse_tensor.SparseTensorValue(
3012              indices=inputs.indices,
3013              values=np.array((2, -1, 0, 4), dtype=np.int64),
3014              dense_shape=inputs.dense_shape),
3015          id_weight_pair.id_tensor.eval())
3016
3017  def test_get_sparse_tensors_int32_dense_input(self):
3018    default_value = -100
3019    column = fc.categorical_column_with_vocabulary_list(
3020        key='aaa',
3021        vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
3022        dtype=dtypes.int32,
3023        default_value=default_value)
3024    id_weight_pair = column._get_sparse_tensors(
3025        _LazyBuilder({
3026            'aaa':
3027                np.array(
3028                    ((11, -1, -1), (100, 30, -1), (-1, -1, 22)), dtype=np.int32)
3029        }))
3030    self.assertIsNone(id_weight_pair.weight_tensor)
3031    with _initialized_session():
3032      _assert_sparse_tensor_value(
3033          self,
3034          sparse_tensor.SparseTensorValue(
3035              indices=((0, 0), (1, 0), (1, 1), (2, 2)),
3036              values=np.array((2, default_value, 0, 4), dtype=np.int64),
3037              dense_shape=(3, 3)),
3038          id_weight_pair.id_tensor.eval())
3039
3040  def test_get_sparse_tensors_int32_with_oov_buckets(self):
3041    column = fc.categorical_column_with_vocabulary_list(
3042        key='aaa',
3043        vocabulary_list=np.array((30, 35, 11, 23, 22), dtype=np.int32),
3044        dtype=dtypes.int32,
3045        num_oov_buckets=100)
3046    inputs = sparse_tensor.SparseTensorValue(
3047        indices=((0, 0), (1, 0), (1, 1), (2, 2)),
3048        values=(11, 100, 30, 22),
3049        dense_shape=(3, 3))
3050    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
3051    self.assertIsNone(id_weight_pair.weight_tensor)
3052    with _initialized_session():
3053      _assert_sparse_tensor_value(
3054          self,
3055          sparse_tensor.SparseTensorValue(
3056              indices=inputs.indices,
3057              values=np.array((2, 60, 0, 4), dtype=np.int64),
3058              dense_shape=inputs.dense_shape),
3059          id_weight_pair.id_tensor.eval())
3060
3061  def test_linear_model(self):
3062    wire_column = fc.categorical_column_with_vocabulary_list(
3063        key='aaa',
3064        vocabulary_list=('omar', 'stringer', 'marlo'),
3065        num_oov_buckets=1)
3066    self.assertEqual(4, wire_column._num_buckets)
3067    with ops.Graph().as_default():
3068      predictions = fc.linear_model({
3069          wire_column.name: sparse_tensor.SparseTensorValue(
3070              indices=((0, 0), (1, 0), (1, 1)),
3071              values=('marlo', 'skywalker', 'omar'),
3072              dense_shape=(2, 2))
3073      }, (wire_column,))
3074      bias = get_linear_model_bias()
3075      wire_var = get_linear_model_column_var(wire_column)
3076      with _initialized_session():
3077        self.assertAllClose((0.,), bias.eval())
3078        self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
3079        self.assertAllClose(((0.,), (0.,)), predictions.eval())
3080        wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
3081        # 'marlo' -> 2: wire_var[2] = 3
3082        # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
3083        self.assertAllClose(((3.,), (5.,)), predictions.eval())
3084
3085
3086class IdentityCategoricalColumnTest(test.TestCase):
3087
3088  def test_constructor(self):
3089    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3090    self.assertEqual('aaa', column.name)
3091    self.assertEqual('aaa', column.key)
3092    self.assertEqual('aaa', column._var_scope_name)
3093    self.assertEqual(3, column._num_buckets)
3094    self.assertEqual({
3095        'aaa': parsing_ops.VarLenFeature(dtypes.int64)
3096    }, column._parse_example_spec)
3097
3098  def test_deep_copy(self):
3099    original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3100    for column in (original, copy.deepcopy(original)):
3101      self.assertEqual('aaa', column.name)
3102      self.assertEqual(3, column._num_buckets)
3103      self.assertEqual({
3104          'aaa': parsing_ops.VarLenFeature(dtypes.int64)
3105      }, column._parse_example_spec)
3106
3107  def test_invalid_num_buckets_zero(self):
3108    with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
3109      fc.categorical_column_with_identity(key='aaa', num_buckets=0)
3110
3111  def test_invalid_num_buckets_negative(self):
3112    with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'):
3113      fc.categorical_column_with_identity(key='aaa', num_buckets=-1)
3114
3115  def test_invalid_default_value_too_small(self):
3116    with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'):
3117      fc.categorical_column_with_identity(
3118          key='aaa', num_buckets=3, default_value=-1)
3119
3120  def test_invalid_default_value_too_big(self):
3121    with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'):
3122      fc.categorical_column_with_identity(
3123          key='aaa', num_buckets=3, default_value=3)
3124
3125  def test_invalid_input_dtype(self):
3126    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3127    inputs = sparse_tensor.SparseTensorValue(
3128        indices=((0, 0), (1, 0), (1, 1)),
3129        values=('omar', 'stringer', 'marlo'),
3130        dense_shape=(2, 2))
3131    with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
3132      column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
3133
3134  def test_parse_example(self):
3135    a = fc.categorical_column_with_identity(key='aaa', num_buckets=30)
3136    data = example_pb2.Example(features=feature_pb2.Features(
3137        feature={
3138            'aaa':
3139                feature_pb2.Feature(int64_list=feature_pb2.Int64List(
3140                    value=[11, 21]))
3141        }))
3142    features = parsing_ops.parse_example(
3143        serialized=[data.SerializeToString()],
3144        features=fc.make_parse_example_spec([a]))
3145    self.assertIn('aaa', features)
3146    with self.test_session():
3147      _assert_sparse_tensor_value(
3148          self,
3149          sparse_tensor.SparseTensorValue(
3150              indices=[[0, 0], [0, 1]],
3151              values=np.array([11, 21], dtype=np.int64),
3152              dense_shape=[1, 2]),
3153          features['aaa'].eval())
3154
3155  def test_get_sparse_tensors(self):
3156    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3157    inputs = sparse_tensor.SparseTensorValue(
3158        indices=((0, 0), (1, 0), (1, 1)),
3159        values=(0, 1, 0),
3160        dense_shape=(2, 2))
3161    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
3162    self.assertIsNone(id_weight_pair.weight_tensor)
3163    with _initialized_session():
3164      _assert_sparse_tensor_value(
3165          self,
3166          sparse_tensor.SparseTensorValue(
3167              indices=inputs.indices,
3168              values=np.array((0, 1, 0), dtype=np.int64),
3169              dense_shape=inputs.dense_shape),
3170          id_weight_pair.id_tensor.eval())
3171
3172  def test_transform_feature(self):
3173    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3174    inputs = sparse_tensor.SparseTensorValue(
3175        indices=((0, 0), (1, 0), (1, 1)),
3176        values=(0, 1, 0),
3177        dense_shape=(2, 2))
3178    id_tensor = _transform_features({'aaa': inputs}, [column])[column]
3179    with _initialized_session():
3180      _assert_sparse_tensor_value(
3181          self,
3182          sparse_tensor.SparseTensorValue(
3183              indices=inputs.indices,
3184              values=np.array((0, 1, 0), dtype=np.int64),
3185              dense_shape=inputs.dense_shape),
3186          id_tensor.eval())
3187
3188  def test_get_sparse_tensors_weight_collections(self):
3189    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3190    inputs = sparse_tensor.SparseTensorValue(
3191        indices=((0, 0), (1, 0), (1, 1)),
3192        values=(0, 1, 0),
3193        dense_shape=(2, 2))
3194    column._get_sparse_tensors(
3195        _LazyBuilder({
3196            'aaa': inputs
3197        }), weight_collections=('my_weights',))
3198
3199    self.assertItemsEqual(
3200        [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
3201    self.assertItemsEqual([], ops.get_collection('my_weights'))
3202
3203  def test_get_sparse_tensors_dense_input(self):
3204    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3205    id_weight_pair = column._get_sparse_tensors(
3206        _LazyBuilder({
3207            'aaa': ((0, -1), (1, 0))
3208        }))
3209    self.assertIsNone(id_weight_pair.weight_tensor)
3210    with _initialized_session():
3211      _assert_sparse_tensor_value(
3212          self,
3213          sparse_tensor.SparseTensorValue(
3214              indices=((0, 0), (1, 0), (1, 1)),
3215              values=np.array((0, 1, 0), dtype=np.int64),
3216              dense_shape=(2, 2)),
3217          id_weight_pair.id_tensor.eval())
3218
3219  def test_get_sparse_tensors_with_inputs_too_small(self):
3220    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3221    inputs = sparse_tensor.SparseTensorValue(
3222        indices=((0, 0), (1, 0), (1, 1)),
3223        values=(1, -1, 0),
3224        dense_shape=(2, 2))
3225    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
3226    self.assertIsNone(id_weight_pair.weight_tensor)
3227    with _initialized_session():
3228      with self.assertRaisesRegexp(
3229          errors.OpError, 'assert_greater_or_equal_0'):
3230        id_weight_pair.id_tensor.eval()
3231
3232  def test_get_sparse_tensors_with_inputs_too_big(self):
3233    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3234    inputs = sparse_tensor.SparseTensorValue(
3235        indices=((0, 0), (1, 0), (1, 1)),
3236        values=(1, 99, 0),
3237        dense_shape=(2, 2))
3238    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
3239    self.assertIsNone(id_weight_pair.weight_tensor)
3240    with _initialized_session():
3241      with self.assertRaisesRegexp(
3242          errors.OpError, 'assert_less_than_num_buckets'):
3243        id_weight_pair.id_tensor.eval()
3244
3245  def test_get_sparse_tensors_with_default_value(self):
3246    column = fc.categorical_column_with_identity(
3247        key='aaa', num_buckets=4, default_value=3)
3248    inputs = sparse_tensor.SparseTensorValue(
3249        indices=((0, 0), (1, 0), (1, 1)),
3250        values=(1, -1, 99),
3251        dense_shape=(2, 2))
3252    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
3253    self.assertIsNone(id_weight_pair.weight_tensor)
3254    with _initialized_session():
3255      _assert_sparse_tensor_value(
3256          self,
3257          sparse_tensor.SparseTensorValue(
3258              indices=inputs.indices,
3259              values=np.array((1, 3, 3), dtype=np.int64),
3260              dense_shape=inputs.dense_shape),
3261          id_weight_pair.id_tensor.eval())
3262
3263  def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
3264    column = fc.categorical_column_with_identity(
3265        key='aaa', num_buckets=4, default_value=3)
3266    input_indices = array_ops.placeholder(dtype=dtypes.int64)
3267    input_values = array_ops.placeholder(dtype=dtypes.int32)
3268    input_shape = array_ops.placeholder(dtype=dtypes.int64)
3269    inputs = sparse_tensor.SparseTensorValue(
3270        indices=input_indices,
3271        values=input_values,
3272        dense_shape=input_shape)
3273    id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
3274    self.assertIsNone(id_weight_pair.weight_tensor)
3275    with _initialized_session():
3276      _assert_sparse_tensor_value(
3277          self,
3278          sparse_tensor.SparseTensorValue(
3279              indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64),
3280              values=np.array((1, 3, 3), dtype=np.int64),
3281              dense_shape=np.array((2, 2), dtype=np.int64)),
3282          id_weight_pair.id_tensor.eval(feed_dict={
3283              input_indices: ((0, 0), (1, 0), (1, 1)),
3284              input_values: (1, -1, 99),
3285              input_shape: (2, 2),
3286          }))
3287
3288  def test_linear_model(self):
3289    column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3290    self.assertEqual(3, column._num_buckets)
3291    with ops.Graph().as_default():
3292      predictions = fc.linear_model({
3293          column.name: sparse_tensor.SparseTensorValue(
3294              indices=((0, 0), (1, 0), (1, 1)),
3295              values=(0, 2, 1),
3296              dense_shape=(2, 2))
3297      }, (column,))
3298      bias = get_linear_model_bias()
3299      weight_var = get_linear_model_column_var(column)
3300      with _initialized_session():
3301        self.assertAllClose((0.,), bias.eval())
3302        self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
3303        self.assertAllClose(((0.,), (0.,)), predictions.eval())
3304        weight_var.assign(((1.,), (2.,), (3.,))).eval()
3305        # weight_var[0] = 1
3306        # weight_var[2] + weight_var[1] = 3+2 = 5
3307        self.assertAllClose(((1.,), (5.,)), predictions.eval())
3308
3309
3310class TransformFeaturesTest(test.TestCase):
3311
3312  # All transform tests are distributed in column test.
3313  # Here we only test multi column case and naming
3314  def transform_multi_column(self):
3315    bucketized_price = fc.bucketized_column(
3316        fc.numeric_column('price'), boundaries=[0, 2, 4, 6])
3317    hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
3318    with ops.Graph().as_default():
3319      features = {
3320          'price': [[-1.], [5.]],
3321          'wire':
3322              sparse_tensor.SparseTensor(
3323                  values=['omar', 'stringer', 'marlo'],
3324                  indices=[[0, 0], [1, 0], [1, 1]],
3325                  dense_shape=[2, 2])
3326      }
3327      transformed = _transform_features(features,
3328                                        [bucketized_price, hashed_sparse])
3329      with _initialized_session():
3330        self.assertIn(bucketized_price.name, transformed[bucketized_price].name)
3331        self.assertAllEqual([[0], [3]], transformed[bucketized_price].eval())
3332        self.assertIn(hashed_sparse.name, transformed[hashed_sparse].name)
3333        self.assertAllEqual([6, 4, 1], transformed[hashed_sparse].values.eval())
3334
3335  def test_column_order(self):
3336    """When the column is both dense and sparse, uses sparse tensors."""
3337
3338    class _LoggerColumn(_FeatureColumn):
3339
3340      def __init__(self, name):
3341        self._name = name
3342
3343      @property
3344      def name(self):
3345        return self._name
3346
3347      def _transform_feature(self, inputs):
3348        del inputs
3349        self.call_order = call_logger['count']
3350        call_logger['count'] += 1
3351        return 'Anything'
3352
3353      @property
3354      def _parse_example_spec(self):
3355        pass
3356
3357    with ops.Graph().as_default():
3358      column1 = _LoggerColumn('1')
3359      column2 = _LoggerColumn('2')
3360      call_logger = {'count': 0}
3361      _transform_features({}, [column1, column2])
3362      self.assertEqual(0, column1.call_order)
3363      self.assertEqual(1, column2.call_order)
3364
3365      call_logger = {'count': 0}
3366      _transform_features({}, [column2, column1])
3367      self.assertEqual(0, column1.call_order)
3368      self.assertEqual(1, column2.call_order)
3369
3370
3371class IndicatorColumnTest(test.TestCase):
3372
3373  def test_indicator_column(self):
3374    a = fc.categorical_column_with_hash_bucket('a', 4)
3375    indicator_a = fc.indicator_column(a)
3376    self.assertEqual(indicator_a.categorical_column.name, 'a')
3377    self.assertEqual(indicator_a.name, 'a_indicator')
3378    self.assertEqual(indicator_a._var_scope_name, 'a_indicator')
3379    self.assertEqual(indicator_a._variable_shape, [1, 4])
3380
3381    b = fc.categorical_column_with_hash_bucket('b', hash_bucket_size=100)
3382    indicator_b = fc.indicator_column(b)
3383    self.assertEqual(indicator_b.categorical_column.name, 'b')
3384    self.assertEqual(indicator_b.name, 'b_indicator')
3385    self.assertEqual(indicator_b._var_scope_name, 'b_indicator')
3386    self.assertEqual(indicator_b._variable_shape, [1, 100])
3387
3388  def test_1D_shape_succeeds(self):
3389    animal = fc.indicator_column(
3390        fc.categorical_column_with_hash_bucket('animal', 4))
3391    builder = _LazyBuilder({'animal': ['fox', 'fox']})
3392    output = builder.get(animal)
3393    with self.test_session():
3394      self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
3395
3396  def test_2D_shape_succeeds(self):
3397    # TODO(ispir/cassandrax): Swith to categorical_column_with_keys when ready.
3398    animal = fc.indicator_column(
3399        fc.categorical_column_with_hash_bucket('animal', 4))
3400    builder = _LazyBuilder({
3401        'animal':
3402            sparse_tensor.SparseTensor(
3403                indices=[[0, 0], [1, 0]],
3404                values=['fox', 'fox'],
3405                dense_shape=[2, 1])
3406    })
3407    output = builder.get(animal)
3408    with self.test_session():
3409      self.assertAllEqual([[0., 0., 1., 0.], [0., 0., 1., 0.]], output.eval())
3410
3411  def test_multi_hot(self):
3412    animal = fc.indicator_column(
3413        fc.categorical_column_with_identity('animal', num_buckets=4))
3414
3415    builder = _LazyBuilder({
3416        'animal':
3417            sparse_tensor.SparseTensor(
3418                indices=[[0, 0], [0, 1]], values=[1, 1], dense_shape=[1, 2])
3419    })
3420    output = builder.get(animal)
3421    with self.test_session():
3422      self.assertAllEqual([[0., 2., 0., 0.]], output.eval())
3423
3424  def test_multi_hot2(self):
3425    animal = fc.indicator_column(
3426        fc.categorical_column_with_identity('animal', num_buckets=4))
3427    builder = _LazyBuilder({
3428        'animal':
3429            sparse_tensor.SparseTensor(
3430                indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
3431    })
3432    output = builder.get(animal)
3433    with self.test_session():
3434      self.assertAllEqual([[0., 1., 1., 0.]], output.eval())
3435
3436  def test_deep_copy(self):
3437    a = fc.categorical_column_with_hash_bucket('a', 4)
3438    column = fc.indicator_column(a)
3439    column_copy = copy.deepcopy(column)
3440    self.assertEqual(column_copy.categorical_column.name, 'a')
3441    self.assertEqual(column.name, 'a_indicator')
3442    self.assertEqual(column._variable_shape, [1, 4])
3443
3444  def test_parse_example(self):
3445    a = fc.categorical_column_with_vocabulary_list(
3446        key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
3447    a_indicator = fc.indicator_column(a)
3448    data = example_pb2.Example(features=feature_pb2.Features(
3449        feature={
3450            'aaa':
3451                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
3452                    value=[b'omar', b'stringer']))
3453        }))
3454    features = parsing_ops.parse_example(
3455        serialized=[data.SerializeToString()],
3456        features=fc.make_parse_example_spec([a_indicator]))
3457    self.assertIn('aaa', features)
3458    with self.test_session():
3459      _assert_sparse_tensor_value(
3460          self,
3461          sparse_tensor.SparseTensorValue(
3462              indices=[[0, 0], [0, 1]],
3463              values=np.array([b'omar', b'stringer'], dtype=np.object_),
3464              dense_shape=[1, 2]),
3465          features['aaa'].eval())
3466
3467  def test_transform(self):
3468    a = fc.categorical_column_with_vocabulary_list(
3469        key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
3470    a_indicator = fc.indicator_column(a)
3471    features = {
3472        'aaa': sparse_tensor.SparseTensorValue(
3473            indices=((0, 0), (1, 0), (1, 1)),
3474            values=('marlo', 'skywalker', 'omar'),
3475            dense_shape=(2, 2))
3476    }
3477    indicator_tensor = _transform_features(features, [a_indicator])[a_indicator]
3478    with _initialized_session():
3479      self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval())
3480
3481  def test_transform_with_weighted_column(self):
3482    # Github issue 12557
3483    ids = fc.categorical_column_with_vocabulary_list(
3484        key='ids', vocabulary_list=('a', 'b', 'c'))
3485    weights = fc.weighted_categorical_column(ids, 'weights')
3486    indicator = fc.indicator_column(weights)
3487    features = {
3488        'ids': constant_op.constant([['c', 'b', 'a']]),
3489        'weights': constant_op.constant([[2., 4., 6.]])
3490    }
3491    indicator_tensor = _transform_features(features, [indicator])[indicator]
3492    with _initialized_session():
3493      self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
3494
3495  def test_transform_with_missing_value_in_weighted_column(self):
3496    # Github issue 12583
3497    ids = fc.categorical_column_with_vocabulary_list(
3498        key='ids', vocabulary_list=('a', 'b', 'c'))
3499    weights = fc.weighted_categorical_column(ids, 'weights')
3500    indicator = fc.indicator_column(weights)
3501    features = {
3502        'ids': constant_op.constant([['c', 'b', 'unknown']]),
3503        'weights': constant_op.constant([[2., 4., 6.]])
3504    }
3505    indicator_tensor = _transform_features(features, [indicator])[indicator]
3506    with _initialized_session():
3507      self.assertAllEqual([[0., 4., 2.]], indicator_tensor.eval())
3508
3509  def test_transform_with_missing_value_in_categorical_column(self):
3510    # Github issue 12583
3511    ids = fc.categorical_column_with_vocabulary_list(
3512        key='ids', vocabulary_list=('a', 'b', 'c'))
3513    indicator = fc.indicator_column(ids)
3514    features = {
3515        'ids': constant_op.constant([['c', 'b', 'unknown']]),
3516    }
3517    indicator_tensor = _transform_features(features, [indicator])[indicator]
3518    with _initialized_session():
3519      self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
3520
3521  def test_linear_model(self):
3522    animal = fc.indicator_column(
3523        fc.categorical_column_with_identity('animal', num_buckets=4))
3524    with ops.Graph().as_default():
3525      features = {
3526          'animal':
3527              sparse_tensor.SparseTensor(
3528                  indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
3529      }
3530
3531      predictions = fc.linear_model(features, [animal])
3532      weight_var = get_linear_model_column_var(animal)
3533      with _initialized_session():
3534        # All should be zero-initialized.
3535        self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
3536        self.assertAllClose([[0.]], predictions.eval())
3537        weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
3538        self.assertAllClose([[2. + 3.]], predictions.eval())
3539
3540  def test_input_layer(self):
3541    animal = fc.indicator_column(
3542        fc.categorical_column_with_identity('animal', num_buckets=4))
3543    with ops.Graph().as_default():
3544      features = {
3545          'animal':
3546              sparse_tensor.SparseTensor(
3547                  indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
3548      }
3549      net = fc.input_layer(features, [animal])
3550      with _initialized_session():
3551        self.assertAllClose([[0., 1., 1., 0.]], net.eval())
3552
3553
3554class EmbeddingColumnTest(test.TestCase):
3555
3556  def test_defaults(self):
3557    categorical_column = fc.categorical_column_with_identity(
3558        key='aaa', num_buckets=3)
3559    embedding_dimension = 2
3560    embedding_column = fc.embedding_column(
3561        categorical_column, dimension=embedding_dimension)
3562    self.assertIs(categorical_column, embedding_column.categorical_column)
3563    self.assertEqual(embedding_dimension, embedding_column.dimension)
3564    self.assertEqual('mean', embedding_column.combiner)
3565    self.assertIsNotNone(embedding_column.initializer)
3566    self.assertIsNone(embedding_column.ckpt_to_load_from)
3567    self.assertIsNone(embedding_column.tensor_name_in_ckpt)
3568    self.assertIsNone(embedding_column.max_norm)
3569    self.assertTrue(embedding_column.trainable)
3570    self.assertEqual('aaa_embedding', embedding_column.name)
3571    self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
3572    self.assertEqual(
3573        (embedding_dimension,), embedding_column._variable_shape)
3574    self.assertEqual({
3575        'aaa': parsing_ops.VarLenFeature(dtypes.int64)
3576    }, embedding_column._parse_example_spec)
3577
3578  def test_all_constructor_args(self):
3579    categorical_column = fc.categorical_column_with_identity(
3580        key='aaa', num_buckets=3)
3581    embedding_dimension = 2
3582    embedding_column = fc.embedding_column(
3583        categorical_column, dimension=embedding_dimension,
3584        combiner='my_combiner', initializer=lambda: 'my_initializer',
3585        ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor',
3586        max_norm=42., trainable=False)
3587    self.assertIs(categorical_column, embedding_column.categorical_column)
3588    self.assertEqual(embedding_dimension, embedding_column.dimension)
3589    self.assertEqual('my_combiner', embedding_column.combiner)
3590    self.assertEqual('my_initializer', embedding_column.initializer())
3591    self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
3592    self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
3593    self.assertEqual(42., embedding_column.max_norm)
3594    self.assertFalse(embedding_column.trainable)
3595    self.assertEqual('aaa_embedding', embedding_column.name)
3596    self.assertEqual('aaa_embedding', embedding_column._var_scope_name)
3597    self.assertEqual(
3598        (embedding_dimension,), embedding_column._variable_shape)
3599    self.assertEqual({
3600        'aaa': parsing_ops.VarLenFeature(dtypes.int64)
3601    }, embedding_column._parse_example_spec)
3602
3603  def test_deep_copy(self):
3604    categorical_column = fc.categorical_column_with_identity(
3605        key='aaa', num_buckets=3)
3606    embedding_dimension = 2
3607    original = fc.embedding_column(
3608        categorical_column, dimension=embedding_dimension,
3609        combiner='my_combiner', initializer=lambda: 'my_initializer',
3610        ckpt_to_load_from='my_ckpt', tensor_name_in_ckpt='my_ckpt_tensor',
3611        max_norm=42., trainable=False)
3612    for embedding_column in (original, copy.deepcopy(original)):
3613      self.assertEqual('aaa', embedding_column.categorical_column.name)
3614      self.assertEqual(3, embedding_column.categorical_column._num_buckets)
3615      self.assertEqual({
3616          'aaa': parsing_ops.VarLenFeature(dtypes.int64)
3617      }, embedding_column.categorical_column._parse_example_spec)
3618
3619      self.assertEqual(embedding_dimension, embedding_column.dimension)
3620      self.assertEqual('my_combiner', embedding_column.combiner)
3621      self.assertEqual('my_initializer', embedding_column.initializer())
3622      self.assertEqual('my_ckpt', embedding_column.ckpt_to_load_from)
3623      self.assertEqual('my_ckpt_tensor', embedding_column.tensor_name_in_ckpt)
3624      self.assertEqual(42., embedding_column.max_norm)
3625      self.assertFalse(embedding_column.trainable)
3626      self.assertEqual('aaa_embedding', embedding_column.name)
3627      self.assertEqual(
3628          (embedding_dimension,), embedding_column._variable_shape)
3629      self.assertEqual({
3630          'aaa': parsing_ops.VarLenFeature(dtypes.int64)
3631      }, embedding_column._parse_example_spec)
3632
3633  def test_invalid_initializer(self):
3634    categorical_column = fc.categorical_column_with_identity(
3635        key='aaa', num_buckets=3)
3636    with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
3637      fc.embedding_column(categorical_column, dimension=2, initializer='not_fn')
3638
3639  def test_parse_example(self):
3640    a = fc.categorical_column_with_vocabulary_list(
3641        key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
3642    a_embedded = fc.embedding_column(a, dimension=2)
3643    data = example_pb2.Example(features=feature_pb2.Features(
3644        feature={
3645            'aaa':
3646                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
3647                    value=[b'omar', b'stringer']))
3648        }))
3649    features = parsing_ops.parse_example(
3650        serialized=[data.SerializeToString()],
3651        features=fc.make_parse_example_spec([a_embedded]))
3652    self.assertIn('aaa', features)
3653    with self.test_session():
3654      _assert_sparse_tensor_value(
3655          self,
3656          sparse_tensor.SparseTensorValue(
3657              indices=[[0, 0], [0, 1]],
3658              values=np.array([b'omar', b'stringer'], dtype=np.object_),
3659              dense_shape=[1, 2]),
3660          features['aaa'].eval())
3661
3662  def test_transform_feature(self):
3663    a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
3664    a_embedded = fc.embedding_column(a, dimension=2)
3665    features = {
3666        'aaa': sparse_tensor.SparseTensor(
3667            indices=((0, 0), (1, 0), (1, 1)),
3668            values=(0, 1, 0),
3669            dense_shape=(2, 2))
3670    }
3671    outputs = _transform_features(features, [a, a_embedded])
3672    output_a = outputs[a]
3673    output_embedded = outputs[a_embedded]
3674    with _initialized_session():
3675      _assert_sparse_tensor_value(
3676          self, output_a.eval(), output_embedded.eval())
3677
3678  def test_get_dense_tensor(self):
3679    # Inputs.
3680    vocabulary_size = 3
3681    sparse_input = sparse_tensor.SparseTensorValue(
3682        # example 0, ids [2]
3683        # example 1, ids [0, 1]
3684        # example 2, ids []
3685        # example 3, ids [1]
3686        indices=((0, 0), (1, 0), (1, 4), (3, 0)),
3687        values=(2, 0, 1, 1),
3688        dense_shape=(4, 5))
3689
3690    # Embedding variable.
3691    embedding_dimension = 2
3692    embedding_values = (
3693        (1., 2.),  # id 0
3694        (3., 5.),  # id 1
3695        (7., 11.)  # id 2
3696    )
3697    def _initializer(shape, dtype, partition_info):
3698      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
3699      self.assertEqual(dtypes.float32, dtype)
3700      self.assertIsNone(partition_info)
3701      return embedding_values
3702
3703    # Expected lookup result, using combiner='mean'.
3704    expected_lookups = (
3705        # example 0, ids [2], embedding = [7, 11]
3706        (7., 11.),
3707        # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
3708        (2., 3.5),
3709        # example 2, ids [], embedding = [0, 0]
3710        (0., 0.),
3711        # example 3, ids [1], embedding = [3, 5]
3712        (3., 5.),
3713    )
3714
3715    # Build columns.
3716    categorical_column = fc.categorical_column_with_identity(
3717        key='aaa', num_buckets=vocabulary_size)
3718    embedding_column = fc.embedding_column(
3719        categorical_column, dimension=embedding_dimension,
3720        initializer=_initializer)
3721
3722    # Provide sparse input and get dense result.
3723    embedding_lookup = embedding_column._get_dense_tensor(
3724        _LazyBuilder({
3725            'aaa': sparse_input
3726        }))
3727
3728    # Assert expected embedding variable and lookups.
3729    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
3730    self.assertItemsEqual(
3731        ('embedding_weights:0',), tuple([v.name for v in global_vars]))
3732    with _initialized_session():
3733      self.assertAllEqual(embedding_values, global_vars[0].eval())
3734      self.assertAllEqual(expected_lookups, embedding_lookup.eval())
3735
3736  def test_get_dense_tensor_3d(self):
3737    # Inputs.
3738    vocabulary_size = 4
3739    sparse_input = sparse_tensor.SparseTensorValue(
3740        # example 0, ids [2]
3741        # example 1, ids [0, 1]
3742        # example 2, ids []
3743        # example 3, ids [1]
3744        indices=((0, 0, 0), (1, 1, 0), (1, 1, 4), (3, 0, 0), (3, 1, 2)),
3745        values=(2, 0, 1, 1, 2),
3746        dense_shape=(4, 2, 5))
3747
3748    # Embedding variable.
3749    embedding_dimension = 3
3750    embedding_values = (
3751        (1., 2., 4.),   # id 0
3752        (3., 5., 1.),   # id 1
3753        (7., 11., 2.),  # id 2
3754        (2., 7., 12.)   # id 3
3755    )
3756    def _initializer(shape, dtype, partition_info):
3757      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
3758      self.assertEqual(dtypes.float32, dtype)
3759      self.assertIsNone(partition_info)
3760      return embedding_values
3761
3762    # Expected lookup result, using combiner='mean'.
3763    expected_lookups = (
3764        # example 0, ids [[2], []], embedding = [[7, 11, 2], [0, 0, 0]]
3765        ((7., 11., 2.), (0., 0., 0.)),
3766        # example 1, ids [[], [0, 1]], embedding
3767        # = mean([[], [1, 2, 4] + [3, 5, 1]]) = [[0, 0, 0], [2, 3.5, 2.5]]
3768        ((0., 0., 0.), (2., 3.5, 2.5)),
3769        # example 2, ids [[], []], embedding = [[0, 0, 0], [0, 0, 0]]
3770        ((0., 0., 0.), (0., 0., 0.)),
3771        # example 3, ids [[1], [2]], embedding = [[3, 5, 1], [7, 11, 2]]
3772        ((3., 5., 1.), (7., 11., 2.)),
3773    )
3774
3775    # Build columns.
3776    categorical_column = fc.categorical_column_with_identity(
3777        key='aaa', num_buckets=vocabulary_size)
3778    embedding_column = fc.embedding_column(
3779        categorical_column, dimension=embedding_dimension,
3780        initializer=_initializer)
3781
3782    # Provide sparse input and get dense result.
3783    embedding_lookup = embedding_column._get_dense_tensor(
3784        _LazyBuilder({
3785            'aaa': sparse_input
3786        }))
3787
3788    # Assert expected embedding variable and lookups.
3789    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
3790    self.assertItemsEqual(
3791        ('embedding_weights:0',), tuple([v.name for v in global_vars]))
3792    with _initialized_session():
3793      self.assertAllEqual(embedding_values, global_vars[0].eval())
3794      self.assertAllEqual(expected_lookups, embedding_lookup.eval())
3795
3796  def test_get_dense_tensor_weight_collections(self):
3797    sparse_input = sparse_tensor.SparseTensorValue(
3798        # example 0, ids [2]
3799        # example 1, ids [0, 1]
3800        # example 2, ids []
3801        # example 3, ids [1]
3802        indices=((0, 0), (1, 0), (1, 4), (3, 0)),
3803        values=(2, 0, 1, 1),
3804        dense_shape=(4, 5))
3805
3806    # Build columns.
3807    categorical_column = fc.categorical_column_with_identity(
3808        key='aaa', num_buckets=3)
3809    embedding_column = fc.embedding_column(categorical_column, dimension=2)
3810
3811    # Provide sparse input and get dense result.
3812    embedding_column._get_dense_tensor(
3813        _LazyBuilder({
3814            'aaa': sparse_input
3815        }), weight_collections=('my_vars',))
3816
3817    # Assert expected embedding variable and lookups.
3818    self.assertItemsEqual(
3819        [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
3820    my_vars = ops.get_collection('my_vars')
3821    self.assertItemsEqual(
3822        ('embedding_weights:0',), tuple([v.name for v in my_vars]))
3823
3824  def test_get_dense_tensor_placeholder_inputs(self):
3825    # Inputs.
3826    vocabulary_size = 3
3827    sparse_input = sparse_tensor.SparseTensorValue(
3828        # example 0, ids [2]
3829        # example 1, ids [0, 1]
3830        # example 2, ids []
3831        # example 3, ids [1]
3832        indices=((0, 0), (1, 0), (1, 4), (3, 0)),
3833        values=(2, 0, 1, 1),
3834        dense_shape=(4, 5))
3835
3836    # Embedding variable.
3837    embedding_dimension = 2
3838    embedding_values = (
3839        (1., 2.),  # id 0
3840        (3., 5.),  # id 1
3841        (7., 11.)  # id 2
3842    )
3843    def _initializer(shape, dtype, partition_info):
3844      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
3845      self.assertEqual(dtypes.float32, dtype)
3846      self.assertIsNone(partition_info)
3847      return embedding_values
3848
3849    # Expected lookup result, using combiner='mean'.
3850    expected_lookups = (
3851        # example 0, ids [2], embedding = [7, 11]
3852        (7., 11.),
3853        # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
3854        (2., 3.5),
3855        # example 2, ids [], embedding = [0, 0]
3856        (0., 0.),
3857        # example 3, ids [1], embedding = [3, 5]
3858        (3., 5.),
3859    )
3860
3861    # Build columns.
3862    categorical_column = fc.categorical_column_with_identity(
3863        key='aaa', num_buckets=vocabulary_size)
3864    embedding_column = fc.embedding_column(
3865        categorical_column, dimension=embedding_dimension,
3866        initializer=_initializer)
3867
3868    # Provide sparse input and get dense result.
3869    input_indices = array_ops.placeholder(dtype=dtypes.int64)
3870    input_values = array_ops.placeholder(dtype=dtypes.int64)
3871    input_shape = array_ops.placeholder(dtype=dtypes.int64)
3872    embedding_lookup = embedding_column._get_dense_tensor(
3873        _LazyBuilder({
3874            'aaa':
3875                sparse_tensor.SparseTensorValue(
3876                    indices=input_indices,
3877                    values=input_values,
3878                    dense_shape=input_shape)
3879        }))
3880
3881    # Assert expected embedding variable and lookups.
3882    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
3883    self.assertItemsEqual(
3884        ('embedding_weights:0',), tuple([v.name for v in global_vars]))
3885    with _initialized_session():
3886      self.assertAllEqual(embedding_values, global_vars[0].eval())
3887      self.assertAllEqual(expected_lookups, embedding_lookup.eval(
3888          feed_dict={
3889              input_indices: sparse_input.indices,
3890              input_values: sparse_input.values,
3891              input_shape: sparse_input.dense_shape,
3892          }))
3893
3894  def test_get_dense_tensor_restore_from_ckpt(self):
3895    # Inputs.
3896    vocabulary_size = 3
3897    sparse_input = sparse_tensor.SparseTensorValue(
3898        # example 0, ids [2]
3899        # example 1, ids [0, 1]
3900        # example 2, ids []
3901        # example 3, ids [1]
3902        indices=((0, 0), (1, 0), (1, 4), (3, 0)),
3903        values=(2, 0, 1, 1),
3904        dense_shape=(4, 5))
3905
3906    # Embedding variable. The checkpoint file contains _embedding_values.
3907    embedding_dimension = 2
3908    embedding_values = (
3909        (1., 2.),  # id 0
3910        (3., 5.),  # id 1
3911        (7., 11.)  # id 2
3912    )
3913    ckpt_path = test.test_src_dir_path(
3914        'python/feature_column/testdata/embedding.ckpt')
3915    ckpt_tensor = 'my_embedding'
3916
3917    # Expected lookup result, using combiner='mean'.
3918    expected_lookups = (
3919        # example 0, ids [2], embedding = [7, 11]
3920        (7., 11.),
3921        # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
3922        (2., 3.5),
3923        # example 2, ids [], embedding = [0, 0]
3924        (0., 0.),
3925        # example 3, ids [1], embedding = [3, 5]
3926        (3., 5.),
3927    )
3928
3929    # Build columns.
3930    categorical_column = fc.categorical_column_with_identity(
3931        key='aaa', num_buckets=vocabulary_size)
3932    embedding_column = fc.embedding_column(
3933        categorical_column, dimension=embedding_dimension,
3934        ckpt_to_load_from=ckpt_path,
3935        tensor_name_in_ckpt=ckpt_tensor)
3936
3937    # Provide sparse input and get dense result.
3938    embedding_lookup = embedding_column._get_dense_tensor(
3939        _LazyBuilder({
3940            'aaa': sparse_input
3941        }))
3942
3943    # Assert expected embedding variable and lookups.
3944    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
3945    self.assertItemsEqual(
3946        ('embedding_weights:0',), tuple([v.name for v in global_vars]))
3947    with _initialized_session():
3948      self.assertAllEqual(embedding_values, global_vars[0].eval())
3949      self.assertAllEqual(expected_lookups, embedding_lookup.eval())
3950
3951  def test_linear_model(self):
3952    # Inputs.
3953    batch_size = 4
3954    vocabulary_size = 3
3955    sparse_input = sparse_tensor.SparseTensorValue(
3956        # example 0, ids [2]
3957        # example 1, ids [0, 1]
3958        # example 2, ids []
3959        # example 3, ids [1]
3960        indices=((0, 0), (1, 0), (1, 4), (3, 0)),
3961        values=(2, 0, 1, 1),
3962        dense_shape=(batch_size, 5))
3963
3964    # Embedding variable.
3965    embedding_dimension = 2
3966    embedding_shape = (vocabulary_size, embedding_dimension)
3967    zeros_embedding_values = np.zeros(embedding_shape)
3968    def _initializer(shape, dtype, partition_info):
3969      self.assertAllEqual(embedding_shape, shape)
3970      self.assertEqual(dtypes.float32, dtype)
3971      self.assertIsNone(partition_info)
3972      return zeros_embedding_values
3973
3974    # Build columns.
3975    categorical_column = fc.categorical_column_with_identity(
3976        key='aaa', num_buckets=vocabulary_size)
3977    embedding_column = fc.embedding_column(
3978        categorical_column, dimension=embedding_dimension,
3979        initializer=_initializer)
3980
3981    with ops.Graph().as_default():
3982      predictions = fc.linear_model({
3983          categorical_column.name: sparse_input
3984      }, (embedding_column,))
3985      expected_var_names = (
3986          'linear_model/bias_weights:0',
3987          'linear_model/aaa_embedding/weights:0',
3988          'linear_model/aaa_embedding/embedding_weights:0',
3989      )
3990      self.assertItemsEqual(
3991          expected_var_names,
3992          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
3993      trainable_vars = {
3994          v.name: v for v in ops.get_collection(
3995              ops.GraphKeys.TRAINABLE_VARIABLES)
3996      }
3997      self.assertItemsEqual(expected_var_names, trainable_vars.keys())
3998      bias = trainable_vars['linear_model/bias_weights:0']
3999      embedding_weights = trainable_vars[
4000          'linear_model/aaa_embedding/embedding_weights:0']
4001      linear_weights = trainable_vars[
4002          'linear_model/aaa_embedding/weights:0']
4003      with _initialized_session():
4004        # Predictions with all zero weights.
4005        self.assertAllClose(np.zeros((1,)), bias.eval())
4006        self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
4007        self.assertAllClose(
4008            np.zeros((embedding_dimension, 1)), linear_weights.eval())
4009        self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
4010
4011        # Predictions with all non-zero weights.
4012        embedding_weights.assign((
4013            (1., 2.),  # id 0
4014            (3., 5.),  # id 1
4015            (7., 11.)  # id 2
4016        )).eval()
4017        linear_weights.assign(((4.,), (6.,))).eval()
4018        # example 0, ids [2], embedding[0] = [7, 11]
4019        # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
4020        # example 2, ids [], embedding[2] = [0, 0]
4021        # example 3, ids [1], embedding[3] = [3, 5]
4022        # sum(embeddings * linear_weights)
4023        # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
4024        self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
4025
4026  def test_input_layer(self):
4027    # Inputs.
4028    vocabulary_size = 3
4029    sparse_input = sparse_tensor.SparseTensorValue(
4030        # example 0, ids [2]
4031        # example 1, ids [0, 1]
4032        # example 2, ids []
4033        # example 3, ids [1]
4034        indices=((0, 0), (1, 0), (1, 4), (3, 0)),
4035        values=(2, 0, 1, 1),
4036        dense_shape=(4, 5))
4037
4038    # Embedding variable.
4039    embedding_dimension = 2
4040    embedding_values = (
4041        (1., 2.),  # id 0
4042        (3., 5.),  # id 1
4043        (7., 11.)  # id 2
4044    )
4045    def _initializer(shape, dtype, partition_info):
4046      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
4047      self.assertEqual(dtypes.float32, dtype)
4048      self.assertIsNone(partition_info)
4049      return embedding_values
4050
4051    # Expected lookup result, using combiner='mean'.
4052    expected_lookups = (
4053        # example 0, ids [2], embedding = [7, 11]
4054        (7., 11.),
4055        # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
4056        (2., 3.5),
4057        # example 2, ids [], embedding = [0, 0]
4058        (0., 0.),
4059        # example 3, ids [1], embedding = [3, 5]
4060        (3., 5.),
4061    )
4062
4063    # Build columns.
4064    categorical_column = fc.categorical_column_with_identity(
4065        key='aaa', num_buckets=vocabulary_size)
4066    embedding_column = fc.embedding_column(
4067        categorical_column, dimension=embedding_dimension,
4068        initializer=_initializer)
4069
4070    # Provide sparse input and get dense result.
4071    input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
4072
4073    # Assert expected embedding variable and lookups.
4074    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
4075    self.assertItemsEqual(
4076        ('input_layer/aaa_embedding/embedding_weights:0',),
4077        tuple([v.name for v in global_vars]))
4078    trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
4079    self.assertItemsEqual(
4080        ('input_layer/aaa_embedding/embedding_weights:0',),
4081        tuple([v.name for v in trainable_vars]))
4082    with _initialized_session():
4083      self.assertAllEqual(embedding_values, trainable_vars[0].eval())
4084      self.assertAllEqual(expected_lookups, input_layer.eval())
4085
4086  def test_input_layer_not_trainable(self):
4087    # Inputs.
4088    vocabulary_size = 3
4089    sparse_input = sparse_tensor.SparseTensorValue(
4090        # example 0, ids [2]
4091        # example 1, ids [0, 1]
4092        # example 2, ids []
4093        # example 3, ids [1]
4094        indices=((0, 0), (1, 0), (1, 4), (3, 0)),
4095        values=(2, 0, 1, 1),
4096        dense_shape=(4, 5))
4097
4098    # Embedding variable.
4099    embedding_dimension = 2
4100    embedding_values = (
4101        (1., 2.),  # id 0
4102        (3., 5.),  # id 1
4103        (7., 11.)  # id 2
4104    )
4105    def _initializer(shape, dtype, partition_info):
4106      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
4107      self.assertEqual(dtypes.float32, dtype)
4108      self.assertIsNone(partition_info)
4109      return embedding_values
4110
4111    # Expected lookup result, using combiner='mean'.
4112    expected_lookups = (
4113        # example 0, ids [2], embedding = [7, 11]
4114        (7., 11.),
4115        # example 1, ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
4116        (2., 3.5),
4117        # example 2, ids [], embedding = [0, 0]
4118        (0., 0.),
4119        # example 3, ids [1], embedding = [3, 5]
4120        (3., 5.),
4121    )
4122
4123    # Build columns.
4124    categorical_column = fc.categorical_column_with_identity(
4125        key='aaa', num_buckets=vocabulary_size)
4126    embedding_column = fc.embedding_column(
4127        categorical_column, dimension=embedding_dimension,
4128        initializer=_initializer, trainable=False)
4129
4130    # Provide sparse input and get dense result.
4131    input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
4132
4133    # Assert expected embedding variable and lookups.
4134    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
4135    self.assertItemsEqual(
4136        ('input_layer/aaa_embedding/embedding_weights:0',),
4137        tuple([v.name for v in global_vars]))
4138    self.assertItemsEqual(
4139        [], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
4140    with _initialized_session():
4141      self.assertAllEqual(embedding_values, global_vars[0].eval())
4142      self.assertAllEqual(expected_lookups, input_layer.eval())
4143
4144
4145class SharedEmbeddingColumnTest(test.TestCase):
4146
4147  def test_defaults(self):
4148    categorical_column_a = fc.categorical_column_with_identity(
4149        key='aaa', num_buckets=3)
4150    categorical_column_b = fc.categorical_column_with_identity(
4151        key='bbb', num_buckets=3)
4152    embedding_dimension = 2
4153    embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
4154        [categorical_column_b, categorical_column_a],
4155        dimension=embedding_dimension)
4156    self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
4157    self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
4158    self.assertEqual(embedding_dimension, embedding_column_a.dimension)
4159    self.assertEqual(embedding_dimension, embedding_column_b.dimension)
4160    self.assertEqual('mean', embedding_column_a.combiner)
4161    self.assertEqual('mean', embedding_column_b.combiner)
4162    self.assertIsNotNone(embedding_column_a.initializer)
4163    self.assertIsNotNone(embedding_column_b.initializer)
4164    self.assertIsNone(embedding_column_a.ckpt_to_load_from)
4165    self.assertIsNone(embedding_column_b.ckpt_to_load_from)
4166    self.assertEqual('aaa_bbb_shared_embedding',
4167                     embedding_column_a.shared_embedding_collection_name)
4168    self.assertEqual('aaa_bbb_shared_embedding',
4169                     embedding_column_b.shared_embedding_collection_name)
4170    self.assertIsNone(embedding_column_a.tensor_name_in_ckpt)
4171    self.assertIsNone(embedding_column_b.tensor_name_in_ckpt)
4172    self.assertIsNone(embedding_column_a.max_norm)
4173    self.assertIsNone(embedding_column_b.max_norm)
4174    self.assertTrue(embedding_column_a.trainable)
4175    self.assertTrue(embedding_column_b.trainable)
4176    self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
4177    self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
4178    self.assertEqual(
4179        'aaa_bbb_shared_embedding', embedding_column_a._var_scope_name)
4180    self.assertEqual(
4181        'aaa_bbb_shared_embedding', embedding_column_b._var_scope_name)
4182    self.assertEqual(
4183        (embedding_dimension,), embedding_column_a._variable_shape)
4184    self.assertEqual(
4185        (embedding_dimension,), embedding_column_b._variable_shape)
4186    self.assertEqual({
4187        'aaa': parsing_ops.VarLenFeature(dtypes.int64)
4188    }, embedding_column_a._parse_example_spec)
4189    self.assertEqual({
4190        'bbb': parsing_ops.VarLenFeature(dtypes.int64)
4191    }, embedding_column_b._parse_example_spec)
4192
4193  def test_all_constructor_args(self):
4194    categorical_column_a = fc.categorical_column_with_identity(
4195        key='aaa', num_buckets=3)
4196    categorical_column_b = fc.categorical_column_with_identity(
4197        key='bbb', num_buckets=3)
4198    embedding_dimension = 2
4199    embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
4200        [categorical_column_a, categorical_column_b],
4201        dimension=embedding_dimension,
4202        combiner='my_combiner',
4203        initializer=lambda: 'my_initializer',
4204        shared_embedding_collection_name='shared_embedding_collection_name',
4205        ckpt_to_load_from='my_ckpt',
4206        tensor_name_in_ckpt='my_ckpt_tensor',
4207        max_norm=42.,
4208        trainable=False)
4209    self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
4210    self.assertIs(categorical_column_b, embedding_column_b.categorical_column)
4211    self.assertEqual(embedding_dimension, embedding_column_a.dimension)
4212    self.assertEqual(embedding_dimension, embedding_column_b.dimension)
4213    self.assertEqual('my_combiner', embedding_column_a.combiner)
4214    self.assertEqual('my_combiner', embedding_column_b.combiner)
4215    self.assertEqual('my_initializer', embedding_column_a.initializer())
4216    self.assertEqual('my_initializer', embedding_column_b.initializer())
4217    self.assertEqual('shared_embedding_collection_name',
4218                     embedding_column_a.shared_embedding_collection_name)
4219    self.assertEqual('shared_embedding_collection_name',
4220                     embedding_column_b.shared_embedding_collection_name)
4221    self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
4222    self.assertEqual('my_ckpt', embedding_column_b.ckpt_to_load_from)
4223    self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
4224    self.assertEqual('my_ckpt_tensor', embedding_column_b.tensor_name_in_ckpt)
4225    self.assertEqual(42., embedding_column_a.max_norm)
4226    self.assertEqual(42., embedding_column_b.max_norm)
4227    self.assertFalse(embedding_column_a.trainable)
4228    self.assertFalse(embedding_column_b.trainable)
4229    self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
4230    self.assertEqual('bbb_shared_embedding', embedding_column_b.name)
4231    self.assertEqual(
4232        'shared_embedding_collection_name', embedding_column_a._var_scope_name)
4233    self.assertEqual(
4234        'shared_embedding_collection_name', embedding_column_b._var_scope_name)
4235    self.assertEqual(
4236        (embedding_dimension,), embedding_column_a._variable_shape)
4237    self.assertEqual(
4238        (embedding_dimension,), embedding_column_b._variable_shape)
4239    self.assertEqual({
4240        'aaa': parsing_ops.VarLenFeature(dtypes.int64)
4241    }, embedding_column_a._parse_example_spec)
4242    self.assertEqual({
4243        'bbb': parsing_ops.VarLenFeature(dtypes.int64)
4244    }, embedding_column_b._parse_example_spec)
4245
4246  def test_deep_copy(self):
4247    categorical_column_a = fc.categorical_column_with_identity(
4248        key='aaa', num_buckets=3)
4249    categorical_column_b = fc.categorical_column_with_identity(
4250        key='bbb', num_buckets=3)
4251    embedding_dimension = 2
4252    original_a, _ = fc.shared_embedding_columns(
4253        [categorical_column_a, categorical_column_b],
4254        dimension=embedding_dimension,
4255        combiner='my_combiner',
4256        initializer=lambda: 'my_initializer',
4257        shared_embedding_collection_name='shared_embedding_collection_name',
4258        ckpt_to_load_from='my_ckpt',
4259        tensor_name_in_ckpt='my_ckpt_tensor',
4260        max_norm=42., trainable=False)
4261    for embedding_column_a in (original_a, copy.deepcopy(original_a)):
4262      self.assertEqual('aaa', embedding_column_a.categorical_column.name)
4263      self.assertEqual(3, embedding_column_a.categorical_column._num_buckets)
4264      self.assertEqual({
4265          'aaa': parsing_ops.VarLenFeature(dtypes.int64)
4266      }, embedding_column_a.categorical_column._parse_example_spec)
4267
4268      self.assertEqual(embedding_dimension, embedding_column_a.dimension)
4269      self.assertEqual('my_combiner', embedding_column_a.combiner)
4270      self.assertEqual('my_initializer', embedding_column_a.initializer())
4271      self.assertEqual('shared_embedding_collection_name',
4272                       embedding_column_a.shared_embedding_collection_name)
4273      self.assertEqual('my_ckpt', embedding_column_a.ckpt_to_load_from)
4274      self.assertEqual('my_ckpt_tensor', embedding_column_a.tensor_name_in_ckpt)
4275      self.assertEqual(42., embedding_column_a.max_norm)
4276      self.assertFalse(embedding_column_a.trainable)
4277      self.assertEqual('aaa_shared_embedding', embedding_column_a.name)
4278      self.assertEqual(
4279          (embedding_dimension,), embedding_column_a._variable_shape)
4280      self.assertEqual({
4281          'aaa': parsing_ops.VarLenFeature(dtypes.int64)
4282      }, embedding_column_a._parse_example_spec)
4283
4284  def test_invalid_initializer(self):
4285    categorical_column_a = fc.categorical_column_with_identity(
4286        key='aaa', num_buckets=3)
4287    categorical_column_b = fc.categorical_column_with_identity(
4288        key='bbb', num_buckets=3)
4289    with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
4290      fc.shared_embedding_columns(
4291          [categorical_column_a, categorical_column_b], dimension=2,
4292          initializer='not_fn')
4293
4294  def test_incompatible_column_type(self):
4295    categorical_column_a = fc.categorical_column_with_identity(
4296        key='aaa', num_buckets=3)
4297    categorical_column_b = fc.categorical_column_with_identity(
4298        key='bbb', num_buckets=3)
4299    categorical_column_c = fc.categorical_column_with_hash_bucket(
4300        key='ccc', hash_bucket_size=3)
4301    with self.assertRaisesRegexp(
4302        ValueError,
4303        'all categorical_columns must have the same type.*'
4304        '_IdentityCategoricalColumn.*_HashedCategoricalColumn'):
4305      fc.shared_embedding_columns(
4306          [categorical_column_a, categorical_column_b, categorical_column_c],
4307          dimension=2)
4308
4309  def test_weighted_categorical_column_ok(self):
4310    categorical_column_a = fc.categorical_column_with_identity(
4311        key='aaa', num_buckets=3)
4312    weighted_categorical_column_a = fc.weighted_categorical_column(
4313        categorical_column_a, weight_feature_key='aaa_weights')
4314    categorical_column_b = fc.categorical_column_with_identity(
4315        key='bbb', num_buckets=3)
4316    weighted_categorical_column_b = fc.weighted_categorical_column(
4317        categorical_column_b, weight_feature_key='bbb_weights')
4318    fc.shared_embedding_columns(
4319        [weighted_categorical_column_a, categorical_column_b], dimension=2)
4320    fc.shared_embedding_columns(
4321        [categorical_column_a, weighted_categorical_column_b], dimension=2)
4322    fc.shared_embedding_columns(
4323        [weighted_categorical_column_a, weighted_categorical_column_b],
4324        dimension=2)
4325
4326  def test_parse_example(self):
4327    a = fc.categorical_column_with_vocabulary_list(
4328        key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
4329    b = fc.categorical_column_with_vocabulary_list(
4330        key='bbb', vocabulary_list=('omar', 'stringer', 'marlo'))
4331    a_embedded, b_embedded = fc.shared_embedding_columns(
4332        [a, b], dimension=2)
4333    data = example_pb2.Example(features=feature_pb2.Features(
4334        feature={
4335            'aaa':
4336                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
4337                    value=[b'omar', b'stringer'])),
4338            'bbb':
4339                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
4340                    value=[b'stringer', b'marlo'])),
4341        }))
4342    features = parsing_ops.parse_example(
4343        serialized=[data.SerializeToString()],
4344        features=fc.make_parse_example_spec([a_embedded, b_embedded]))
4345    self.assertIn('aaa', features)
4346    self.assertIn('bbb', features)
4347    with self.test_session():
4348      _assert_sparse_tensor_value(
4349          self,
4350          sparse_tensor.SparseTensorValue(
4351              indices=[[0, 0], [0, 1]],
4352              values=np.array([b'omar', b'stringer'], dtype=np.object_),
4353              dense_shape=[1, 2]),
4354          features['aaa'].eval())
4355      _assert_sparse_tensor_value(
4356          self,
4357          sparse_tensor.SparseTensorValue(
4358              indices=[[0, 0], [0, 1]],
4359              values=np.array([b'stringer', b'marlo'], dtype=np.object_),
4360              dense_shape=[1, 2]),
4361          features['bbb'].eval())
4362
4363  def test_transform_feature(self):
4364    a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
4365    b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
4366    a_embedded, b_embedded = fc.shared_embedding_columns(
4367        [a, b], dimension=2)
4368    features = {
4369        'aaa': sparse_tensor.SparseTensor(
4370            indices=((0, 0), (1, 0), (1, 1)),
4371            values=(0, 1, 0),
4372            dense_shape=(2, 2)),
4373        'bbb': sparse_tensor.SparseTensor(
4374            indices=((0, 0), (1, 0), (1, 1)),
4375            values=(1, 2, 1),
4376            dense_shape=(2, 2)),
4377    }
4378    outputs = _transform_features(features, [a, a_embedded, b, b_embedded])
4379    output_a = outputs[a]
4380    output_a_embedded = outputs[a_embedded]
4381    output_b = outputs[b]
4382    output_b_embedded = outputs[b_embedded]
4383    with _initialized_session():
4384      _assert_sparse_tensor_value(
4385          self, output_a.eval(), output_a_embedded.eval())
4386      _assert_sparse_tensor_value(
4387          self, output_b.eval(), output_b_embedded.eval())
4388
4389  def test_get_dense_tensor(self):
4390    # Inputs.
4391    vocabulary_size = 3
4392    # -1 values are ignored.
4393    input_a = np.array(
4394        [[2, -1, -1],  # example 0, ids [2]
4395         [0, 1, -1]])  # example 1, ids [0, 1]
4396    input_b = np.array(
4397        [[0, -1, -1],  # example 0, ids [0]
4398         [-1, -1, -1]])  # example 1, ids []
4399    input_features = {
4400        'aaa': input_a,
4401        'bbb': input_b
4402    }
4403
4404    # Embedding variable.
4405    embedding_dimension = 2
4406    embedding_values = (
4407        (1., 2.),  # id 0
4408        (3., 5.),  # id 1
4409        (7., 11.)  # id 2
4410    )
4411    def _initializer(shape, dtype, partition_info):
4412      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
4413      self.assertEqual(dtypes.float32, dtype)
4414      self.assertIsNone(partition_info)
4415      return embedding_values
4416
4417    # Expected lookup result, using combiner='mean'.
4418    expected_lookups_a = (
4419        # example 0:
4420        (7., 11.),  # ids [2], embedding = [7, 11]
4421        # example 1:
4422        (2., 3.5),  # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
4423    )
4424    expected_lookups_b = (
4425        # example 0:
4426        (1., 2.),  # ids [0], embedding = [1, 2]
4427        # example 1:
4428        (0., 0.),  # ids [], embedding = [0, 0]
4429    )
4430
4431    # Build columns.
4432    categorical_column_a = fc.categorical_column_with_identity(
4433        key='aaa', num_buckets=vocabulary_size)
4434    categorical_column_b = fc.categorical_column_with_identity(
4435        key='bbb', num_buckets=vocabulary_size)
4436    embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
4437        [categorical_column_a, categorical_column_b],
4438        dimension=embedding_dimension, initializer=_initializer)
4439
4440    # Provide sparse input and get dense result.
4441    embedding_lookup_a = embedding_column_a._get_dense_tensor(
4442        _LazyBuilder(input_features))
4443    embedding_lookup_b = embedding_column_b._get_dense_tensor(
4444        _LazyBuilder(input_features))
4445
4446    # Assert expected embedding variable and lookups.
4447    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
4448    self.assertItemsEqual(
4449        ('embedding_weights:0',), tuple([v.name for v in global_vars]))
4450    embedding_var = global_vars[0]
4451    with _initialized_session():
4452      self.assertAllEqual(embedding_values, embedding_var.eval())
4453      self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
4454      self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
4455
4456  def test_get_dense_tensor_placeholder_inputs(self):
4457    # Inputs.
4458    vocabulary_size = 3
4459    # -1 values are ignored.
4460    input_a = np.array(
4461        [[2, -1, -1],  # example 0, ids [2]
4462         [0, 1, -1]])  # example 1, ids [0, 1]
4463    input_b = np.array(
4464        [[0, -1, -1],  # example 0, ids [0]
4465         [-1, -1, -1]])  # example 1, ids []
4466    # Specify shape, because dense input must have rank specified.
4467    input_a_placeholder = array_ops.placeholder(
4468        dtype=dtypes.int64, shape=[None, 3])
4469    input_b_placeholder = array_ops.placeholder(
4470        dtype=dtypes.int64, shape=[None, 3])
4471    input_features = {
4472        'aaa': input_a_placeholder,
4473        'bbb': input_b_placeholder,
4474    }
4475    feed_dict = {
4476        input_a_placeholder: input_a,
4477        input_b_placeholder: input_b,
4478    }
4479
4480    # Embedding variable.
4481    embedding_dimension = 2
4482    embedding_values = (
4483        (1., 2.),  # id 0
4484        (3., 5.),  # id 1
4485        (7., 11.)  # id 2
4486    )
4487    def _initializer(shape, dtype, partition_info):
4488      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
4489      self.assertEqual(dtypes.float32, dtype)
4490      self.assertIsNone(partition_info)
4491      return embedding_values
4492
4493    # Build columns.
4494    categorical_column_a = fc.categorical_column_with_identity(
4495        key='aaa', num_buckets=vocabulary_size)
4496    categorical_column_b = fc.categorical_column_with_identity(
4497        key='bbb', num_buckets=vocabulary_size)
4498    embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
4499        [categorical_column_a, categorical_column_b],
4500        dimension=embedding_dimension, initializer=_initializer)
4501
4502    # Provide sparse input and get dense result.
4503    embedding_lookup_a = embedding_column_a._get_dense_tensor(
4504        _LazyBuilder(input_features))
4505    embedding_lookup_b = embedding_column_b._get_dense_tensor(
4506        _LazyBuilder(input_features))
4507
4508    with _initialized_session() as sess:
4509      sess.run([embedding_lookup_a, embedding_lookup_b], feed_dict=feed_dict)
4510
4511  def test_linear_model(self):
4512    # Inputs.
4513    batch_size = 2
4514    vocabulary_size = 3
4515    # -1 values are ignored.
4516    input_a = np.array(
4517        [[2, -1, -1],  # example 0, ids [2]
4518         [0, 1, -1]])  # example 1, ids [0, 1]
4519    input_b = np.array(
4520        [[0, -1, -1],  # example 0, ids [0]
4521         [-1, -1, -1]])  # example 1, ids []
4522
4523    # Embedding variable.
4524    embedding_dimension = 2
4525    embedding_shape = (vocabulary_size, embedding_dimension)
4526    zeros_embedding_values = np.zeros(embedding_shape)
4527    def _initializer(shape, dtype, partition_info):
4528      self.assertAllEqual(embedding_shape, shape)
4529      self.assertEqual(dtypes.float32, dtype)
4530      self.assertIsNone(partition_info)
4531      return zeros_embedding_values
4532
4533    # Build columns.
4534    categorical_column_a = fc.categorical_column_with_identity(
4535        key='aaa', num_buckets=vocabulary_size)
4536    categorical_column_b = fc.categorical_column_with_identity(
4537        key='bbb', num_buckets=vocabulary_size)
4538    embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
4539        [categorical_column_a, categorical_column_b],
4540        dimension=embedding_dimension, initializer=_initializer)
4541
4542    with ops.Graph().as_default():
4543      predictions = fc.linear_model({
4544          categorical_column_a.name: input_a,
4545          categorical_column_b.name: input_b,
4546      }, (embedding_column_a, embedding_column_b))
4547      # Linear weights do not follow the column name. But this is a rare use
4548      # case, and fixing it would add too much complexity to the code.
4549      expected_var_names = (
4550          'linear_model/bias_weights:0',
4551          'linear_model/aaa_bbb_shared_embedding/weights:0',
4552          'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
4553          'linear_model/aaa_bbb_shared_embedding_1/weights:0',
4554      )
4555      self.assertItemsEqual(
4556          expected_var_names,
4557          [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
4558      trainable_vars = {
4559          v.name: v for v in ops.get_collection(
4560              ops.GraphKeys.TRAINABLE_VARIABLES)
4561      }
4562      self.assertItemsEqual(expected_var_names, trainable_vars.keys())
4563      bias = trainable_vars['linear_model/bias_weights:0']
4564      embedding_weights = trainable_vars[
4565          'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
4566      linear_weights_a = trainable_vars[
4567          'linear_model/aaa_bbb_shared_embedding/weights:0']
4568      linear_weights_b = trainable_vars[
4569          'linear_model/aaa_bbb_shared_embedding_1/weights:0']
4570      with _initialized_session():
4571        # Predictions with all zero weights.
4572        self.assertAllClose(np.zeros((1,)), bias.eval())
4573        self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
4574        self.assertAllClose(
4575            np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
4576        self.assertAllClose(
4577            np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
4578        self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
4579
4580        # Predictions with all non-zero weights.
4581        embedding_weights.assign((
4582            (1., 2.),  # id 0
4583            (3., 5.),  # id 1
4584            (7., 11.)  # id 2
4585        )).eval()
4586        linear_weights_a.assign(((4.,), (6.,))).eval()
4587        # example 0, ids [2], embedding[0] = [7, 11]
4588        # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
4589        # sum(embeddings * linear_weights)
4590        # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
4591        linear_weights_b.assign(((3.,), (5.,))).eval()
4592        # example 0, ids [0], embedding[0] = [1, 2]
4593        # example 1, ids [], embedding[1] = 0, 0]
4594        # sum(embeddings * linear_weights)
4595        # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
4596        self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
4597
4598  def _test_input_layer(self, trainable=True):
4599    # Inputs.
4600    vocabulary_size = 3
4601    sparse_input_a = sparse_tensor.SparseTensorValue(
4602        # example 0, ids [2]
4603        # example 1, ids [0, 1]
4604        indices=((0, 0), (1, 0), (1, 4)),
4605        values=(2, 0, 1),
4606        dense_shape=(2, 5))
4607    sparse_input_b = sparse_tensor.SparseTensorValue(
4608        # example 0, ids [0]
4609        # example 1, ids []
4610        indices=((0, 0),),
4611        values=(0,),
4612        dense_shape=(2, 5))
4613
4614    # Embedding variable.
4615    embedding_dimension = 2
4616    embedding_values = (
4617        (1., 2.),  # id 0
4618        (3., 5.),  # id 1
4619        (7., 11.)  # id 2
4620    )
4621    def _initializer(shape, dtype, partition_info):
4622      self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
4623      self.assertEqual(dtypes.float32, dtype)
4624      self.assertIsNone(partition_info)
4625      return embedding_values
4626
4627    # Expected lookup result, using combiner='mean'.
4628    expected_lookups = (
4629        # example 0:
4630        # A ids [2], embedding = [7, 11]
4631        # B ids [0], embedding = [1, 2]
4632        (7., 11., 1., 2.),
4633        # example 1:
4634        # A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
4635        # B ids [], embedding = [0, 0]
4636        (2., 3.5, 0., 0.),
4637    )
4638
4639    # Build columns.
4640    categorical_column_a = fc.categorical_column_with_identity(
4641        key='aaa', num_buckets=vocabulary_size)
4642    categorical_column_b = fc.categorical_column_with_identity(
4643        key='bbb', num_buckets=vocabulary_size)
4644    embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
4645        [categorical_column_a, categorical_column_b],
4646        dimension=embedding_dimension, initializer=_initializer,
4647        trainable=trainable)
4648
4649    # Provide sparse input and get dense result.
4650    input_layer = fc.input_layer(
4651        features={'aaa': sparse_input_a, 'bbb': sparse_input_b},
4652        feature_columns=(embedding_column_b, embedding_column_a))
4653
4654    # Assert expected embedding variable and lookups.
4655    global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
4656    self.assertItemsEqual(
4657        ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
4658        tuple([v.name for v in global_vars]))
4659    trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
4660    if trainable:
4661      self.assertItemsEqual(
4662          ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
4663          tuple([v.name for v in trainable_vars]))
4664    else:
4665      self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
4666    shared_embedding_vars = ops.get_collection('aaa_bbb_shared_embedding')
4667    self.assertItemsEqual(
4668        ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
4669        tuple([v.name for v in shared_embedding_vars]))
4670    with _initialized_session():
4671      self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
4672      self.assertAllEqual(expected_lookups, input_layer.eval())
4673
4674  def test_input_layer(self):
4675    self._test_input_layer()
4676
4677  def test_input_layer_no_trainable(self):
4678    self._test_input_layer(trainable=False)
4679
4680
4681class WeightedCategoricalColumnTest(test.TestCase):
4682
4683  def test_defaults(self):
4684    column = fc.weighted_categorical_column(
4685        categorical_column=fc.categorical_column_with_identity(
4686            key='ids', num_buckets=3),
4687        weight_feature_key='values')
4688    self.assertEqual('ids_weighted_by_values', column.name)
4689    self.assertEqual('ids_weighted_by_values', column._var_scope_name)
4690    self.assertEqual(3, column._num_buckets)
4691    self.assertEqual({
4692        'ids': parsing_ops.VarLenFeature(dtypes.int64),
4693        'values': parsing_ops.VarLenFeature(dtypes.float32)
4694    }, column._parse_example_spec)
4695
4696  def test_deep_copy(self):
4697    """Tests deepcopy of categorical_column_with_hash_bucket."""
4698    original = fc.weighted_categorical_column(
4699        categorical_column=fc.categorical_column_with_identity(
4700            key='ids', num_buckets=3),
4701        weight_feature_key='values')
4702    for column in (original, copy.deepcopy(original)):
4703      self.assertEqual('ids_weighted_by_values', column.name)
4704      self.assertEqual(3, column._num_buckets)
4705      self.assertEqual({
4706          'ids': parsing_ops.VarLenFeature(dtypes.int64),
4707          'values': parsing_ops.VarLenFeature(dtypes.float32)
4708      }, column._parse_example_spec)
4709
4710  def test_invalid_dtype_none(self):
4711    with self.assertRaisesRegexp(ValueError, 'is not convertible to float'):
4712      fc.weighted_categorical_column(
4713          categorical_column=fc.categorical_column_with_identity(
4714              key='ids', num_buckets=3),
4715          weight_feature_key='values',
4716          dtype=None)
4717
4718  def test_invalid_dtype_string(self):
4719    with self.assertRaisesRegexp(ValueError, 'is not convertible to float'):
4720      fc.weighted_categorical_column(
4721          categorical_column=fc.categorical_column_with_identity(
4722              key='ids', num_buckets=3),
4723          weight_feature_key='values',
4724          dtype=dtypes.string)
4725
4726  def test_invalid_input_dtype(self):
4727    column = fc.weighted_categorical_column(
4728        categorical_column=fc.categorical_column_with_identity(
4729            key='ids', num_buckets=3),
4730        weight_feature_key='values')
4731    strings = sparse_tensor.SparseTensorValue(
4732        indices=((0, 0), (1, 0), (1, 1)),
4733        values=('omar', 'stringer', 'marlo'),
4734        dense_shape=(2, 2))
4735    with self.assertRaisesRegexp(ValueError, 'Bad dtype'):
4736      _transform_features({'ids': strings, 'values': strings}, (column,))
4737
4738  def test_column_name_collision(self):
4739    with self.assertRaisesRegexp(ValueError, r'Parse config.*already exists'):
4740      fc.weighted_categorical_column(
4741          categorical_column=fc.categorical_column_with_identity(
4742              key='aaa', num_buckets=3),
4743          weight_feature_key='aaa')._parse_example_spec()
4744
4745  def test_missing_weights(self):
4746    column = fc.weighted_categorical_column(
4747        categorical_column=fc.categorical_column_with_identity(
4748            key='ids', num_buckets=3),
4749        weight_feature_key='values')
4750    inputs = sparse_tensor.SparseTensorValue(
4751        indices=((0, 0), (1, 0), (1, 1)),
4752        values=('omar', 'stringer', 'marlo'),
4753        dense_shape=(2, 2))
4754    with self.assertRaisesRegexp(
4755        ValueError, 'values is not in features dictionary'):
4756      _transform_features({'ids': inputs}, (column,))
4757
4758  def test_parse_example(self):
4759    a = fc.categorical_column_with_vocabulary_list(
4760        key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
4761    a_weighted = fc.weighted_categorical_column(a, weight_feature_key='weights')
4762    data = example_pb2.Example(features=feature_pb2.Features(
4763        feature={
4764            'aaa':
4765                feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
4766                    value=[b'omar', b'stringer'])),
4767            'weights':
4768                feature_pb2.Feature(float_list=feature_pb2.FloatList(
4769                    value=[1., 10.]))
4770        }))
4771    features = parsing_ops.parse_example(
4772        serialized=[data.SerializeToString()],
4773        features=fc.make_parse_example_spec([a_weighted]))
4774    self.assertIn('aaa', features)
4775    self.assertIn('weights', features)
4776    with self.test_session():
4777      _assert_sparse_tensor_value(
4778          self,
4779          sparse_tensor.SparseTensorValue(
4780              indices=[[0, 0], [0, 1]],
4781              values=np.array([b'omar', b'stringer'], dtype=np.object_),
4782              dense_shape=[1, 2]),
4783          features['aaa'].eval())
4784      _assert_sparse_tensor_value(
4785          self,
4786          sparse_tensor.SparseTensorValue(
4787              indices=[[0, 0], [0, 1]],
4788              values=np.array([1., 10.], dtype=np.float32),
4789              dense_shape=[1, 2]),
4790          features['weights'].eval())
4791
4792  def test_transform_features(self):
4793    column = fc.weighted_categorical_column(
4794        categorical_column=fc.categorical_column_with_identity(
4795            key='ids', num_buckets=3),
4796        weight_feature_key='values')
4797    inputs = sparse_tensor.SparseTensorValue(
4798        indices=((0, 0), (1, 0), (1, 1)),
4799        values=(0, 1, 0),
4800        dense_shape=(2, 2))
4801    weights = sparse_tensor.SparseTensorValue(
4802        indices=((0, 0), (1, 0), (1, 1)),
4803        values=(0.5, 1.0, 0.1),
4804        dense_shape=(2, 2))
4805    id_tensor, weight_tensor = _transform_features({
4806        'ids': inputs,
4807        'values': weights,
4808    }, (column,))[column]
4809    with _initialized_session():
4810      _assert_sparse_tensor_value(
4811          self,
4812          sparse_tensor.SparseTensorValue(
4813              indices=inputs.indices,
4814              values=np.array(inputs.values, dtype=np.int64),
4815              dense_shape=inputs.dense_shape),
4816          id_tensor.eval())
4817      _assert_sparse_tensor_value(
4818          self,
4819          sparse_tensor.SparseTensorValue(
4820              indices=weights.indices,
4821              values=np.array(weights.values, dtype=np.float32),
4822              dense_shape=weights.dense_shape),
4823          weight_tensor.eval())
4824
4825  def test_transform_features_dense_input(self):
4826    column = fc.weighted_categorical_column(
4827        categorical_column=fc.categorical_column_with_identity(
4828            key='ids', num_buckets=3),
4829        weight_feature_key='values')
4830    weights = sparse_tensor.SparseTensorValue(
4831        indices=((0, 0), (1, 0), (1, 1)),
4832        values=(0.5, 1.0, 0.1),
4833        dense_shape=(2, 2))
4834    id_tensor, weight_tensor = _transform_features({
4835        'ids': ((0, -1), (1, 0)),
4836        'values': weights,
4837    }, (column,))[column]
4838    with _initialized_session():
4839      _assert_sparse_tensor_value(
4840          self,
4841          sparse_tensor.SparseTensorValue(
4842              indices=((0, 0), (1, 0), (1, 1)),
4843              values=np.array((0, 1, 0), dtype=np.int64),
4844              dense_shape=(2, 2)),
4845          id_tensor.eval())
4846      _assert_sparse_tensor_value(
4847          self,
4848          sparse_tensor.SparseTensorValue(
4849              indices=weights.indices,
4850              values=np.array(weights.values, dtype=np.float32),
4851              dense_shape=weights.dense_shape),
4852          weight_tensor.eval())
4853
4854  def test_transform_features_dense_weights(self):
4855    column = fc.weighted_categorical_column(
4856        categorical_column=fc.categorical_column_with_identity(
4857            key='ids', num_buckets=3),
4858        weight_feature_key='values')
4859    inputs = sparse_tensor.SparseTensorValue(
4860        indices=((0, 0), (1, 0), (1, 1)),
4861        values=(2, 1, 0),
4862        dense_shape=(2, 2))
4863    id_tensor, weight_tensor = _transform_features({
4864        'ids': inputs,
4865        'values': ((.5, 0.), (1., .1)),
4866    }, (column,))[column]
4867    with _initialized_session():
4868      _assert_sparse_tensor_value(
4869          self,
4870          sparse_tensor.SparseTensorValue(
4871              indices=inputs.indices,
4872              values=np.array(inputs.values, dtype=np.int64),
4873              dense_shape=inputs.dense_shape),
4874          id_tensor.eval())
4875      _assert_sparse_tensor_value(
4876          self,
4877          sparse_tensor.SparseTensorValue(
4878              indices=((0, 0), (1, 0), (1, 1)),
4879              values=np.array((.5, 1., .1), dtype=np.float32),
4880              dense_shape=(2, 2)),
4881          weight_tensor.eval())
4882
4883  def test_linear_model(self):
4884    column = fc.weighted_categorical_column(
4885        categorical_column=fc.categorical_column_with_identity(
4886            key='ids', num_buckets=3),
4887        weight_feature_key='values')
4888    with ops.Graph().as_default():
4889      predictions = fc.linear_model({
4890          'ids': sparse_tensor.SparseTensorValue(
4891              indices=((0, 0), (1, 0), (1, 1)),
4892              values=(0, 2, 1),
4893              dense_shape=(2, 2)),
4894          'values': sparse_tensor.SparseTensorValue(
4895              indices=((0, 0), (1, 0), (1, 1)),
4896              values=(.5, 1., .1),
4897              dense_shape=(2, 2))
4898      }, (column,))
4899      bias = get_linear_model_bias()
4900      weight_var = get_linear_model_column_var(column)
4901      with _initialized_session():
4902        self.assertAllClose((0.,), bias.eval())
4903        self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
4904        self.assertAllClose(((0.,), (0.,)), predictions.eval())
4905        weight_var.assign(((1.,), (2.,), (3.,))).eval()
4906        # weight_var[0] * weights[0, 0] = 1 * .5 = .5
4907        # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
4908        # = 3*1 + 2*.1 = 3+.2 = 3.2
4909        self.assertAllClose(((.5,), (3.2,)), predictions.eval())
4910
4911  def test_linear_model_mismatched_shape(self):
4912    column = fc.weighted_categorical_column(
4913        categorical_column=fc.categorical_column_with_identity(
4914            key='ids', num_buckets=3),
4915        weight_feature_key='values')
4916    with ops.Graph().as_default():
4917      with self.assertRaisesRegexp(
4918          ValueError, r'Dimensions.*are not compatible'):
4919        fc.linear_model({
4920            'ids': sparse_tensor.SparseTensorValue(
4921                indices=((0, 0), (1, 0), (1, 1)),
4922                values=(0, 2, 1),
4923                dense_shape=(2, 2)),
4924            'values': sparse_tensor.SparseTensorValue(
4925                indices=((0, 0), (0, 1), (1, 0), (1, 1)),
4926                values=(.5, 11., 1., .1),
4927                dense_shape=(2, 2))
4928        }, (column,))
4929
4930  def test_linear_model_mismatched_dense_values(self):
4931    column = fc.weighted_categorical_column(
4932        categorical_column=fc.categorical_column_with_identity(
4933            key='ids', num_buckets=3),
4934        weight_feature_key='values')
4935    with ops.Graph().as_default():
4936      predictions = fc.linear_model({
4937          'ids': sparse_tensor.SparseTensorValue(
4938              indices=((0, 0), (1, 0), (1, 1)),
4939              values=(0, 2, 1),
4940              dense_shape=(2, 2)),
4941          'values': ((.5,), (1.,))
4942      }, (column,))
4943      with _initialized_session():
4944        with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
4945          predictions.eval()
4946
4947  def test_linear_model_mismatched_dense_shape(self):
4948    column = fc.weighted_categorical_column(
4949        categorical_column=fc.categorical_column_with_identity(
4950            key='ids', num_buckets=3),
4951        weight_feature_key='values')
4952    with ops.Graph().as_default():
4953      predictions = fc.linear_model({
4954          'ids': sparse_tensor.SparseTensorValue(
4955              indices=((0, 0), (1, 0), (1, 1)),
4956              values=(0, 2, 1),
4957              dense_shape=(2, 2)),
4958          'values': ((.5,), (1.,), (.1,))
4959      }, (column,))
4960      bias = get_linear_model_bias()
4961      weight_var = get_linear_model_column_var(column)
4962      with _initialized_session():
4963        self.assertAllClose((0.,), bias.eval())
4964        self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
4965        self.assertAllClose(((0.,), (0.,)), predictions.eval())
4966        weight_var.assign(((1.,), (2.,), (3.,))).eval()
4967        # weight_var[0] * weights[0, 0] = 1 * .5 = .5
4968        # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
4969        # = 3*1 + 2*.1 = 3+.2 = 3.2
4970        self.assertAllClose(((.5,), (3.2,)), predictions.eval())
4971
4972  # TODO(ptucker): Add test with embedding of weighted categorical.
4973
4974if __name__ == '__main__':
4975  test.main()
4976