10bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# 3f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Licensed under the Apache License, Version 2.0 (the "License"); 4f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# you may not use this file except in compliance with the License. 5f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# You may obtain a copy of the License at 6f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# 7f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# http://www.apache.org/licenses/LICENSE-2.0 8f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# 9f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# Unless required by applicable law or agreed to in writing, software 10f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# distributed under the License is distributed on an "AS IS" BASIS, 11f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# See the License for the specific language governing permissions and 13f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# limitations under the License. 14f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet# ============================================================================== 150bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet# pylint: disable=protected-access 160bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet"""Utilities related to layer/model functionality. 17f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet""" 18f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import absolute_import 19f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import division 20f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletfrom __future__ import print_function 21f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 22f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletimport numpy as np 23f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 24eaaa0b93852054dee086a3ed5373cf8bbe3d2fb3Francois Cholletfrom tensorflow.python.keras._impl.keras import backend as K 25eaaa0b93852054dee086a3ed5373cf8bbe3d2fb3Francois Cholletfrom tensorflow.python.keras._impl.keras.utils.conv_utils import convert_kernel 26e99724b78b9f6834b918ae8a599597f863cba8d4Anna Rfrom tensorflow.python.util.tf_export import tf_export 27f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 28f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 29144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhardef count_params(weights): 30144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar """Count the total number of scalars composing the weights. 31144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar 32144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar Arguments: 330bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet weights: An iterable containing the weights on which to compute params 34144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar 35144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar Returns: 360bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet The total number of scalars composing the weights 37144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar """ 38144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar return int(np.sum([K.count_params(p) for p in set(weights)])) 39144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar 40144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar 4124101b35f3baebbfff3d8057ac223b325bc415ceFrancois Cholletdef print_summary(model, line_length=None, positions=None, print_fn=None): 42f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet """Prints a summary of a model. 43f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 44f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet Arguments: 45f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet model: Keras model instance. 4624101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet line_length: Total length of printed lines 4724101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet (e.g. set this to adapt the display to different 4824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet terminal window sizes). 4924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet positions: Relative or absolute positions of log elements in each line. 50f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet If not provided, defaults to `[.33, .55, .67, 1.]`. 510bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet print_fn: Print function to use. 5224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet It will be called on each line of the summary. 5324101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet You can set it to a custom function 5424101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet in order to capture the string summary. 550bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet It defaults to `print` (prints to stdout). 56f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet """ 5724101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet if print_fn is None: 5824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn = print 5924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet 60f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if model.__class__.__name__ == 'Sequential': 61f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet sequential_like = True 6203cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet elif not model._is_graph_network: 6303cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet # We treat subclassed models as a simple sequence of layers, for logging 6403cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet # purposes. 6503cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet sequential_like = True 66f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet else: 67f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet sequential_like = True 680bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet nodes_by_depth = model._nodes_by_depth.values() 69144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar nodes = [] 70144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar for v in nodes_by_depth: 71d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet if (len(v) > 1) or (len(v) == 1 and len(v[0].inbound_layers) > 1): 720bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet # if the model has multiple nodes 730bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet # or if the nodes have multiple inbound_layers 740bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet # the model is no longer sequential 75f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet sequential_like = False 76d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet break 77144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar nodes += v 78144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar if sequential_like: 79144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar # search for shared layers 80144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar for layer in model.layers: 81144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar flag = False 820bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet for node in layer._inbound_nodes: 83144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar if node in nodes: 84144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar if flag: 85144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar sequential_like = False 86144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar break 87144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar else: 88144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar flag = True 89144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar if not sequential_like: 90144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar break 91f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 92f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if sequential_like: 93f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet line_length = line_length or 65 94f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet positions = positions or [.45, .85, 1.] 95f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if positions[-1] <= 1: 96f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet positions = [int(line_length * p) for p in positions] 97f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet # header names for the different log elements 98f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet to_display = ['Layer (type)', 'Output Shape', 'Param #'] 99f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet else: 100144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar line_length = line_length or 98 101f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet positions = positions or [.33, .55, .67, 1.] 102f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if positions[-1] <= 1: 103f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet positions = [int(line_length * p) for p in positions] 104f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet # header names for the different log elements 105f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] 106f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet relevant_nodes = [] 1070bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet for v in model._nodes_by_depth.values(): 108f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet relevant_nodes += v 109f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 110f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet def print_row(fields, positions): 111f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet line = '' 112f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet for i in range(len(fields)): 113f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if i > 0: 114f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet line = line[:-1] + ' ' 115f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet line += str(fields[i]) 116f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet line = line[:positions[i]] 117f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet line += ' ' * (positions[i] - len(line)) 11824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn(line) 119f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 12024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn('_' * line_length) 121f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet print_row(to_display, positions) 12224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn('=' * line_length) 123f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 124f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet def print_layer_summary(layer): 12503cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet """Prints a summary for a single layer. 12603cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet 12703cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet Arguments: 12803cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet layer: target layer. 12903cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet """ 130f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet try: 131f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet output_shape = layer.output_shape 132f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet except AttributeError: 133f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet output_shape = 'multiple' 13403cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet except RuntimeError: # output_shape unknown in Eager mode. 13503cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet output_shape = '?' 136f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet name = layer.name 137f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet cls_name = layer.__class__.__name__ 138f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()] 139f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet print_row(fields, positions) 140f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 141f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet def print_layer_summary_with_connections(layer): 14203cebfcc80169b867ff87f700bdfd27e28e1c7bcFrancois Chollet """Prints a summary for a single layer (including topological connections). 143f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 144f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet Arguments: 145f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet layer: target layer. 146f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet """ 147f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet try: 148f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet output_shape = layer.output_shape 149f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet except AttributeError: 150f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet output_shape = 'multiple' 151f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet connections = [] 1520bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet for node in layer._inbound_nodes: 153d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet if relevant_nodes and node not in relevant_nodes: 154d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet # node is not part of the current network 155d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet continue 156f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet for i in range(len(node.inbound_layers)): 157f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet inbound_layer = node.inbound_layers[i].name 158f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet inbound_node_index = node.node_indices[i] 159f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet inbound_tensor_index = node.tensor_indices[i] 1600bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet connections.append(inbound_layer + '[' + str(inbound_node_index) + 1610bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet '][' + str(inbound_tensor_index) + ']') 162f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 163f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet name = layer.name 164f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet cls_name = layer.__class__.__name__ 165f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if not connections: 166f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet first_connection = '' 167f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet else: 168f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet first_connection = connections[0] 169f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet fields = [ 170d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet name + ' (' + cls_name + ')', output_shape, 171d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet layer.count_params(), first_connection 172f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet ] 173f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet print_row(fields, positions) 174f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if len(connections) > 1: 175f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet for i in range(1, len(connections)): 176f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet fields = ['', '', '', connections[i]] 177f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet print_row(fields, positions) 178f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 179f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet layers = model.layers 180f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet for i in range(len(layers)): 181f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if sequential_like: 182f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet print_layer_summary(layers[i]) 183f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet else: 184f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet print_layer_summary_with_connections(layers[i]) 185f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if i == len(layers) - 1: 18624101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn('=' * line_length) 187f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet else: 18824101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn('_' * line_length) 189f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 1900bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet model._check_trainable_weights_consistency() 191144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar if hasattr(model, '_collected_trainable_weights'): 1920bd0bf02aa15a3238b77053a2f0ad6fe373c7d1cFrancois Chollet trainable_count = count_params(model._collected_trainable_weights) 193144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar else: 194144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar trainable_count = count_params(model.trainable_weights) 195144eaa8e273da43b7ca881d7dcac98b65f698f11Anjali Sridhar 196d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet non_trainable_count = int( 197d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet np.sum([K.count_params(p) for p in set(model.non_trainable_weights)])) 198f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 19924101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count)) 20024101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn('Trainable params: {:,}'.format(trainable_count)) 20124101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn('Non-trainable params: {:,}'.format(non_trainable_count)) 20224101b35f3baebbfff3d8057ac223b325bc415ceFrancois Chollet print_fn('_' * line_length) 203f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 204f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 205e99724b78b9f6834b918ae8a599597f863cba8d4Anna R@tf_export('keras.utils.convert_all_kernels_in_model') 206f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletdef convert_all_kernels_in_model(model): 207f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet """Converts all convolution kernels in a model from Theano to TensorFlow. 208f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 209f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet Also works from TensorFlow to Theano. 210f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 211f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet Arguments: 212f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet model: target model for the conversion. 213f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet """ 214f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet # Note: SeparableConvolution not included 215f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet # since only supported by TF. 216f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet conv_classes = { 217f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 'Conv1D', 218f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 'Conv2D', 219f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 'Conv3D', 220f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 'Conv2DTranspose', 221f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet } 222f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet to_assign = [] 223f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet for layer in model.layers: 224f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if layer.__class__.__name__ in conv_classes: 225f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet original_kernel = K.get_value(layer.kernel) 226f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet converted_kernel = convert_kernel(original_kernel) 227f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet to_assign.append((layer.kernel, converted_kernel)) 228f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet K.batch_set_value(to_assign) 229f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 230f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 231f49f801276154d0f693c5d57db6977a7eb32f017Francois Cholletdef convert_dense_weights_data_format(dense, 232f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet previous_feature_map_shape, 233f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet target_data_format='channels_first'): 234f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet """Utility useful when changing a convnet's `data_format`. 235f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 236f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet When porting the weights of a convnet from one data format to the other, 237f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if the convnet includes a `Flatten` layer 238f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet (applied to the last convolutional feature map) 239f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet followed by a `Dense` layer, the weights of that `Dense` layer 240f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet should be updated to reflect the new dimension ordering. 241f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet 242f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet Arguments: 243f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet dense: The target `Dense` layer. 244f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet previous_feature_map_shape: A shape tuple of 3 integers, 245f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet e.g. `(512, 7, 7)`. The shape of the convolutional 246f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet feature map right before the `Flatten` layer that 247f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet came before the target `Dense` layer. 248f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet target_data_format: One of "channels_last", "channels_first". 249f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet Set it "channels_last" 250d21bf7d7502f447e5f967a479282b32b5845ba8bFrancois Chollet if converting a "channels_first" model to "channels_last", 251f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet or reciprocally. 252f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet """ 253f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet assert target_data_format in {'channels_last', 'channels_first'} 254f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet kernel, bias = dense.get_weights() 255f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet for i in range(kernel.shape[1]): 256f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet if target_data_format == 'channels_first': 257f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet c, h, w = previous_feature_map_shape 258f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet original_fm_shape = (h, w, c) 259f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet ki = kernel[:, i].reshape(original_fm_shape) 260f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet ki = np.transpose(ki, (2, 0, 1)) # last -> first 261f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet else: 262f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet h, w, c = previous_feature_map_shape 263f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet original_fm_shape = (c, h, w) 264f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet ki = kernel[:, i].reshape(original_fm_shape) 265f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet ki = np.transpose(ki, (1, 2, 0)) # first -> last 266f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) 267f49f801276154d0f693c5d57db6977a7eb32f017Francois Chollet dense.set_weights([kernel, bias]) 268