flutter 中的 TensorFlowInferenceInterface

问题描述 投票:0回答:1

我有一个 Android 中的对象检测代码,我想将其转换为 Flutter 以在 IOS 上也使用它

public static Classifier create(
            AssetManager assetManager,
            String modelFilename,
            String[] labels,
            int inputSize,
            int imageMean,
            float imageStd,
            String inputName,
            String outputName) {
        final TensorFlowImageClassifier c = new TensorFlowImageClassifier();
        c.inputName = inputName;
        c.outputName = outputName;

        // Read the label names into memory.
        Collections.addAll(c.labels, labels);

        c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);

        // The shape of the output is [N, NUM_CLASSES], where N is the batch size.
        final Operation operation = c.inferenceInterface.graphOperation(outputName);
        final int numClasses = (int) operation.output(0).shape().size(1);

        // Ideally, inputSize could have been retrieved from the shape of the input operation.  Alas,
        // the placeholder node for input in the graphdef typically used does not specify a shape, so it
        // must be passed in as a parameter.
        c.inputSize = inputSize;
        c.imageMean = imageMean;
        c.imageStd = imageStd;

        // Pre-allocate buffers.
        c.outputNames = new String[]{outputName};
        c.intValues = new int[inputSize * inputSize];
        c.floatValues = new float[inputSize * inputSize * 3];
        c.outputs = new float[numClasses];

        return c;
    }

打电话:

create(
                                getAssets(),
                                "file:///android_asset/xxx",
                                getResources().getStringArray(R.array.yyy),
                                INPUT_SIZE,
                                128,
                                128,
                                "input",
                                "InceptionV3/Predictions/Reshape_1")

org.tensorflow.contrib.android.TensorFlowInferenceInterface

flutter中没有找到这个类,我应该用什么来替换它?

android flutter tensorflow machine-learning tflite
1个回答
0
投票

您可以使用 flutter_tflite 包来推断 flutter 中的 tflite 模型。由于您使用它来进行对象检测,因此您需要弄清楚的不仅仅是这些。

  1. 加载模型和标签。这可以通过 flutter_tflite 来完成
  2. 获取图像并为模型准备图像(根据模型输入大小裁剪图像并根据模型要求创建字节数组)
  3. 使用图像运行模型。
  4. 清理输出(如果需要 nms)并获取边界框。

flutter_tflite 包中有一些示例。但这完全取决于您的模型类型。

© www.soinside.com 2019 - 2024. All rights reserved.