1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for text input preprocessing.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from collections import OrderedDict
22from hashlib import md5
23import string
24import sys
25
26import numpy as np
27from six.moves import range  # pylint: disable=redefined-builtin
28from six.moves import zip  # pylint: disable=redefined-builtin
29
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util.tf_export import tf_export
32
33
34if sys.version_info < (3,):
35  maketrans = string.maketrans
36else:
37  maketrans = str.maketrans
38
39
40@tf_export('keras.preprocessing.text.text_to_word_sequence')
41def text_to_word_sequence(text,
42                          filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
43                          lower=True,
44                          split=' '):
45  """Converts a text to a sequence of words (or tokens).
46
47  Arguments:
48      text: Input text (string).
49      filters: Sequence of characters to filter out.
50      lower: Whether to convert the input to lowercase.
51      split: Sentence split marker (string).
52
53  Returns:
54      A list of words (or tokens).
55  """
56  if lower:
57    text = text.lower()
58
59  if sys.version_info < (3,) and isinstance(text, unicode):
60    translate_map = dict((ord(c), unicode(split)) for c in filters)
61  else:
62    translate_map = maketrans(filters, split * len(filters))
63
64  text = text.translate(translate_map)
65  seq = text.split(split)
66  return [i for i in seq if i]
67
68
69@tf_export('keras.preprocessing.text.one_hot')
70def one_hot(text,
71            n,
72            filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
73            lower=True,
74            split=' '):
75  """One-hot encodes a text into a list of word indexes of size n.
76
77  This is a wrapper to the `hashing_trick` function using `hash` as the
78  hashing function; unicity of word to index mapping non-guaranteed.
79
80  Arguments:
81      text: Input text (string).
82      n: Dimension of the hashing space.
83      filters: Sequence of characters to filter out.
84      lower: Whether to convert the input to lowercase.
85      split: Sentence split marker (string).
86
87  Returns:
88      A list of integer word indices (unicity non-guaranteed).
89  """
90  return hashing_trick(
91      text, n, hash_function=hash, filters=filters, lower=lower, split=split)
92
93
94def hashing_trick(text,
95                  n,
96                  hash_function=None,
97                  filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
98                  lower=True,
99                  split=' '):
100  """Converts a text to a sequence of indexes in a fixed-size hashing space.
101
102  Arguments:
103      text: Input text (string).
104      n: Dimension of the hashing space.
105      hash_function: if `None` uses python `hash` function, can be 'md5' or
106          any function that takes in input a string and returns a int.
107          Note that `hash` is not a stable hashing function, so
108          it is not consistent across different runs, while 'md5'
109          is a stable hashing function.
110      filters: Sequence of characters to filter out.
111      lower: Whether to convert the input to lowercase.
112      split: Sentence split marker (string).
113
114  Returns:
115      A list of integer word indices (unicity non-guaranteed).
116
117  `0` is a reserved index that won't be assigned to any word.
118
119  Two or more words may be assigned to the same index, due to possible
120  collisions by the hashing function.
121  The
122  probability
123  of a collision is in relation to the dimension of the hashing space and
124  the number of distinct objects.
125  """
126  if hash_function is None:
127    hash_function = hash
128  elif hash_function == 'md5':
129    hash_function = lambda w: int(md5(w.encode()).hexdigest(), 16)
130
131  seq = text_to_word_sequence(text, filters=filters, lower=lower, split=split)
132  return [(hash_function(w) % (n - 1) + 1) for w in seq]
133
134
135@tf_export('keras.preprocessing.text.Tokenizer')
136class Tokenizer(object):
137  """Text tokenization utility class.
138
139  This class allows to vectorize a text corpus, by turning each
140  text into either a sequence of integers (each integer being the index
141  of a token in a dictionary) or into a vector where the coefficient
142  for each token could be binary, based on word count, based on tf-idf...
143
144  Arguments:
145      num_words: the maximum number of words to keep, based
146          on word frequency. Only the most common `num_words` words will
147          be kept.
148      filters: a string where each element is a character that will be
149          filtered from the texts. The default is all punctuation, plus
150          tabs and line breaks, minus the `'` character.
151      lower: boolean. Whether to convert the texts to lowercase.
152      split: character or string to use for token splitting.
153      char_level: if True, every character will be treated as a token.
154      oov_token: if given, it will be added to word_index and used to
155          replace out-of-vocabulary words during text_to_sequence calls
156
157  By default, all punctuation is removed, turning the texts into
158  space-separated sequences of words
159  (words maybe include the `'` character). These sequences are then
160  split into lists of tokens. They will then be indexed or vectorized.
161
162  `0` is a reserved index that won't be assigned to any word.
163  """
164
165  def __init__(self,
166               num_words=None,
167               filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
168               lower=True,
169               split=' ',
170               char_level=False,
171               oov_token=None,
172               **kwargs):
173    # Legacy support
174    if 'nb_words' in kwargs:
175      logging.warning('The `nb_words` argument in `Tokenizer` '
176                      'has been renamed `num_words`.')
177      num_words = kwargs.pop('nb_words')
178    if kwargs:
179      raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
180
181    self.word_counts = OrderedDict()
182    self.word_docs = {}
183    self.filters = filters
184    self.split = split
185    self.lower = lower
186    self.num_words = num_words
187    self.document_count = 0
188    self.char_level = char_level
189    self.oov_token = oov_token
190
191  def fit_on_texts(self, texts):
192    """Updates internal vocabulary based on a list of texts.
193
194    Required before using `texts_to_sequences` or `texts_to_matrix`.
195
196    Arguments:
197        texts: can be a list of strings,
198            or a generator of strings (for memory-efficiency)
199    """
200    self.document_count = 0
201    for text in texts:
202      self.document_count += 1
203      seq = text if self.char_level else text_to_word_sequence(
204          text, self.filters, self.lower, self.split)
205      for w in seq:
206        if w in self.word_counts:
207          self.word_counts[w] += 1
208        else:
209          self.word_counts[w] = 1
210      for w in set(seq):
211        if w in self.word_docs:
212          self.word_docs[w] += 1
213        else:
214          self.word_docs[w] = 1
215
216    wcounts = list(self.word_counts.items())
217    wcounts.sort(key=lambda x: x[1], reverse=True)
218    sorted_voc = [wc[0] for wc in wcounts]
219    # note that index 0 is reserved, never assigned to an existing word
220    self.word_index = dict(
221        list(zip(sorted_voc, list(range(1,
222                                        len(sorted_voc) + 1)))))
223
224    if self.oov_token is not None:
225      i = self.word_index.get(self.oov_token)
226      if i is None:
227        self.word_index[self.oov_token] = len(self.word_index) + 1
228
229    self.index_docs = {}
230    for w, c in list(self.word_docs.items()):
231      self.index_docs[self.word_index[w]] = c
232
233  def fit_on_sequences(self, sequences):
234    """Updates internal vocabulary based on a list of sequences.
235
236    Required before using `sequences_to_matrix`
237    (if `fit_on_texts` was never called).
238
239    Arguments:
240        sequences: A list of sequence.
241            A "sequence" is a list of integer word indices.
242    """
243    self.document_count = len(sequences)
244    self.index_docs = {}
245    for seq in sequences:
246      seq = set(seq)
247      for i in seq:
248        if i not in self.index_docs:
249          self.index_docs[i] = 1
250        else:
251          self.index_docs[i] += 1
252
253  def texts_to_sequences(self, texts):
254    """Transforms each text in texts in a sequence of integers.
255
256    Only top "num_words" most frequent words will be taken into account.
257    Only words known by the tokenizer will be taken into account.
258
259    Arguments:
260        texts: A list of texts (strings).
261
262    Returns:
263        A list of sequences.
264    """
265    res = []
266    for vect in self.texts_to_sequences_generator(texts):
267      res.append(vect)
268    return res
269
270  def texts_to_sequences_generator(self, texts):
271    """Transforms each text in texts in a sequence of integers.
272
273    Only top "num_words" most frequent words will be taken into account.
274    Only words known by the tokenizer will be taken into account.
275
276    Arguments:
277        texts: A list of texts (strings).
278
279    Yields:
280        Yields individual sequences.
281    """
282    num_words = self.num_words
283    for text in texts:
284      seq = text if self.char_level else text_to_word_sequence(
285          text, self.filters, self.lower, self.split)
286      vect = []
287      for w in seq:
288        i = self.word_index.get(w)
289        if i is not None:
290          if num_words and i >= num_words:
291            continue
292          else:
293            vect.append(i)
294        elif self.oov_token is not None:
295          i = self.word_index.get(self.oov_token)
296          if i is not None:
297            vect.append(i)
298      yield vect
299
300  def texts_to_matrix(self, texts, mode='binary'):
301    """Convert a list of texts to a Numpy matrix.
302
303    Arguments:
304        texts: list of strings.
305        mode: one of "binary", "count", "tfidf", "freq".
306
307    Returns:
308        A Numpy matrix.
309    """
310    sequences = self.texts_to_sequences(texts)
311    return self.sequences_to_matrix(sequences, mode=mode)
312
313  def sequences_to_matrix(self, sequences, mode='binary'):
314    """Converts a list of sequences into a Numpy matrix.
315
316    Arguments:
317        sequences: list of sequences
318            (a sequence is a list of integer word indices).
319        mode: one of "binary", "count", "tfidf", "freq"
320
321    Returns:
322        A Numpy matrix.
323
324    Raises:
325        ValueError: In case of invalid `mode` argument,
326            or if the Tokenizer requires to be fit to sample data.
327    """
328    if not self.num_words:
329      if self.word_index:
330        num_words = len(self.word_index) + 1
331      else:
332        raise ValueError('Specify a dimension (num_words argument), '
333                         'or fit on some text data first.')
334    else:
335      num_words = self.num_words
336
337    if mode == 'tfidf' and not self.document_count:
338      raise ValueError('Fit the Tokenizer on some data '
339                       'before using tfidf mode.')
340
341    x = np.zeros((len(sequences), num_words))
342    for i, seq in enumerate(sequences):
343      if not seq:
344        continue
345      counts = {}
346      for j in seq:
347        if j >= num_words:
348          continue
349        if j not in counts:
350          counts[j] = 1.
351        else:
352          counts[j] += 1
353      for j, c in list(counts.items()):
354        if mode == 'count':
355          x[i][j] = c
356        elif mode == 'freq':
357          x[i][j] = c / len(seq)
358        elif mode == 'binary':
359          x[i][j] = 1
360        elif mode == 'tfidf':
361          # Use weighting scheme 2 in
362          # https://en.wikipedia.org/wiki/Tf%E2%80%93idf
363          tf = 1 + np.log(c)
364          idf = np.log(1 + self.document_count /
365                       (1 + self.index_docs.get(j, 0)))
366          x[i][j] = tf * idf
367        else:
368          raise ValueError('Unknown vectorization mode:', mode)
369    return x
370