1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Categorical vocabulary classes to map categories to indexes.
17
18Can be used for categorical variables, sparse variables and words.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import six
27
28
29class CategoricalVocabulary(object):
30  """Categorical variables vocabulary class.
31
32  Accumulates and provides mapping from classes to indexes.
33  Can be easily used for words.
34  """
35
36  def __init__(self, unknown_token="<UNK>", support_reverse=True):
37    self._unknown_token = unknown_token
38    self._mapping = {unknown_token: 0}
39    self._support_reverse = support_reverse
40    if support_reverse:
41      self._reverse_mapping = [unknown_token]
42    self._freq = collections.defaultdict(int)
43    self._freeze = False
44
45  def __len__(self):
46    """Returns total count of mappings. Including unknown token."""
47    return len(self._mapping)
48
49  def freeze(self, freeze=True):
50    """Freezes the vocabulary, after which new words return unknown token id.
51
52    Args:
53      freeze: True to freeze, False to unfreeze.
54    """
55    self._freeze = freeze
56
57  def get(self, category):
58    """Returns word's id in the vocabulary.
59
60    If category is new, creates a new id for it.
61
62    Args:
63      category: string or integer to lookup in vocabulary.
64
65    Returns:
66      interger, id in the vocabulary.
67    """
68    if category not in self._mapping:
69      if self._freeze:
70        return 0
71      self._mapping[category] = len(self._mapping)
72      if self._support_reverse:
73        self._reverse_mapping.append(category)
74    return self._mapping[category]
75
76  def add(self, category, count=1):
77    """Adds count of the category to the frequency table.
78
79    Args:
80      category: string or integer, category to add frequency to.
81      count: optional integer, how many to add.
82    """
83    category_id = self.get(category)
84    if category_id <= 0:
85      return
86    self._freq[category] += count
87
88  def trim(self, min_frequency, max_frequency=-1):
89    """Trims vocabulary for minimum frequency.
90
91    Remaps ids from 1..n in sort frequency order.
92    where n - number of elements left.
93
94    Args:
95      min_frequency: minimum frequency to keep.
96      max_frequency: optional, maximum frequency to keep.
97        Useful to remove very frequent categories (like stop words).
98    """
99    # Sort by alphabet then reversed frequency.
100    self._freq = sorted(
101        sorted(
102            six.iteritems(self._freq),
103            key=lambda x: (isinstance(x[0], str), x[0])),
104        key=lambda x: x[1],
105        reverse=True)
106    self._mapping = {self._unknown_token: 0}
107    if self._support_reverse:
108      self._reverse_mapping = [self._unknown_token]
109    idx = 1
110    for category, count in self._freq:
111      if max_frequency > 0 and count >= max_frequency:
112        continue
113      if count <= min_frequency:
114        break
115      self._mapping[category] = idx
116      idx += 1
117      if self._support_reverse:
118        self._reverse_mapping.append(category)
119    self._freq = dict(self._freq[:idx - 1])
120
121  def reverse(self, class_id):
122    """Given class id reverse to original class name.
123
124    Args:
125      class_id: Id of the class.
126
127    Returns:
128      Class name.
129
130    Raises:
131      ValueError: if this vocabulary wasn't initialized with support_reverse.
132    """
133    if not self._support_reverse:
134      raise ValueError("This vocabulary wasn't initialized with "
135                       "support_reverse to support reverse() function.")
136    return self._reverse_mapping[class_id]
137