如何访问tensorflow层属性?

问题描述 投票:0回答:1

我是张量流新手。我有一个自定义张量流模型,它在多层中有一个变量。我想访问图像中显示的变量 y。

Model Layer

我尝试使用 getattribute 打印属性,但它抛出错误,因为 y 不是属性。我已经为图层对象打印了 dir() ,但无法找到任何内容来获取变量。

savedModel=tf.keras.models.load_model('./tf_model.h5')
k = savedModel.layers
for i in k:
    if i.name == 'tf.math.multiply_25':
        print(dir(i)) 
Output: 
['_TF_MODULE_IGNORED_PROPERTIES', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_activity_regularizer', '_add_trackable', '_add_trackable_child', '_add_variable_with_custom_getter', '_already_warned', '_auto_get_config', '_auto_track_sub_layers', '_autocast', '_autographed_call', '_build_input_shape', '_call_spec', '_call_wrapper', '_callable_losses', '_captured_weight_regularizer', '_cast_single_input', '_check_variables', '_checkpoint_dependencies', '_clear_losses', '_compute_dtype', '_compute_dtype_object', '_dedup_weights', '_deferred_dependencies', '_delete_tracking', '_deserialization_dependencies', '_deserialize_from_proto', '_dtype', '_dtype_policy', '_dynamic', '_eager_losses', '_expects_mask_arg', '_expects_training_arg', '_export_to_saved_model_graph', '_flatten', '_flatten_layers', '_flatten_modules', '_functional_construction_call', '_gather_children_attribute', '_gather_saveables_for_checkpoint', '_get_cell_name', '_get_existing_metric', '_get_input_masks', '_get_node_attribute_at_index', '_get_save_spec', '_get_trainable_state', '_get_unnested_name_scope', '_handle_activity_regularization', '_handle_deferred_dependencies', '_handle_weight_regularization', '_inbound_nodes', '_inbound_nodes_value', '_infer_output_signature', '_init_call_fn_args', '_init_set_name', '_initial_weights', '_input_spec', '_instrument_layer_creation', '_instrumented_keras_api', '_instrumented_keras_layer_class', '_instrumented_keras_model_class', '_is_layer', '_keras_api_names', '_keras_api_names_v1', '_keras_tensor_symbolic_call', '_load_own_variables', '_lookup_dependency', '_losses', '_map_resources', '_maybe_build', '_maybe_cast_inputs', '_maybe_create_attribute', '_maybe_initialize_trackable', '_metrics', '_metrics_lock', '_must_restore_from_config', '_name', '_name_based_attribute_restore', '_name_based_restores', '_name_scope', '_name_scope_on_declaration', '_no_dependency', '_non_trainable_weights', '_obj_reference_counts', '_obj_reference_counts_dict', '_object_identifier', '_outbound_nodes', '_outbound_nodes_value', '_preload_simple_restoration', '_preserve_input_structure_in_config', '_restore_from_tensors', '_save_own_variables', '_saved_model_arg_spec', '_saved_model_inputs_spec', '_self_setattr_tracking', '_self_tracked_trackables', '_serialize_to_proto', '_serialize_to_tensors', '_set_connectivity_metadata', '_set_dtype_policy', '_set_mask_keras_history_checked', '_set_mask_metadata', '_set_save_spec', '_set_trainable_state', '_set_training_mode', '_setattr_tracking', '_should_cast_single_input', '_stateful', '_supports_masking', '_tf_api_names', '_tf_api_names_v1', '_thread_local', '_track_trackable', '_track_variable', '_track_variables', '_trackable_children', '_trackable_saved_model_saver', '_tracking_metadata', '_trainable', '_trainable_weights', '_unconditional_checkpoint_dependencies', '_unconditional_dependency_names', '_update_trackables', '_update_uid', '_updates', '_use_input_spec_as_call_signature', '_warn', 'activity_regularizer', 'add_loss', 'add_metric', 'add_update', 'add_variable', 'add_weight', 'build', 'built', 'call', 'compute_dtype', 'compute_mask', 'compute_output_shape', 'compute_output_signature', 'count_params', 'dtype', 'dtype_policy', 'dynamic', 'finalize_state', 'from_config', 'function', 'get_config', 'get_input_at', 'get_input_mask_at', 'get_input_shape_at', 'get_output_at', 'get_output_mask_at', 'get_output_shape_at', 'get_weights', 'inbound_nodes', 'input', 'input_mask', 'input_shape', 'input_spec', 'losses', 'metrics', 'name', 'name_scope', 'non_trainable_variables', 'non_trainable_weights', 'outbound_nodes', 'output', 'output_mask', 'output_shape', 'set_weights', 'stateful', 'submodules', 'supports_masking', 'symbol', 'trainable', 'trainable_variables', 'trainable_weights', 'updates', 'variable_dtype', 'variables', 'weights', 'with_name_scope']

有人可以指导我如何访问它吗? 我正在使用 python - 3.9,tensorflow - 2.10

提前致谢

python tensorflow deep-learning tensorflow2.0
1个回答
0
投票

您可以使用 Pythonic 方法,如下所示:

继续您附加的代码片段。
import tensorflow as tf
from tensorflow import keras

savedModel=tf.keras.models.load_model('./tf_model.h5')

first_layer = savedModel.layers[0]

for layer in savedModel.layers:
    print(layer.output_shape)
© www.soinside.com 2019 - 2024. All rights reserved.