我为VGG块编写了以下代码,我想显示块的摘要:

import tensorflow as tf
from keras.layers import Conv2D, MaxPool2D, Input

class VggBlock(tf.keras.Model):
  def __init__(self, filters, repetitions):
    super(VggBlock, self).__init__()    
    self.repetitions = repetitions
    
    for i in range(repetitions):
      vars(self)[f'conv2D_{i}'] = Conv2D(filters=filters, kernel_size=(3, 3), padding='same', activation='relu')
    self.max_pool = MaxPool2D(pool_size=(2, 2))
  
  def call(self, inputs):
    x = vars(self)['conv2D_0'](inputs)
    for i in range(1, self.repetitions):
      x = vars(self)[f'conv2D_{i}'](x)
    return self.max_pool(x)


test_block = VggBlock(64, 2)
temp_inputs = Input(shape=(224, 224, 3))
test_block(temp_inputs)
test_block.summary()

然后,此代码将输出以下内容:

Model: "vgg_block"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 max_pooling2d (MaxPooling2D  multiple                 0         
 )                                                               
                                                                 
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________

我试着明确地判断了这些层:

for layer in test_block.layers:
  print(layer)

此输出仅显示一个层:

<keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7f6c18377f50>

然而,转换层在字典形式中存在得很好:

print(vars(test_block))
{'_self_setattr_tracking': True, '_is_model_for_instrumentation': True, '_instrumented_keras_api': True, '_instrumented_keras_layer_class': False, '_instrumented_keras_model_class': True, '_trainable': True, '_stateful': False, 'built': True, '_input_spec': None, '_build_input_shape': None, '_saved_model_inputs_spec': TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='input_10'), '_saved_model_arg_spec': ([TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='input_10')], {}), '_supports_masking': False, '_name': 'vgg_block_46', '_activity_regularizer': None, '_trainable_weights': [], '_non_trainable_weights': [], '_updates': [], '_thread_local': <_thread._local object at 0x7fb9084d9ef0>, '_callable_losses': [], '_losses': [], '_metrics': [], '_metrics_lock': <unlocked _thread.lock object at 0x7fb90d88abd0>, '_dtype_policy': <Policy "float32">, '_compute_dtype_object': tf.float32, '_autocast': True, '_self_tracked_trackables': [<keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>], '_inbound_nodes_value': [<keras.engine.node.Node object at 0x7fb9087146d0>], '_outbound_nodes_value': [], '_expects_training_arg': False, '_default_training_arg': None, '_expects_mask_arg': False, '_dynamic': False, '_initial_weights': None, '_auto_track_sub_layers': True, '_preserve_input_structure_in_config': False, '_name_scope_on_declaration': '', '_captured_weight_regularizer': [], '_is_graph_network': False, 'inputs': None, 'outputs': None, 'input_names': None, 'output_names': None, 'stop_training': False, 'history': None, 'compiled_loss': None, 'compiled_metrics': None, '_compute_output_and_mask_jointly': False, '_is_compiled': False, 'optimizer': None, '_distribution_strategy': None, '_cluster_coordinator': None, '_run_eagerly': None, 'train_function': None, 'test_function': None, 'predict_function': None, 'train_tf_function': None, '_compiled_trainable_state': <WeakKeyDictionary at 0x7fb9084b0790>, '_training_state': None, '_self_unconditional_checkpoint_dependencies': [TrackableReference(name=max_pool, ref=<keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>)], '_self_unconditional_dependency_names': {'max_pool': <keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>}, '_self_unconditional_deferred_dependencies': {}, '_self_update_uid': -1, '_self_name_based_restores': set(), '_self_saveable_object_factories': {}, '_checkpoint': <tensorflow.python.training.tracking.util.Checkpoint object at 0x7fb9084b0910>, '_steps_per_execution': None, '_train_counter': <tf.Variable 'Variable:0' shape=() dtype=int64, numpy=0>, '_test_counter': <tf.Variable 'Variable:0' shape=() dtype=int64, numpy=0>, '_predict_counter': <tf.Variable 'Variable:0' shape=() dtype=int64, numpy=0>, '_base_model_initialized': True, '_jit_compile': None, '_layout_map': None, '_obj_reference_counts_dict': ObjectIdentityDictionary({<_ObjectIdentityWrapper wrapping 3>: 1, <_ObjectIdentityWrapper wrapping <keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>>: 1}), 'repetitions': 3, 
'conv2D_0': <keras.layers.convolutional.conv2d.Conv2D object at 0x7fb90852e390>, 'conv2D_1': <keras.layers.convolutional.conv2d.Conv2D object at 0x7fb90852ed90>, 'conv2D_2': <keras.layers.convolutional.conv2d.Conv2D object at 0x7fb9084dac90>, 'max_pool': <keras.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb9084e2510>}

vars()会不会让这一层显得有些怪异? 如何才能正确显示层或参数?

推荐答案

想必,这样设置类属性可以避免Kera Layer通常所做的内务工作(例如注册变量、子层等),因此您应该避免这样做.宁可做这样的事情:

class VggBlock(tf.keras.Model):
  def __init__(self, filters, repetitions):
    super(VggBlock, self).__init__()    
    self.repetitions = repetitions
    
    self.conv_layers = [Conv2D(filters=filters, kernel_size=(3, 3), padding='same', activation='relu') for _ in range(repetitions)]
    self.max_pool = MaxPool2D(pool_size=(2, 2))
  
  def call(self, inputs):
    x = inputs
    for layer in self.conv_layers:
      x = layer(x)
    return self.max_pool(x)


test_block = VggBlock(64, 2)
temp_inputs = Input(shape=(224, 224, 3))
test_block(temp_inputs)
test_block.summary()

在这里,我们使用列表来存储层,并且仍然可以使用循环来完成.这是打印的

Model: "vgg_block_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv2d_4 (Conv2D)           multiple                  1792      
                                                                 
 conv2d_5 (Conv2D)           multiple                  36928     
                                                                 
 max_pooling2d_2 (MaxPooling  multiple                 0         
 2D)                                                             
                                                                 
=================================================================
Total params: 38,720
Trainable params: 38,720
Non-trainable params: 0
_________________________________________________________________

如果不需要在摘要中明确给出层,可以使用Sequential来简化调用方法:

class VggBlock(tf.keras.Model):
  def __init__(self, filters, repetitions):
    super(VggBlock, self).__init__()    
    self.repetitions = repetitions
    
    self.conv_layers = tf.keras.Sequential([Conv2D(filters=filters, kernel_size=(3, 3), padding='same', activation='relu') for _ in range(repetitions)])
    self.max_pool = MaxPool2D(pool_size=(2, 2))
  
  def call(self, inputs):
    x = self.conv_layers(inputs)
    return self.max_pool(x)


test_block = VggBlock(64, 2)
temp_inputs = Input(shape=(224, 224, 3))
test_block(temp_inputs)
test_block.summary()

它在功能上是相同的,但显示了您可能不需要的摘要:

Model: "vgg_block_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 sequential (Sequential)     (None, 224, 224, 64)      38720     
                                                                 
 max_pooling2d_3 (MaxPooling  multiple                 0         
 2D)                                                             
                                                                 
=================================================================
Total params: 38,720
Trainable params: 38,720
Non-trainable params: 0
_________________________________________________________________

Python相关问答推荐

Python daskValue错误:无法识别的区块管理器dask -必须是以下之一:[]

试图找到Python方法来部分填充numpy数组

max_of_three使用First_select、second_select、

try 在树叶 map 上应用覆盖磁贴

如何在类和classy-fastapi -fastapi- followup中使用FastAPI创建路由

Python—从np.array中 Select 复杂的列子集

从spaCy的句子中提取日期

无法在Docker内部运行Python的Matlab SDK模块,但本地没有问题

转换为浮点,pandas字符串列,混合千和十进制分隔符

基于行条件计算(pandas)

如何创建引用列表并分配值的Systemrame列

搜索按钮不工作,Python tkinter

不允许 Select 北极滚动?

为什么t sns.barplot图例不显示所有值?'

如何获得3D点的平移和旋转,给定的点已经旋转?

如何设置nan值为numpy数组多条件

为什么Visual Studio Code说我的代码在使用Pandas concat函数后无法访问?

替换包含Python DataFrame中的值的<;

按最大属性值Django对对象进行排序

Pandas:根据相邻行之间的差异过滤数据帧