将keras模型导入tensorflow java后预测步骤出错

问题描述 投票:0回答:1

我的目标是在 java 程序中使用 Keras 模型。 我使用

model.export()
而不是 model.save() 导出 keras 模型,因此我得到了一个包含 .pb 格式模型的文件夹。

然后我使用

py .\saved_model_cli.py show -- dir '.' -all 
查看输入和输出以填写 java 代码。 我明白了:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serve']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['keras_tensor'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 6)
        name: serve_keras_tensor:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_0'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['keras_tensor'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 6)
        name: serving_default_keras_tensor:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_0'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: StatefulPartitionedCall_1:0
  Method name is: tensorflow/serving/predict
The MetaGraph with tag set ['serve'] contains the following ops: {'ReadVariableOp', 'Select', 'StatefulPartitionedCall', 'RestoreV2', 'NoOp', 'Identity', 'StaticRegexFullMatch', 'StringJoin', 'AssignVariableOp', 'SaveV2', 'MergeV2Checkpoints', 'VarIsInitializedOp', 'AddV2', 'VarHandleOp', 'DisableCopyOnRead', 'Pack', 'Placeholder', 'MatMul', 'Const', 'Relu', 'ShardedFilename'}

Concrete Functions:2024-11-12 16:47:24.597134: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

  Function Name: 'serve'
    Option #1
      Callable with:
        Argument #1
          keras_tensor: TensorSpec(shape=(None, 6), dtype=tf.float32, name='keras_tensor')

最后,导入并进行预测的java代码是:

public static void importKerasModel() {
        try (SavedModelBundle model = SavedModelBundle.load("PATH\kerasModel", "serve")) {
            float[] x = {0.48f, 0.48f, 0.48f, 0.48f, 0.48f, 0.48f};
            try (Tensor input = TFloat32.vectorOf(x);
                 Tensor output = model.session()
                         .runner()
                         .feed("serve_keras_tensor", input)
                         .fetch("StatefulPartitionedCall")
                         .run()
                         .get(0)) {

                float prediction = output.dataType().getNumber();
                System.out.println("prediction = " + prediction);
            }
        }
    }

但是我收到此错误消息:

2024-11-12 17:26:01.089591: I tensorflow/cc/saved_model/loader.cc:317] SavedModel load for tags { serve }; Status: success: OK. Took 61548 microseconds.
2024-11-12 17:26:01.317247: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: In[0] is not a matrix
     [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential_1/dense_1/Relu}}]]
Exception in thread "main" org.tensorflow.exceptions.TFInvalidArgumentException: In[0] is not a matrix
     [[{{node StatefulPartitionedCall/StatefulPartitionedCall/sequential_1/dense_1/Relu}}]]
    at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:76)
    at org.tensorflow.Session.run(Session.java:826)
    at org.tensorflow.Session$Runner.runHelper(Session.java:549)
    at org.tensorflow.Session$Runner.run(Session.java:476)
    at com.ptvgroup.platform.truckslogs.converter.HelloTensorFlow.importKerasModel(HelloTensorFlow.java:471)
    at com.ptvgroup.platform.truckslogs.converter.Main.main(Main.java:25)

有人可以帮助我吗? “In[0] 不是矩阵”是什么意思?这是因为我的输入和输出的尺寸/形状是(-1,6)和(-1,1)?

python java tensorflow keras
1个回答
0
投票

当我想要进行预测时,错误“In[0] 不是矩阵”来自错误的数据类型。我创建一个向量张量而不是矩阵张量。异常消息很好地表明该值不是矩阵。相关的值是输入的张量。

public static void importKerasModel() {
        try (SavedModelBundle model = SavedModelBundle.load("PATH", "serve")) {
            float[] x = {200f,0f,1.5f,2f,2.5f,0f};
            FloatNdArray matrix = NdArrays.ofFloats(Shape.of(1, 6)); // my model have 6 features per observations
            matrix.set(NdArrays.vectorOf(x), 0);
            try (Tensor input = TFloat32.tensorOf(matrix);
                 Tensor output = model.session()
                         .runner()
                         .feed("serve_keras_tensor_272", input)  ///## to know inputs and outputs  py .\saved_model_cli.py show --dir '.' --all
                         .fetch("StatefulPartitionedCall")
                         .run()
                         .get(0)) {
                float prediction = output.dataType().getNumber();
                System.out.println("prediction = " + prediction);
            }
        }
    }

此代码修复了失败的执行。但无论输入值如何,我得到的预测值始终等于 1。所以,总是有一些错误,但我不知道这是否是另一个故事。

© www.soinside.com 2019 - 2024. All rights reserved.