我为VGG块编写了以下代码,我想显示块的摘要:
import tensorflow as tf
from keras.layers import Conv2D, MaxPool2D, Input
class VggBlock(tf.keras.Model):
def __init__(self, filters, repetitions):
super(VggBlock, self).__init__()
self.repetitions = repetitions
for i in range(repetitions):
vars(self)[f'conv2D_{i}'] = Conv2D(filters=filters, kernel_size=(3, 3), padding='same', activation='relu')
self.max_pool = MaxPool2D(pool_size=(2, 2))
def call(self, inputs):
x = vars(self)['conv2D_0'](inputs)
for i in range(1, self.repetitions):
x = vars(self)[f'conv2D_{i}'](x)
return self.max_pool(x)
test_block = VggBlock(64, 2)
temp_inputs = Input(shape=(224, 224, 3))
test_block(temp_inputs)
test_block.summary()
然后,此代码将输出以下内容:
Model: "vgg_block"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
max_pooling2d (MaxPooling2D multiple 0
)
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
我试着明确地判断了这些层:
for layer in test_block.layers:
print(layer)
此输出仅显示一个层:
<keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7f6c18377f50>
然而,转换层在字典形式中存在得很好:
print(vars(test_block))
{'_self_setattr_tracking': True, '_is_model_for_instrumentation': True, '_instrumented_keras_api': True, '_instrumented_keras_layer_class': False, '_instrumented_keras_model_class': True, '_trainable': True, '_stateful': False, 'built': True, '_input_spec': None, '_build_input_shape': None, '_saved_model_inputs_spec': TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='input_10'), '_saved_model_arg_spec': ([TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='input_10')], {}), '_supports_masking': False, '_name': 'vgg_block_46', '_activity_regularizer': None, '_trainable_weights': [], '_non_trainable_weights': [], '_updates': [], '_thread_local': <_thread._local object at 0x7fb9084d9ef0>, '_callable_losses': [], '_losses': [], '_metrics': [], '_metrics_lock': <unlocked _thread.lock object at 0x7fb90d88abd0>, '_dtype_policy': <Policy "float32">, '_compute_dtype_object': tf.float32, '_autocast': True, '_self_tracked_trackables': [<keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>], '_inbound_nodes_value': [<keras.engine.node.Node object at 0x7fb9087146d0>], '_outbound_nodes_value': [], '_expects_training_arg': False, '_default_training_arg': None, '_expects_mask_arg': False, '_dynamic': False, '_initial_weights': None, '_auto_track_sub_layers': True, '_preserve_input_structure_in_config': False, '_name_scope_on_declaration': '', '_captured_weight_regularizer': [], '_is_graph_network': False, 'inputs': None, 'outputs': None, 'input_names': None, 'output_names': None, 'stop_training': False, 'history': None, 'compiled_loss': None, 'compiled_metrics': None, '_compute_output_and_mask_jointly': False, '_is_compiled': False, 'optimizer': None, '_distribution_strategy': None, '_cluster_coordinator': None, '_run_eagerly': None, 'train_function': None, 'test_function': None, 'predict_function': None, 'train_tf_function': None, '_compiled_trainable_state': <WeakKeyDictionary at 0x7fb9084b0790>, '_training_state': None, '_self_unconditional_checkpoint_dependencies': [TrackableReference(name=max_pool, ref=<keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>)], '_self_unconditional_dependency_names': {'max_pool': <keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>}, '_self_unconditional_deferred_dependencies': {}, '_self_update_uid': -1, '_self_name_based_restores': set(), '_self_saveable_object_factories': {}, '_checkpoint': <tensorflow.python.training.tracking.util.Checkpoint object at 0x7fb9084b0910>, '_steps_per_execution': None, '_train_counter': <tf.Variable 'Variable:0' shape=() dtype=int64, numpy=0>, '_test_counter': <tf.Variable 'Variable:0' shape=() dtype=int64, numpy=0>, '_predict_counter': <tf.Variable 'Variable:0' shape=() dtype=int64, numpy=0>, '_base_model_initialized': True, '_jit_compile': None, '_layout_map': None, '_obj_reference_counts_dict': ObjectIdentityDictionary({<_ObjectIdentityWrapper wrapping 3>: 1, <_ObjectIdentityWrapper wrapping <keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>>: 1}), 'repetitions': 3,
'conv2D_0': <keras.layers.convolutional.conv2d.Conv2D object at 0x7fb90852e390>, 'conv2D_1': <keras.layers.convolutional.conv2d.Conv2D object at 0x7fb90852ed90>, 'conv2D_2': <keras.layers.convolutional.conv2d.Conv2D object at 0x7fb9084dac90>, 'max_pool': <keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>}
vars()
会不会让这一层显得有些怪异?
如何才能正确显示层或参数?