1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A TFGAN-backed GAN Estimator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
24from tensorflow.contrib.gan.python import train as tfgan_train
25from tensorflow.python.estimator import model_fn as model_fn_lib
26from tensorflow.python.estimator.canned import head
27from tensorflow.python.framework import ops
28
29__all__ = [
30    'GANHead',
31    'gan_head',
32]
33
34
35def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
36             discriminator_optimizer, use_loss_summaries=True,
37             get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
38             name=None):
39  """Creates a `GANHead`.
40
41  Args:
42    generator_loss_fn: A TFGAN loss function for the generator. Takes a
43      `GANModel` and returns a scalar.
44    discriminator_loss_fn: Same as `generator_loss_fn`, but for the
45      discriminator.
46    generator_optimizer: The optimizer for generator updates.
47    discriminator_optimizer: Same as `generator_optimizer`, but for the
48      discriminator updates.
49    use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
50        If `None`, uses defaults.
51    get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
52        of hooks.
53    name: name of the head. If provided, summary and metrics keys will be
54      suffixed by `"/" + name`.
55
56  Returns:
57    An instance of `GANHead`.
58  """
59  return GANHead(generator_loss_fn=generator_loss_fn,
60                 discriminator_loss_fn=discriminator_loss_fn,
61                 generator_optimizer=generator_optimizer,
62                 discriminator_optimizer=discriminator_optimizer,
63                 use_loss_summaries=use_loss_summaries,
64                 get_hooks_fn=get_hooks_fn,
65                 name=name)
66
67
68class GANHead(head._Head):  # pylint: disable=protected-access
69  """`Head` for a GAN."""
70
71  def __init__(self, generator_loss_fn, discriminator_loss_fn,
72               generator_optimizer, discriminator_optimizer,
73               use_loss_summaries=True,
74               get_hooks_fn=None,
75               name=None):
76    """`Head` for GAN training.
77
78    Args:
79      generator_loss_fn: A TFGAN loss function for the generator. Takes a
80        `GANModel` and returns a scalar.
81      discriminator_loss_fn: Same as `generator_loss_fn`, but for the
82      discriminator.
83      generator_optimizer: The optimizer for generator updates.
84      discriminator_optimizer: Same as `generator_optimizer`, but for the
85        discriminator updates.
86      use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
87        If `None`, uses defaults.
88      get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
89        of hooks. Defaults to `train.get_sequential_train_hooks()`
90      name: name of the head. If provided, summary and metrics keys will be
91        suffixed by `"/" + name`.
92    """
93    if get_hooks_fn is None:
94      get_hooks_fn = tfgan_train.get_sequential_train_hooks()
95    # TODO(joelshor): Validate inputs.
96
97    if use_loss_summaries in [True, False]:
98      generator_loss_fn = functools.partial(
99          generator_loss_fn, add_summaries=use_loss_summaries)
100      discriminator_loss_fn = functools.partial(
101          discriminator_loss_fn, add_summaries=use_loss_summaries)
102    self._generator_loss_fn = generator_loss_fn
103    self._discriminator_loss_fn = discriminator_loss_fn
104    self._generator_optimizer = generator_optimizer
105    self._discriminator_optimizer = discriminator_optimizer
106    self._get_hooks_fn = get_hooks_fn
107
108  @property
109  def name(self):
110    return self._name
111
112  @property
113  def logits_dimension(self):
114    return None
115
116  def create_loss(self, features, mode, logits, labels):
117    """Returns a GANLoss tuple from the provided GANModel.
118
119    See `Head` for more details.
120
121    Args:
122      features: Input `dict` of `Tensor` objects. Unused.
123      mode: Estimator's `ModeKeys`.
124      logits: A GANModel tuple.
125      labels: Must be `None`.
126
127    Returns:
128      A GANLoss tuple.
129
130    """
131    _validate_logits_and_labels(logits, labels)
132    del mode, labels, features  # unused for this head.
133    gan_model = logits  # rename variable for clarity
134    return tfgan_tuples.GANLoss(
135        generator_loss=self._generator_loss_fn(gan_model),
136        discriminator_loss=self._discriminator_loss_fn(gan_model))
137
138  def create_estimator_spec(
139      self, features, mode, logits, labels=None,
140      train_op_fn=tfgan_train.gan_train_ops):
141    """Returns `EstimatorSpec` that a model_fn can return.
142
143    See `Head` for more details.
144
145    Args:
146      features: Must be `None`.
147      mode: Estimator's `ModeKeys`.
148      logits: A GANModel tuple.
149      labels: Must be `None`.
150      train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer,
151        and discriminator optimizer, and returns a `GANTrainOps` tuple. For
152        example, this function can come from TFGAN's `train.py` library, or can
153        be custom.
154
155    Returns:
156      `EstimatorSpec`.
157
158    Raises:
159      ValueError: If `features` isn't `None`.
160      ValueError: If `train_op_fn` isn't provided in train mode.
161    """
162    _validate_logits_and_labels(logits, labels)
163    if features is not None:
164      raise ValueError('`features` should be `None`. Instead, found: %s' %
165                       features)
166    gan_model = logits  # rename variable for clarity
167    with ops.name_scope('GANHead'):
168      if mode == model_fn_lib.ModeKeys.PREDICT:
169        return model_fn_lib.EstimatorSpec(
170            mode=model_fn_lib.ModeKeys.PREDICT,
171            predictions=gan_model.generated_data)
172      elif mode == model_fn_lib.ModeKeys.EVAL:
173        gan_loss = self.create_loss(
174            features=None, mode=mode, logits=gan_model, labels=None)
175        scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
176        return model_fn_lib.EstimatorSpec(
177            mode=model_fn_lib.ModeKeys.EVAL,
178            predictions=gan_model.generated_data,
179            loss=scalar_loss,
180            # TODO(joelshor): Add metrics. If head name provided, append it to
181            # metric keys.
182            eval_metric_ops={})
183      elif mode == model_fn_lib.ModeKeys.TRAIN:
184        if train_op_fn is None:
185          raise ValueError('train_op_fn can not be None.')
186        gan_loss = self.create_loss(None, mode, gan_model, None)
187        scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
188        train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer,
189                                self._discriminator_optimizer)
190        training_hooks = self._get_hooks_fn(train_ops)
191        return model_fn_lib.EstimatorSpec(
192            loss=scalar_loss,
193            mode=model_fn_lib.ModeKeys.TRAIN,
194            train_op=train_ops.global_step_inc_op,
195            training_hooks=training_hooks)
196      else:
197        raise ValueError('Mode not recognized: %s' % mode)
198
199
200def _validate_logits_and_labels(logits, labels):
201  if labels is not None:
202    raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must '
203                     'be `None`. Instead, found: %s' % labels)
204
205  if not isinstance(logits, tfgan_tuples.GANModel):
206    raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must '
207                     'be an instnace of a `GANModel`. Instead, found: %s' %
208                     logits)
209