10bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
3f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Licensed under the Apache License, Version 2.0 (the "License");
4f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# you may not use this file except in compliance with the License.
5f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# You may obtain a copy of the License at
6f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
7f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#     http://www.apache.org/licenses/LICENSE-2.0
8f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet#
9f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Unless required by applicable law or agreed to in writing, software
10f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# distributed under the License is distributed on an "AS IS" BASIS,
11f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# See the License for the specific language governing permissions and
13f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# limitations under the License.
14f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# ==============================================================================
150bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet# pylint: disable=protected-access
160bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet"""Utilities related to layer/model functionality.
17f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet"""
18f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import absolute_import
19f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import division
20f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import print_function
21f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
22f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletimport numpy as np
23f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
24eaaa0b93852054dee086a3ed5373cf8bbe3d2fb3Francois Cholletfrom tensorflow.python.keras._impl.keras import backend as K
25eaaa0b93852054dee086a3ed5373cf8bbe3d2fb3Francois Cholletfrom tensorflow.python.keras._impl.keras.utils.conv_utils import convert_kernel
26e99724b78b9f6834b918ae8a599597f863cba8d4Anna Rfrom tensorflow.python.util.tf_export import tf_export
27f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
28f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
29144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhardef count_params(weights):
30144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar  """Count the total number of scalars composing the weights.
31144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar
32144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar  Arguments:
330bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      weights: An iterable containing the weights on which to compute params
34144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar
35144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar  Returns:
360bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      The total number of scalars composing the weights
37144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar  """
38144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar  return int(np.sum([K.count_params(p) for p in set(weights)]))
39144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar
40144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar
4124101b35f3baebbfff3d8057ac223b325bc415ceFrancois Cholletdef print_summary(model, line_length=None, positions=None, print_fn=None):
42f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """Prints a summary of a model.
43f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
44f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  Arguments:
45f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model: Keras model instance.
4624101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      line_length: Total length of printed lines
4724101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          (e.g. set this to adapt the display to different
4824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          terminal window sizes).
4924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      positions: Relative or absolute positions of log elements in each line.
50f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          If not provided, defaults to `[.33, .55, .67, 1.]`.
510bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet      print_fn: Print function to use.
5224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          It will be called on each line of the summary.
5324101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          You can set it to a custom function
5424101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet          in order to capture the string summary.
550bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet          It defaults to `print` (prints to stdout).
56f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """
5724101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  if print_fn is None:
5824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet    print_fn = print
5924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet
60f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  if model.__class__.__name__ == 'Sequential':
61f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    sequential_like = True
6203cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet  elif not model._is_graph_network:
6303cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet    # We treat subclassed models as a simple sequence of layers, for logging
6403cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet    # purposes.
6503cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet    sequential_like = True
66f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  else:
67f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    sequential_like = True
680bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    nodes_by_depth = model._nodes_by_depth.values()
69144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar    nodes = []
70144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar    for v in nodes_by_depth:
71d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet      if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1):
720bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet        # if the model has multiple nodes
730bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet        # or if the nodes have multiple inbound_layers
740bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet        # the model is no longer sequential
75f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        sequential_like = False
76d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet        break
77144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar      nodes += v
78144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar    if sequential_like:
79144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar      # search for shared layers
80144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar      for layer in model.layers:
81144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar        flag = False
820bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet        for node in layer._inbound_nodes:
83144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar          if node in nodes:
84144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar            if flag:
85144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar              sequential_like = False
86144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar              break
87144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar            else:
88144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar              flag = True
89144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar        if not sequential_like:
90144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar          break
91f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
92f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  if sequential_like:
93f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    line_length = line_length or 65
94f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    positions = positions or [.45, .85, 1.]
95f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if positions[-1] <= 1:
96f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      positions = [int(line_length * p) for p in positions]
97f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    # header names for the different log elements
98f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    to_display = ['Layer (type)', 'Output Shape', 'Param #']
99f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  else:
100144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar    line_length = line_length or 98
101f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    positions = positions or [.33, .55, .67, 1.]
102f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if positions[-1] <= 1:
103f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      positions = [int(line_length * p) for p in positions]
104f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    # header names for the different log elements
105f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
106f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    relevant_nodes = []
1070bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    for v in model._nodes_by_depth.values():
108f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      relevant_nodes += v
109f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
110f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def print_row(fields, positions):
111f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    line = ''
112f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    for i in range(len(fields)):
113f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      if i > 0:
114f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        line = line[:-1] + ' '
115f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      line += str(fields[i])
116f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      line = line[:positions[i]]
117f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      line += ' ' * (positions[i] - len(line))
11824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet    print_fn(line)
119f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
12024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  print_fn('_' * line_length)
121f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  print_row(to_display, positions)
12224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  print_fn('=' * line_length)
123f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
124f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def print_layer_summary(layer):
12503cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet    """Prints a summary for a single layer.
12603cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet
12703cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet    Arguments:
12803cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet        layer: target layer.
12903cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet    """
130f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    try:
131f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      output_shape = layer.output_shape
132f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    except AttributeError:
133f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      output_shape = 'multiple'
13403cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet    except RuntimeError:  # output_shape unknown in Eager mode.
13503cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet      output_shape = '?'
136f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    name = layer.name
137f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    cls_name = layer.__class__.__name__
138f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()]
139f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    print_row(fields, positions)
140f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
141f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  def print_layer_summary_with_connections(layer):
14203cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet    """Prints a summary for a single layer (including topological connections).
143f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
144f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    Arguments:
145f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        layer: target layer.
146f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    """
147f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    try:
148f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      output_shape = layer.output_shape
149f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    except AttributeError:
150f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      output_shape = 'multiple'
151f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    connections = []
1520bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    for node in layer._inbound_nodes:
153d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet      if relevant_nodes and node not in relevant_nodes:
154d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet        # node is not part of the current network
155d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet        continue
156f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for i in range(len(node.inbound_layers)):
157f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        inbound_layer = node.inbound_layers[i].name
158f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        inbound_node_index = node.node_indices[i]
159f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        inbound_tensor_index = node.tensor_indices[i]
1600bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet        connections.append(inbound_layer + '[' + str(inbound_node_index) +
1610bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet                           '][' + str(inbound_tensor_index) + ']')
162f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
163f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    name = layer.name
164f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    cls_name = layer.__class__.__name__
165f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if not connections:
166f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      first_connection = ''
167f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    else:
168f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      first_connection = connections[0]
169f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    fields = [
170d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet        name + ' (' + cls_name + ')', output_shape,
171d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet        layer.count_params(), first_connection
172f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    ]
173f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    print_row(fields, positions)
174f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if len(connections) > 1:
175f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      for i in range(1, len(connections)):
176f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        fields = ['', '', '', connections[i]]
177f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet        print_row(fields, positions)
178f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
179f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  layers = model.layers
180f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  for i in range(len(layers)):
181f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if sequential_like:
182f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      print_layer_summary(layers[i])
183f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    else:
184f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      print_layer_summary_with_connections(layers[i])
185f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if i == len(layers) - 1:
18624101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      print_fn('=' * line_length)
187f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    else:
18824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet      print_fn('_' * line_length)
189f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
1900bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet  model._check_trainable_weights_consistency()
191144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar  if hasattr(model, '_collected_trainable_weights'):
1920bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet    trainable_count = count_params(model._collected_trainable_weights)
193144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar  else:
194144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar    trainable_count = count_params(model.trainable_weights)
195144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar
196d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet  non_trainable_count = int(
197d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet      np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
198f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
19924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
20024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  print_fn('Trainable params: {:,}'.format(trainable_count))
20124101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
20224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet  print_fn('_' * line_length)
203f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
204f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
205e99724b78b9f6834b918ae8a599597f863cba8d4Anna R@tf_export('keras.utils.convert_all_kernels_in_model')
206f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletdef convert_all_kernels_in_model(model):
207f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """Converts all convolution kernels in a model from Theano to TensorFlow.
208f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
209f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  Also works from TensorFlow to Theano.
210f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
211f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  Arguments:
212f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      model: target model for the conversion.
213f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """
214f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  # Note: SeparableConvolution not included
215f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  # since only supported by TF.
216f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  conv_classes = {
217f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      'Conv1D',
218f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      'Conv2D',
219f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      'Conv3D',
220f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      'Conv2DTranspose',
221f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  }
222f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  to_assign = []
223f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  for layer in model.layers:
224f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if layer.__class__.__name__ in conv_classes:
225f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      original_kernel = K.get_value(layer.kernel)
226f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      converted_kernel = convert_kernel(original_kernel)
227f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      to_assign.append((layer.kernel, converted_kernel))
228f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  K.batch_set_value(to_assign)
229f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
230f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
231f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletdef convert_dense_weights_data_format(dense,
232f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                                      previous_feature_map_shape,
233f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet                                      target_data_format='channels_first'):
234f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """Utility useful when changing a convnet's `data_format`.
235f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
236f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  When porting the weights of a convnet from one data format to the other,
237f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  if the convnet includes a `Flatten` layer
238f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  (applied to the last convolutional feature map)
239f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  followed by a `Dense` layer, the weights of that `Dense` layer
240f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  should be updated to reflect the new dimension ordering.
241f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet
242f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  Arguments:
243f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      dense: The target `Dense` layer.
244f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      previous_feature_map_shape: A shape tuple of 3 integers,
245f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          e.g. `(512, 7, 7)`. The shape of the convolutional
246f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          feature map right before the `Flatten` layer that
247f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          came before the target `Dense` layer.
248f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      target_data_format: One of "channels_last", "channels_first".
249f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          Set it "channels_last"
250d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet          if converting a "channels_first" model to "channels_last",
251f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet          or reciprocally.
252f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  """
253f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  assert target_data_format in {'channels_last', 'channels_first'}
254f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  kernel, bias = dense.get_weights()
255f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  for i in range(kernel.shape[1]):
256f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    if target_data_format == 'channels_first':
257f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      c, h, w = previous_feature_map_shape
258f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      original_fm_shape = (h, w, c)
259f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      ki = kernel[:, i].reshape(original_fm_shape)
260f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      ki = np.transpose(ki, (2, 0, 1))  # last -> first
261f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    else:
262f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      h, w, c = previous_feature_map_shape
263f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      original_fm_shape = (c, h, w)
264f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      ki = kernel[:, i].reshape(original_fm_shape)
265f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet      ki = np.transpose(ki, (1, 2, 0))  # first -> last
266f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet    kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
267f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet  dense.set_weights([kernel, bias])
268