1f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
3f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Licensed under the Apache License, Version 2.0 (the "License");
4f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# you may not use this file except in compliance with the License.
5f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# You may obtain a copy of the License at
6f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
7f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#     http://www.apache.org/licenses/LICENSE-2.0
8f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
9f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Unless required by applicable law or agreed to in writing, software
10f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# distributed under the License is distributed on an "AS IS" BASIS,
11f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# See the License for the specific language governing permissions and
13f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# limitations under the License.
14f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# ==============================================================================
15f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet"""Utilities for text input preprocessing.
16f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet"""
17f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import absolute_import
18f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import division
19f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import print_function
20f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
21d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Cholletfrom collections import OrderedDict
2224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Cholletfrom hashlib import md5
23f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletimport string
24f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletimport sys
25f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
26f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletimport numpy as np
27f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom six.moves import range  # pylint: disable=redefined-builtin
28f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom six.moves import zip  # pylint: disable=redefined-builtin
29f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
300bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Cholletfrom tensorflow.python.platform import tf_logging as logging
31e99724b78b9f6834b918ae8a599597f863cba8d4Anna Rfrom tensorflow.python.util.tf_export import tf_export
320bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet
330bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet
34f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletif sys.version_info < (3,):
35f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  maketrans = string.maketrans
36f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletelse:
37f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  maketrans = str.maketrans
38f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
39f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
40e99724b78b9f6834b918ae8a599597f863cba8d4Anna R@tf_export('keras.preprocessing.text.text_to_word_sequence')
41f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletdef text_to_word_sequence(text,
42f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                          filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
43f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                          lower=True,
44f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                          split=' '):
45d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet  """Converts a text to a sequence of words (or tokens).
46f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
47f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  Arguments:
48f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      text: Input text (string).
49f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      filters: Sequence of characters to filter out.
50f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      lower: Whether to convert the input to lowercase.
51f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      split: Sentence split marker (string).
52f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
53f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  Returns:
54d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet      A list of words (or tokens).
55f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """
56f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  if lower:
57f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    text = text.lower()
58164b2ab29e84d512c1874fbc44e369c93835a352Francois Chollet
59164b2ab29e84d512c1874fbc44e369c93835a352Francois Chollet  if sys.version_info < (3,) and isinstance(text, unicode):
60164b2ab29e84d512c1874fbc44e369c93835a352Francois Chollet    translate_map = dict((ord(c), unicode(split)) for c in filters)
61164b2ab29e84d512c1874fbc44e369c93835a352Francois Chollet  else:
62164b2ab29e84d512c1874fbc44e369c93835a352Francois Chollet    translate_map = maketrans(filters, split * len(filters))
63164b2ab29e84d512c1874fbc44e369c93835a352Francois Chollet
64164b2ab29e84d512c1874fbc44e369c93835a352Francois Chollet  text = text.translate(translate_map)
65f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  seq = text.split(split)
66f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  return [i for i in seq if i]
67f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
68f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
69e99724b78b9f6834b918ae8a599597f863cba8d4Anna R@tf_export('keras.preprocessing.text.one_hot')
70f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletdef one_hot(text,
71f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            n,
72f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
73f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            lower=True,
74f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            split=' '):
750bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  """One-hot encodes a text into a list of word indexes of size n.
760bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet
770bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  This is a wrapper to the `hashing_trick` function using `hash` as the
780bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  hashing function; unicity of word to index mapping non-guaranteed.
790bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet
800bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  Arguments:
810bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      text: Input text (string).
820bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      n: Dimension of the hashing space.
830bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      filters: Sequence of characters to filter out.
840bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      lower: Whether to convert the input to lowercase.
850bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      split: Sentence split marker (string).
860bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet
870bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  Returns:
880bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      A list of integer word indices (unicity non-guaranteed).
890bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  """
9024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  return hashing_trick(
9124101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      text, n, hash_function=hash, filters=filters, lower=lower, split=split)
9224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet
9324101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet
9424101b35f3baebbfff3d8057ac223b325bc415ceFrancois Cholletdef hashing_trick(text,
9524101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet                  n,
9624101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet                  hash_function=None,
9724101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet                  filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
9824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet                  lower=True,
9924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet                  split=' '):
10024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  """Converts a text to a sequence of indexes in a fixed-size hashing space.
10124101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet
10224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  Arguments:
10324101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      text: Input text (string).
10424101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      n: Dimension of the hashing space.
10524101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      hash_function: if `None` uses python `hash` function, can be 'md5' or
10624101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          any function that takes in input a string and returns a int.
10724101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          Note that `hash` is not a stable hashing function, so
10824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          it is not consistent across different runs, while 'md5'
10924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          is a stable hashing function.
11024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      filters: Sequence of characters to filter out.
11124101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      lower: Whether to convert the input to lowercase.
11224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      split: Sentence split marker (string).
11324101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet
11424101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  Returns:
11524101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      A list of integer word indices (unicity non-guaranteed).
11624101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet
11724101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  `0` is a reserved index that won't be assigned to any word.
11824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet
11924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  Two or more words may be assigned to the same index, due to possible
12024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  collisions by the hashing function.
1210bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  The
1220bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  probability
1230bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  of a collision is in relation to the dimension of the hashing space and
1240bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  the number of distinct objects.
12524101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  """
12624101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  if hash_function is None:
12724101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet    hash_function = hash
12824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  elif hash_function == 'md5':
12924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet    hash_function = lambda w: int(md5(w.encode()).hexdigest(), 16)
13024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet
131f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  seq = text_to_word_sequence(text, filters=filters, lower=lower, split=split)
13224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  return [(hash_function(w) % (n - 1) + 1) for w in seq]
133f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
134f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
135e99724b78b9f6834b918ae8a599597f863cba8d4Anna R@tf_export('keras.preprocessing.text.Tokenizer')
136f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletclass Tokenizer(object):
137f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """Text tokenization utility class.
138f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
139f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  This class allows to vectorize a text corpus, by turning each
140f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  text into either a sequence of integers (each integer being the index
141f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  of a token in a dictionary) or into a vector where the coefficient
142f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  for each token could be binary, based on word count, based on tf-idf...
143f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
144f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  Arguments:
145f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      num_words: the maximum number of words to keep, based
146f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          on word frequency. Only the most common `num_words` words will
147f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          be kept.
148f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      filters: a string where each element is a character that will be
149f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          filtered from the texts. The default is all punctuation, plus
150f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          tabs and line breaks, minus the `'` character.
151f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      lower: boolean. Whether to convert the texts to lowercase.
152f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      split: character or string to use for token splitting.
153d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet      char_level: if True, every character will be treated as a token.
1540bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      oov_token: if given, it will be added to word_index and used to
1550bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet          replace out-of-vocabulary words during text_to_sequence calls
156f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
157f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  By default, all punctuation is removed, turning the texts into
158f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  space-separated sequences of words
159f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  (words maybe include the `'` character). These sequences are then
160f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  split into lists of tokens. They will then be indexed or vectorized.
161f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
162f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  `0` is a reserved index that won't be assigned to any word.
163f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """
164f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
165f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def __init__(self,
166f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet               num_words=None,
167f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet               filters='!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n',
168f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet               lower=True,
169f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet               split=' ',
1700bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet               char_level=False,
1710bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet               oov_token=None,
1720bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet               **kwargs):
1730bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    # Legacy support
1740bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    if 'nb_words' in kwargs:
1750bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      logging.warning('The `nb_words` argument in `Tokenizer` '
1760bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet                      'has been renamed `num_words`.')
1770bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      num_words = kwargs.pop('nb_words')
1780bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    if kwargs:
1790bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
1800bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet
181d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet    self.word_counts = OrderedDict()
182f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.word_docs = {}
183f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.filters = filters
184f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.split = split
185f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.lower = lower
186f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.num_words = num_words
187f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.document_count = 0
188f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.char_level = char_level
1890bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    self.oov_token = oov_token
190f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
191f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def fit_on_texts(self, texts):
192f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """Updates internal vocabulary based on a list of texts.
193f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
194f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Required before using `texts_to_sequences` or `texts_to_matrix`.
195f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
196f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Arguments:
197f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        texts: can be a list of strings,
198f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            or a generator of strings (for memory-efficiency)
199f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """
200f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.document_count = 0
201f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    for text in texts:
202f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      self.document_count += 1
203f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      seq = text if self.char_level else text_to_word_sequence(
204f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          text, self.filters, self.lower, self.split)
205f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for w in seq:
206f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        if w in self.word_counts:
207f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          self.word_counts[w] += 1
208f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        else:
209f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          self.word_counts[w] = 1
210f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for w in set(seq):
211f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        if w in self.word_docs:
212f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          self.word_docs[w] += 1
213f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        else:
214f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          self.word_docs[w] = 1
215f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
216f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    wcounts = list(self.word_counts.items())
217f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    wcounts.sort(key=lambda x: x[1], reverse=True)
218f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    sorted_voc = [wc[0] for wc in wcounts]
219f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    # note that index 0 is reserved, never assigned to an existing word
220f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.word_index = dict(
2210bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet        list(zip(sorted_voc, list(range(1,
2220bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet                                        len(sorted_voc) + 1)))))
2230bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet
2240bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    if self.oov_token is not None:
2250bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      i = self.word_index.get(self.oov_token)
2260bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      if i is None:
2270bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet        self.word_index[self.oov_token] = len(self.word_index) + 1
228f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
229f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.index_docs = {}
230f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    for w, c in list(self.word_docs.items()):
231f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      self.index_docs[self.word_index[w]] = c
232f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
233f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def fit_on_sequences(self, sequences):
234f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """Updates internal vocabulary based on a list of sequences.
235f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
236f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Required before using `sequences_to_matrix`
237f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    (if `fit_on_texts` was never called).
238f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
239f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Arguments:
240f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        sequences: A list of sequence.
241f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            A "sequence" is a list of integer word indices.
242f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """
243f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.document_count = len(sequences)
244f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    self.index_docs = {}
245f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    for seq in sequences:
246f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      seq = set(seq)
247f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for i in seq:
248f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        if i not in self.index_docs:
249f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          self.index_docs[i] = 1
250f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        else:
251f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          self.index_docs[i] += 1
252f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
253f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def texts_to_sequences(self, texts):
254f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """Transforms each text in texts in a sequence of integers.
255f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
256f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Only top "num_words" most frequent words will be taken into account.
257f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Only words known by the tokenizer will be taken into account.
258f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
259f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Arguments:
260f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        texts: A list of texts (strings).
261f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
262f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Returns:
263f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        A list of sequences.
264f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """
265f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    res = []
266f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    for vect in self.texts_to_sequences_generator(texts):
267f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      res.append(vect)
268f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    return res
269f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
270f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def texts_to_sequences_generator(self, texts):
271f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """Transforms each text in texts in a sequence of integers.
272f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
273f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Only top "num_words" most frequent words will be taken into account.
274f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Only words known by the tokenizer will be taken into account.
275f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
276f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Arguments:
277f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        texts: A list of texts (strings).
278f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
279f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Yields:
280f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        Yields individual sequences.
281f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """
282f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    num_words = self.num_words
283f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    for text in texts:
284f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      seq = text if self.char_level else text_to_word_sequence(
285f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          text, self.filters, self.lower, self.split)
286f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      vect = []
287f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for w in seq:
288f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        i = self.word_index.get(w)
289f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        if i is not None:
290f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          if num_words and i >= num_words:
291f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            continue
292f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          else:
293f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            vect.append(i)
2940bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet        elif self.oov_token is not None:
2950bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet          i = self.word_index.get(self.oov_token)
2960bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet          if i is not None:
2970bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet            vect.append(i)
298f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      yield vect
299f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
300f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def texts_to_matrix(self, texts, mode='binary'):
301f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """Convert a list of texts to a Numpy matrix.
302f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
303f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Arguments:
304f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        texts: list of strings.
305f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        mode: one of "binary", "count", "tfidf", "freq".
306f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
307f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Returns:
308f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        A Numpy matrix.
309f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """
310f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    sequences = self.texts_to_sequences(texts)
311f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    return self.sequences_to_matrix(sequences, mode=mode)
312f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
313f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def sequences_to_matrix(self, sequences, mode='binary'):
314f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """Converts a list of sequences into a Numpy matrix.
315f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
316f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Arguments:
317f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        sequences: list of sequences
318f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            (a sequence is a list of integer word indices).
319f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        mode: one of "binary", "count", "tfidf", "freq"
320f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
321f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Returns:
322f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        A Numpy matrix.
323f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
324f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Raises:
325f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        ValueError: In case of invalid `mode` argument,
326f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet            or if the Tokenizer requires to be fit to sample data.
327f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """
328f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if not self.num_words:
329f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      if self.word_index:
330f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        num_words = len(self.word_index) + 1
331f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      else:
332f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        raise ValueError('Specify a dimension (num_words argument), '
333f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                         'or fit on some text data first.')
334f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    else:
335f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      num_words = self.num_words
336f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
337f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if mode == 'tfidf' and not self.document_count:
338f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      raise ValueError('Fit the Tokenizer on some data '
339f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                       'before using tfidf mode.')
340f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
341f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    x = np.zeros((len(sequences), num_words))
342f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    for i, seq in enumerate(sequences):
343f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      if not seq:
344f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        continue
345f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      counts = {}
346f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for j in seq:
347f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        if j >= num_words:
348f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          continue
349f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        if j not in counts:
350f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          counts[j] = 1.
351f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        else:
352f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          counts[j] += 1
353f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for j, c in list(counts.items()):
354f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        if mode == 'count':
355f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          x[i][j] = c
356f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        elif mode == 'freq':
357f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          x[i][j] = c / len(seq)
358f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        elif mode == 'binary':
359f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          x[i][j] = 1
360f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        elif mode == 'tfidf':
361f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          # Use weighting scheme 2 in
362f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          # https://en.wikipedia.org/wiki/Tf%E2%80%93idf
363f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          tf = 1 + np.log(c)
364b8b8ebcf851df71ebb5209ae27d75e2befc50f0dFrancois Chollet          idf = np.log(1 + self.document_count /
365b8b8ebcf851df71ebb5209ae27d75e2befc50f0dFrancois Chollet                       (1 + self.index_docs.get(j, 0)))
366f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          x[i][j] = tf * idf
367f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        else:
368f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          raise ValueError('Unknown vectorization mode:', mode)
369f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    return x
370