1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import csv
21
22from tensorflow.contrib import lookup
23from tensorflow.contrib.text.python.ops import gen_skip_gram_ops
24from tensorflow.contrib.util import loader
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import random_seed
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import random_ops
31from tensorflow.python.platform import gfile
32from tensorflow.python.platform import resource_loader
33from tensorflow.python.training import input as input_ops
34
35_checkpoint_ops_so = loader.load_op_library(
36    resource_loader.get_path_to_datafile("_skip_gram_ops.so"))
37
38ops.NotDifferentiable("SkipGramGenerateCandidates")
39
40
41def skip_gram_sample(input_tensor,
42                     min_skips=1,
43                     max_skips=5,
44                     start=0,
45                     limit=-1,
46                     emit_self_as_target=False,
47                     vocab_freq_table=None,
48                     vocab_min_count=None,
49                     vocab_subsampling=None,
50                     corpus_size=None,
51                     batch_size=None,
52                     batch_capacity=None,
53                     seed=None,
54                     name=None):
55  """Generates skip-gram token and label paired Tensors from the input tensor.
56
57  Generates skip-gram `("token", "label")` pairs using each element in the
58  rank-1 `input_tensor` as a token. The window size used for each token will be
59  randomly selected from the range specified by `[min_skips, max_skips]`,
60  inclusive. See https://arxiv.org/abs/1301.3781 for more details about
61  skip-gram.
62
63  For example, given `input_tensor = ["the", "quick", "brown", "fox", "jumps"]`,
64  `min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`, the output
65  `(tokens, labels)` pairs for the token "quick" will be randomly selected from
66  either `(tokens=["quick", "quick"], labels=["the", "brown"])` for 1 skip, or
67  `(tokens=["quick", "quick", "quick"], labels=["the", "brown", "fox"])` for 2
68  skips.
69
70  If `emit_self_as_target = True`, each token will also be emitted as a label
71  for itself. From the previous example, the output will be either
72  `(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])` for 1
73  skip, or `(tokens=["quick", "quick", "quick", "quick"], labels=["the",
74  "quick", "brown", "fox"])` for 2 skips.
75
76  The same process is repeated for each element of `input_tensor` and
77  concatenated together into the two output rank-1 `Tensors` (one for all the
78  tokens, another for all the labels).
79
80  If `vocab_freq_table` is specified, tokens in `input_tensor` that are not
81  present in the vocabulary are discarded. Tokens whose frequency counts are
82  below `vocab_min_count` are also discarded. Tokens whose frequency proportions
83  in the corpus exceed `vocab_subsampling` may be randomly down-sampled. See
84  Eq. 5 in http://arxiv.org/abs/1310.4546 for more details about subsampling.
85
86  Due to the random window sizes used for each token, the lengths of the outputs
87  are non-deterministic, unless `batch_size` is specified to batch the outputs
88  to always return `Tensors` of length `batch_size`.
89
90  Args:
91    input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
92    min_skips: `int` or scalar `Tensor` specifying the minimum window size to
93      randomly use for each token. Must be >= 0 and <= `max_skips`. If
94      `min_skips` and `max_skips` are both 0, the only label outputted will be
95      the token itself when `emit_self_as_target = True` - or no output
96      otherwise.
97    max_skips: `int` or scalar `Tensor` specifying the maximum window size to
98      randomly use for each token. Must be >= 0.
99    start: `int` or scalar `Tensor` specifying the position in
100      `input_tensor` from which to start generating skip-gram candidates.
101    limit: `int` or scalar `Tensor` specifying the maximum number of
102      elements in `input_tensor` to use in generating skip-gram candidates. -1
103      means to use the rest of the `Tensor` after `start`.
104    emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
105      each token as a label for itself.
106    vocab_freq_table: (Optional) A lookup table (subclass of
107      `lookup.InitializableLookupTableBase`) that maps tokens to their raw
108      frequency counts. If specified, any token in `input_tensor` that is not
109      found in `vocab_freq_table` will be filtered out before generating
110      skip-gram candidates. While this will typically map to integer raw
111      frequency counts, it could also map to float frequency proportions.
112      `vocab_min_count` and `corpus_size` should be in the same units as this.
113    vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying
114      minimum frequency threshold (from `vocab_freq_table`) for a token to be
115      kept in `input_tensor`. If this is specified, `vocab_freq_table` must also
116      be specified - and they should both be in the same units.
117    vocab_subsampling: (Optional) `float` specifying frequency proportion
118      threshold for tokens from `input_tensor`. Tokens that occur more
119      frequently (based on the ratio of the token's `vocab_freq_table` value to
120      the `corpus_size`) will be randomly down-sampled. Reasonable starting
121      values may be around 1e-3 or 1e-5. If this is specified, both
122      `vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5
123      in http://arxiv.org/abs/1310.4546 for more details.
124    corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
125      total number of tokens in the corpus (e.g., sum of all the frequency
126      counts of `vocab_freq_table`). Used with `vocab_subsampling` for
127      down-sampling frequently occurring tokens. If this is specified,
128      `vocab_freq_table` and `vocab_subsampling` must also be specified.
129    batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
130    batch_capacity: (Optional) `int` specifying batch capacity for the queue
131      used for batching returned `Tensors`. Only has an effect if
132      `batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
133    seed: (Optional) `int` used to create a random seed for window size and
134      subsampling. See `set_random_seed` docs for behavior.
135    name: (Optional) A `string` name or a name scope for the operations.
136
137  Returns:
138    A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
139    rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
140    length `batch_size`; if `batch_size` is not specified, they will be of
141    random length, though they will be in sync with each other as long as they
142    are evaluated together.
143
144  Raises:
145    ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`,
146      `vocab_subsampling`, or `corpus_size` is specified. If `vocab_subsampling`
147      and `corpus_size` are not both present or both absent.
148  """
149
150  if vocab_freq_table is None and (vocab_min_count is not None or
151                                   vocab_subsampling is not None or
152                                   corpus_size is not None):
153    raise ValueError(
154        "vocab_freq_table is not provided, but vocab_min_count={}, "
155        "vocab_subsampling={}, or corpus_size={} is not None. These settings "
156        "are useless without a vocab_freq_table.".format(
157            vocab_min_count, vocab_subsampling, corpus_size))
158
159  if (vocab_subsampling is None) != (corpus_size is None):
160    raise ValueError(
161        "vocab_subsampling is {} while corpus_size is {} - both must be "
162        "provided in order for subsampling to work.".format(
163            vocab_subsampling, corpus_size))
164
165  with ops.name_scope(
166      name,
167      "skip_gram_sample",
168      values=[input_tensor, min_skips, max_skips, start, limit]):
169
170    input_tensor = _filter_input(
171        input_tensor=input_tensor,
172        vocab_freq_table=vocab_freq_table,
173        vocab_min_count=vocab_min_count,
174        vocab_subsampling=vocab_subsampling,
175        corpus_size=corpus_size,
176        seed=seed)
177
178    seed1, seed2 = random_seed.get_seed(seed)
179    tokens, labels = gen_skip_gram_ops.skip_gram_generate_candidates(
180        input_tensor=input_tensor,
181        min_skips=min_skips,
182        max_skips=max_skips,
183        start=start,
184        limit=limit,
185        emit_self_as_target=emit_self_as_target,
186        # Note that seed here should be seed1! This is due to
187        # GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2".
188        seed=seed1,
189        seed2=seed2)
190
191    # TODO(weiho): If the need arises, add support for sparse input_tensor that
192    # figures out sentence boundaries, then calls
193    # skip_gram_generate_candidates() on each sentence.
194
195    # Batches the (tokens, labels) outputs so that they will be of deterministic
196    # batch_size, to facilitate feeding them into the rest of the network.
197    if batch_size is not None and batch_size > 0:
198      batch_capacity = (batch_capacity
199                        if (batch_capacity is not None and batch_capacity > 0)
200                        else 100 * batch_size)
201      return input_ops.batch(
202          [tokens, labels],
203          batch_size,
204          capacity=batch_capacity,
205          enqueue_many=True)
206
207    return tokens, labels
208
209
210def skip_gram_sample_with_text_vocab(input_tensor,
211                                     vocab_freq_file,
212                                     vocab_token_index=0,
213                                     vocab_token_dtype=dtypes.string,
214                                     vocab_freq_index=1,
215                                     vocab_freq_dtype=dtypes.float64,
216                                     vocab_delimiter=",",
217                                     vocab_min_count=0,
218                                     vocab_subsampling=None,
219                                     corpus_size=None,
220                                     min_skips=1,
221                                     max_skips=5,
222                                     start=0,
223                                     limit=-1,
224                                     emit_self_as_target=False,
225                                     batch_size=None,
226                                     batch_capacity=None,
227                                     seed=None,
228                                     name=None):
229  """Skip-gram sampling with a text vocabulary file.
230
231  Wrapper around `skip_gram_sample()` for use with a text vocabulary file. The
232  vocabulary file is expected to be a plain-text file, with lines of
233  `vocab_delimiter`-separated columns. The `vocab_token_index` column should
234  contain the vocabulary term, while the `vocab_freq_index` column should
235  contain the number of times that term occurs in the corpus. For example, with
236  a text vocabulary file of:
237
238    ```
239    bonjour,fr,42
240    hello,en,777
241    hola,es,99
242    ```
243
244  You should set `vocab_delimiter=","`, `vocab_token_index=0`, and
245  `vocab_freq_index=2`.
246
247  See `skip_gram_sample()` documentation for more details about the skip-gram
248  sampling process.
249
250  Args:
251    input_tensor: A rank-1 `Tensor` from which to generate skip-gram candidates.
252    vocab_freq_file: `string` specifying full file path to the text vocab file.
253    vocab_token_index: `int` specifying which column in the text vocab file
254      contains the tokens.
255    vocab_token_dtype: `DType` specifying the format of the tokens in the text
256      vocab file.
257    vocab_freq_index: `int` specifying which column in the text vocab file
258      contains the frequency counts of the tokens.
259    vocab_freq_dtype: `DType` specifying the format of the frequency counts in
260      the text vocab file.
261    vocab_delimiter: `string` specifying the delimiter used in the text vocab
262      file.
263    vocab_min_count: `int`, `float`, or scalar `Tensor` specifying
264      minimum frequency threshold (from `vocab_freq_file`) for a token to be
265      kept in `input_tensor`. This should correspond with `vocab_freq_dtype`.
266    vocab_subsampling: (Optional) `float` specifying frequency proportion
267      threshold for tokens from `input_tensor`. Tokens that occur more
268      frequently will be randomly down-sampled. Reasonable starting values may
269      be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 for
270      more details.
271    corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
272      total number of tokens in the corpus (e.g., sum of all the frequency
273      counts of `vocab_freq_file`). Used with `vocab_subsampling` for
274      down-sampling frequently occurring tokens. If this is specified,
275      `vocab_freq_file` and `vocab_subsampling` must also be specified.
276      If `corpus_size` is needed but not supplied, then it will be calculated
277      from `vocab_freq_file`. You might want to supply your own value if you
278      have already eliminated infrequent tokens from your vocabulary files
279      (where frequency < vocab_min_count) to save memory in the internal token
280      lookup table. Otherwise, the unused tokens' variables will waste memory.
281      The user-supplied `corpus_size` value must be greater than or equal to the
282      sum of all the frequency counts of `vocab_freq_file`.
283    min_skips: `int` or scalar `Tensor` specifying the minimum window size to
284      randomly use for each token. Must be >= 0 and <= `max_skips`. If
285      `min_skips` and `max_skips` are both 0, the only label outputted will be
286      the token itself.
287    max_skips: `int` or scalar `Tensor` specifying the maximum window size to
288      randomly use for each token. Must be >= 0.
289    start: `int` or scalar `Tensor` specifying the position in `input_tensor`
290      from which to start generating skip-gram candidates.
291    limit: `int` or scalar `Tensor` specifying the maximum number of elements in
292      `input_tensor` to use in generating skip-gram candidates. -1 means to use
293      the rest of the `Tensor` after `start`.
294    emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
295      each token as a label for itself.
296    batch_size: (Optional) `int` specifying batch size of returned `Tensors`.
297    batch_capacity: (Optional) `int` specifying batch capacity for the queue
298      used for batching returned `Tensors`. Only has an effect if
299      `batch_size` > 0. Defaults to 100 * `batch_size` if not specified.
300    seed: (Optional) `int` used to create a random seed for window size and
301      subsampling. See
302      [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed)
303      for behavior.
304    name: (Optional) A `string` name or a name scope for the operations.
305
306  Returns:
307    A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
308    rank-1 and has the same type as `input_tensor`. The `Tensors` will be of
309    length `batch_size`; if `batch_size` is not specified, they will be of
310    random length, though they will be in sync with each other as long as they
311    are evaluated together.
312
313  Raises:
314    ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 or
315      exceeds the number of columns in `vocab_freq_file`. If `vocab_token_index`
316      and `vocab_freq_index` are both set to the same column. If any token in
317      `vocab_freq_file` has a negative frequency.
318  """
319
320  if vocab_token_index < 0 or vocab_freq_index < 0:
321    raise ValueError(
322        "vocab_token_index={} and vocab_freq_index={} must both be >= 0.".
323        format(vocab_token_index, vocab_freq_index))
324  if vocab_token_index == vocab_freq_index:
325    raise ValueError(
326        "vocab_token_index and vocab_freq_index should be different, but are "
327        "both {}.".format(vocab_token_index))
328
329  # Iterates through the vocab file and calculates the number of vocab terms as
330  # well as the total corpus size (by summing the frequency counts of all the
331  # vocab terms).
332  calculated_corpus_size = 0.0
333  vocab_size = 0
334  with gfile.GFile(vocab_freq_file, mode="r") as f:
335    reader = csv.reader(f, delimiter=vocab_delimiter)
336    for row in reader:
337      if vocab_token_index >= len(row) or vocab_freq_index >= len(row):
338        raise ValueError(
339            "Row in vocab file only has {} columns, so vocab_token_index={} or "
340            "vocab_freq_index={} is out of bounds. Row content: {}".format(
341                len(row), vocab_token_index, vocab_freq_index, row))
342      vocab_size += 1
343      freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index])
344      if freq < 0:
345        raise ValueError(
346            "Row in vocab file has negative frequency of {}. Row content: {}".
347            format(freq, row))
348      # Note: tokens whose frequencies are below vocab_min_count will still
349      # contribute to the total corpus size used for vocab subsampling.
350      calculated_corpus_size += freq
351
352  if not corpus_size:
353    corpus_size = calculated_corpus_size
354  elif calculated_corpus_size - corpus_size > 1e-6:
355    raise ValueError(
356        "`corpus_size`={} must be greater than or equal to the sum of all the "
357        "frequency counts ({}) of `vocab_freq_file` ({}).".format(
358            corpus_size, calculated_corpus_size, vocab_freq_file))
359
360  vocab_freq_table = lookup.HashTable(
361      lookup.TextFileInitializer(
362          filename=vocab_freq_file,
363          key_dtype=vocab_token_dtype,
364          key_index=vocab_token_index,
365          value_dtype=vocab_freq_dtype,
366          value_index=vocab_freq_index,
367          vocab_size=vocab_size,
368          delimiter=vocab_delimiter),
369      # For vocab terms not in vocab file, use a default value of -1.
370      default_value=-1)
371
372  return skip_gram_sample(
373      input_tensor,
374      min_skips=min_skips,
375      max_skips=max_skips,
376      start=start,
377      limit=limit,
378      emit_self_as_target=emit_self_as_target,
379      vocab_freq_table=vocab_freq_table,
380      vocab_min_count=vocab_min_count,
381      vocab_subsampling=vocab_subsampling,
382      # corpus_size is not used unless vocab_subsampling is specified.
383      corpus_size=None if vocab_subsampling is None else corpus_size,
384      batch_size=batch_size,
385      batch_capacity=batch_capacity,
386      seed=seed,
387      name=name)
388
389
390def _filter_input(input_tensor, vocab_freq_table, vocab_min_count,
391                  vocab_subsampling, corpus_size, seed):
392  """Filters input tensor based on vocab freq, threshold, and subsampling."""
393  if vocab_freq_table is None:
394    return input_tensor
395
396  if not isinstance(vocab_freq_table, lookup.InitializableLookupTableBase):
397    raise ValueError(
398        "vocab_freq_table must be a subclass of "
399        "InitializableLookupTableBase (such as HashTable) instead of type "
400        "{}.".format(type(vocab_freq_table)))
401
402  with ops.name_scope(
403      "filter_vocab", values=[vocab_freq_table, input_tensor, vocab_min_count]):
404    freq = vocab_freq_table.lookup(input_tensor)
405    # Filters out elements in input_tensor that are not found in
406    # vocab_freq_table (table returns a default value of -1 specified above when
407    # an element is not found).
408    mask = math_ops.not_equal(freq, vocab_freq_table.default_value)
409
410    # Filters out elements whose vocab frequencies are less than the threshold.
411    if vocab_min_count is not None:
412      cast_threshold = math_ops.cast(vocab_min_count, freq.dtype)
413      mask = math_ops.logical_and(mask,
414                                  math_ops.greater_equal(freq, cast_threshold))
415
416    input_tensor = array_ops.boolean_mask(input_tensor, mask)
417    freq = array_ops.boolean_mask(freq, mask)
418
419  if not vocab_subsampling:
420    return input_tensor
421
422  if vocab_subsampling < 0 or vocab_subsampling > 1:
423    raise ValueError(
424        "Invalid vocab_subsampling={} - it should be within range [0, 1].".
425        format(vocab_subsampling))
426
427  # Subsamples the input tokens based on vocabulary frequency and
428  # vocab_subsampling threshold (ie randomly discard commonly appearing
429  # tokens).
430  with ops.name_scope(
431      "subsample_vocab", values=[input_tensor, freq, vocab_subsampling]):
432    corpus_size = math_ops.cast(corpus_size, dtypes.float64)
433    freq = math_ops.cast(freq, dtypes.float64)
434    vocab_subsampling = math_ops.cast(vocab_subsampling, dtypes.float64)
435
436    # From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is
437    # suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546.
438    keep_prob = ((math_ops.sqrt(freq /
439                                (vocab_subsampling * corpus_size)) + 1.0) *
440                 (vocab_subsampling * corpus_size / freq))
441    random_prob = random_ops.random_uniform(
442        array_ops.shape(freq),
443        minval=0,
444        maxval=1,
445        dtype=dtypes.float64,
446        seed=seed)
447
448    mask = math_ops.less_equal(random_prob, keep_prob)
449    return array_ops.boolean_mask(input_tensor, mask)
450