1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A TFGAN-backed GAN Estimator.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples 24from tensorflow.contrib.gan.python import train as tfgan_train 25from tensorflow.python.estimator import model_fn as model_fn_lib 26from tensorflow.python.estimator.canned import head 27from tensorflow.python.framework import ops 28 29__all__ = [ 30 'GANHead', 31 'gan_head', 32] 33 34 35def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, 36 discriminator_optimizer, use_loss_summaries=True, 37 get_hooks_fn=tfgan_train.get_sequential_train_hooks(), 38 name=None): 39 """Creates a `GANHead`. 40 41 Args: 42 generator_loss_fn: A TFGAN loss function for the generator. Takes a 43 `GANModel` and returns a scalar. 44 discriminator_loss_fn: Same as `generator_loss_fn`, but for the 45 discriminator. 46 generator_optimizer: The optimizer for generator updates. 47 discriminator_optimizer: Same as `generator_optimizer`, but for the 48 discriminator updates. 49 use_loss_summaries: If `True`, add loss summaries. If `False`, does not. 50 If `None`, uses defaults. 51 get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list 52 of hooks. 53 name: name of the head. If provided, summary and metrics keys will be 54 suffixed by `"/" + name`. 55 56 Returns: 57 An instance of `GANHead`. 58 """ 59 return GANHead(generator_loss_fn=generator_loss_fn, 60 discriminator_loss_fn=discriminator_loss_fn, 61 generator_optimizer=generator_optimizer, 62 discriminator_optimizer=discriminator_optimizer, 63 use_loss_summaries=use_loss_summaries, 64 get_hooks_fn=get_hooks_fn, 65 name=name) 66 67 68class GANHead(head._Head): # pylint: disable=protected-access 69 """`Head` for a GAN.""" 70 71 def __init__(self, generator_loss_fn, discriminator_loss_fn, 72 generator_optimizer, discriminator_optimizer, 73 use_loss_summaries=True, 74 get_hooks_fn=None, 75 name=None): 76 """`Head` for GAN training. 77 78 Args: 79 generator_loss_fn: A TFGAN loss function for the generator. Takes a 80 `GANModel` and returns a scalar. 81 discriminator_loss_fn: Same as `generator_loss_fn`, but for the 82 discriminator. 83 generator_optimizer: The optimizer for generator updates. 84 discriminator_optimizer: Same as `generator_optimizer`, but for the 85 discriminator updates. 86 use_loss_summaries: If `True`, add loss summaries. If `False`, does not. 87 If `None`, uses defaults. 88 get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list 89 of hooks. Defaults to `train.get_sequential_train_hooks()` 90 name: name of the head. If provided, summary and metrics keys will be 91 suffixed by `"/" + name`. 92 """ 93 if get_hooks_fn is None: 94 get_hooks_fn = tfgan_train.get_sequential_train_hooks() 95 # TODO(joelshor): Validate inputs. 96 97 if use_loss_summaries in [True, False]: 98 generator_loss_fn = functools.partial( 99 generator_loss_fn, add_summaries=use_loss_summaries) 100 discriminator_loss_fn = functools.partial( 101 discriminator_loss_fn, add_summaries=use_loss_summaries) 102 self._generator_loss_fn = generator_loss_fn 103 self._discriminator_loss_fn = discriminator_loss_fn 104 self._generator_optimizer = generator_optimizer 105 self._discriminator_optimizer = discriminator_optimizer 106 self._get_hooks_fn = get_hooks_fn 107 108 @property 109 def name(self): 110 return self._name 111 112 @property 113 def logits_dimension(self): 114 return None 115 116 def create_loss(self, features, mode, logits, labels): 117 """Returns a GANLoss tuple from the provided GANModel. 118 119 See `Head` for more details. 120 121 Args: 122 features: Input `dict` of `Tensor` objects. Unused. 123 mode: Estimator's `ModeKeys`. 124 logits: A GANModel tuple. 125 labels: Must be `None`. 126 127 Returns: 128 A GANLoss tuple. 129 130 """ 131 _validate_logits_and_labels(logits, labels) 132 del mode, labels, features # unused for this head. 133 gan_model = logits # rename variable for clarity 134 return tfgan_tuples.GANLoss( 135 generator_loss=self._generator_loss_fn(gan_model), 136 discriminator_loss=self._discriminator_loss_fn(gan_model)) 137 138 def create_estimator_spec( 139 self, features, mode, logits, labels=None, 140 train_op_fn=tfgan_train.gan_train_ops): 141 """Returns `EstimatorSpec` that a model_fn can return. 142 143 See `Head` for more details. 144 145 Args: 146 features: Must be `None`. 147 mode: Estimator's `ModeKeys`. 148 logits: A GANModel tuple. 149 labels: Must be `None`. 150 train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, 151 and discriminator optimizer, and returns a `GANTrainOps` tuple. For 152 example, this function can come from TFGAN's `train.py` library, or can 153 be custom. 154 155 Returns: 156 `EstimatorSpec`. 157 158 Raises: 159 ValueError: If `features` isn't `None`. 160 ValueError: If `train_op_fn` isn't provided in train mode. 161 """ 162 _validate_logits_and_labels(logits, labels) 163 if features is not None: 164 raise ValueError('`features` should be `None`. Instead, found: %s' % 165 features) 166 gan_model = logits # rename variable for clarity 167 with ops.name_scope('GANHead'): 168 if mode == model_fn_lib.ModeKeys.PREDICT: 169 return model_fn_lib.EstimatorSpec( 170 mode=model_fn_lib.ModeKeys.PREDICT, 171 predictions=gan_model.generated_data) 172 elif mode == model_fn_lib.ModeKeys.EVAL: 173 gan_loss = self.create_loss( 174 features=None, mode=mode, logits=gan_model, labels=None) 175 scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss 176 return model_fn_lib.EstimatorSpec( 177 mode=model_fn_lib.ModeKeys.EVAL, 178 predictions=gan_model.generated_data, 179 loss=scalar_loss, 180 # TODO(joelshor): Add metrics. If head name provided, append it to 181 # metric keys. 182 eval_metric_ops={}) 183 elif mode == model_fn_lib.ModeKeys.TRAIN: 184 if train_op_fn is None: 185 raise ValueError('train_op_fn can not be None.') 186 gan_loss = self.create_loss(None, mode, gan_model, None) 187 scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss 188 train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer, 189 self._discriminator_optimizer) 190 training_hooks = self._get_hooks_fn(train_ops) 191 return model_fn_lib.EstimatorSpec( 192 loss=scalar_loss, 193 mode=model_fn_lib.ModeKeys.TRAIN, 194 train_op=train_ops.global_step_inc_op, 195 training_hooks=training_hooks) 196 else: 197 raise ValueError('Mode not recognized: %s' % mode) 198 199 200def _validate_logits_and_labels(logits, labels): 201 if labels is not None: 202 raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must ' 203 'be `None`. Instead, found: %s' % labels) 204 205 if not isinstance(logits, tfgan_tuples.GANModel): 206 raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must ' 207 'be an instnace of a `GANModel`. Instead, found: %s' % 208 logits) 209