1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Categorical vocabulary classes to map categories to indexes. 17 18Can be used for categorical variables, sparse variables and words. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections 26import six 27 28 29class CategoricalVocabulary(object): 30 """Categorical variables vocabulary class. 31 32 Accumulates and provides mapping from classes to indexes. 33 Can be easily used for words. 34 """ 35 36 def __init__(self, unknown_token="<UNK>", support_reverse=True): 37 self._unknown_token = unknown_token 38 self._mapping = {unknown_token: 0} 39 self._support_reverse = support_reverse 40 if support_reverse: 41 self._reverse_mapping = [unknown_token] 42 self._freq = collections.defaultdict(int) 43 self._freeze = False 44 45 def __len__(self): 46 """Returns total count of mappings. Including unknown token.""" 47 return len(self._mapping) 48 49 def freeze(self, freeze=True): 50 """Freezes the vocabulary, after which new words return unknown token id. 51 52 Args: 53 freeze: True to freeze, False to unfreeze. 54 """ 55 self._freeze = freeze 56 57 def get(self, category): 58 """Returns word's id in the vocabulary. 59 60 If category is new, creates a new id for it. 61 62 Args: 63 category: string or integer to lookup in vocabulary. 64 65 Returns: 66 interger, id in the vocabulary. 67 """ 68 if category not in self._mapping: 69 if self._freeze: 70 return 0 71 self._mapping[category] = len(self._mapping) 72 if self._support_reverse: 73 self._reverse_mapping.append(category) 74 return self._mapping[category] 75 76 def add(self, category, count=1): 77 """Adds count of the category to the frequency table. 78 79 Args: 80 category: string or integer, category to add frequency to. 81 count: optional integer, how many to add. 82 """ 83 category_id = self.get(category) 84 if category_id <= 0: 85 return 86 self._freq[category] += count 87 88 def trim(self, min_frequency, max_frequency=-1): 89 """Trims vocabulary for minimum frequency. 90 91 Remaps ids from 1..n in sort frequency order. 92 where n - number of elements left. 93 94 Args: 95 min_frequency: minimum frequency to keep. 96 max_frequency: optional, maximum frequency to keep. 97 Useful to remove very frequent categories (like stop words). 98 """ 99 # Sort by alphabet then reversed frequency. 100 self._freq = sorted( 101 sorted( 102 six.iteritems(self._freq), 103 key=lambda x: (isinstance(x[0], str), x[0])), 104 key=lambda x: x[1], 105 reverse=True) 106 self._mapping = {self._unknown_token: 0} 107 if self._support_reverse: 108 self._reverse_mapping = [self._unknown_token] 109 idx = 1 110 for category, count in self._freq: 111 if max_frequency > 0 and count >= max_frequency: 112 continue 113 if count <= min_frequency: 114 break 115 self._mapping[category] = idx 116 idx += 1 117 if self._support_reverse: 118 self._reverse_mapping.append(category) 119 self._freq = dict(self._freq[:idx - 1]) 120 121 def reverse(self, class_id): 122 """Given class id reverse to original class name. 123 124 Args: 125 class_id: Id of the class. 126 127 Returns: 128 Class name. 129 130 Raises: 131 ValueError: if this vocabulary wasn't initialized with support_reverse. 132 """ 133 if not self._support_reverse: 134 raise ValueError("This vocabulary wasn't initialized with " 135 "support_reverse to support reverse() function.") 136 return self._reverse_mapping[class_id] 137