输入张量和输出张量必须匹配才能成功进行对象检测吗?

问题描述 投票:0回答:1

我在自定义数据上训练了 YOLOv8 模型,当我使用 YOLOv8 检测功能时它运行良好,但我想在 Android(甚至目前是 iOS 设备)上运行它。 我将其导出到 torchscript,这就是我碰壁的地方。 我使用针对移动 torchscript api 进行优化的模型编写了一个 Android 应用程序。 .pt 和 .ptl 文件之间的大小差异可以忽略不计,因此我不确定它是否真的优化了。

相机输入的张量大小为 Tensor([1, 3, 1024, 1024], dtype=torch.float32), 输出张量为 Tensor([1, 31, 21504], dtype=torch.float32)。

这可能是由于形状不匹配错误导致检测无法正常工作的原因吗? 不过,调试控制台确实没有显示任何不匹配错误。 另外我遇到了这个错误 java.lang.IndexOutOfBoundsException: Index: 37822, Size: 27。我的模型中有 27 个类。 我在 youtube 上做了一个教程,所以我不能把所有代码都归功于我,但这里是部分代码片段(我发布了减去导入的代码)。

public class MainActivity extends AppCompatActivity {

    private ListenableFuture<ProcessCameraProvider> cameraProviderFuture;

    PreviewView previewView;

    TextView textView;

    private int REQUEST_CODE_PERMISSION = 101;

    List<String> classes;

    private final String[] REQUIRED_PERMISSIONS = new String[] {"android.permission.CAMERA"};

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        previewView = findViewById(R.id.cameraView);
        textView = findViewById(R.id.result_text);

        if (!checkPermissions()) {
            ActivityCompat.requestPermissions(this, REQUIRED_PERMISSIONS, REQUEST_CODE_PERMISSION);
        }

        classes = LoadClasses("labels.txt");
        LoadTorchModule("model.ptl");
        cameraProviderFuture = ProcessCameraProvider.getInstance(this);
        cameraProviderFuture.addListener(() -> {
            try {
                ProcessCameraProvider cameraProvider = cameraProviderFuture.get();
                startCamera(cameraProvider);
            } catch (ExecutionException | InterruptedException e) {
                // Errors
            }
        }, ContextCompat.getMainExecutor(this));
    }

    private boolean checkPermissions() {
        for (String permission : REQUIRED_PERMISSIONS) {
            if (ContextCompat.checkSelfPermission(this, permission) != PackageManager.PERMISSION_GRANTED) {
                return false;
            }
        }
        return true;
    }

    Executor executor = Executors.newSingleThreadExecutor();

    void startCamera(@NotNull ProcessCameraProvider cameraProvider) {
        Preview preview = new Preview.Builder().build();
        CameraSelector cameraSelector = new CameraSelector.Builder().requireLensFacing(CameraSelector.LENS_FACING_BACK).build();
        preview.setSurfaceProvider(previewView.getSurfaceProvider());

        ImageAnalysis imageAnalysis = new ImageAnalysis.Builder().setTargetResolution(new Size(1024, 1024)).setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST).build();

        imageAnalysis.setAnalyzer(executor, new ImageAnalysis.Analyzer() {
            @Override
            public void analyze(@NonNull ImageProxy image) {
                int rotation = image.getImageInfo().getRotationDegrees();
                analyzeImage(image, rotation);
                image.close();
            }
        });

        Camera camera = cameraProvider.bindToLifecycle((LifecycleOwner) this, cameraSelector, preview, imageAnalysis);
    }

    Module module;

    void LoadTorchModule(String fileName) {
        File modelFile = new File(this.getFilesDir(), fileName);
        try {
            if (!modelFile.exists()) {
                InputStream inputStream = getAssets().open(fileName);
                FileOutputStream outputStream = new FileOutputStream(modelFile);
                byte[] buffer = new byte[2048];
                int bytesRead = -1;
                while ((bytesRead = inputStream.read(buffer)) != -1) {
                    outputStream.write(buffer, 0, bytesRead);
                }
                inputStream.close();
                outputStream.close();
            }
            module = LiteModuleLoader.load(modelFile.getAbsolutePath());
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @OptIn(markerClass = ExperimentalGetImage.class) void analyzeImage(ImageProxy image, int rotation) {
        try {
            Tensor inputTensor = TensorImageUtils.imageYUV420CenterCropToFloat32Tensor(image.getImage(), rotation, 1024, 1024, TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
            Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

            Log.v("ZackTensor", inputTensor.toString());
            Log.v("ZackTensor", outputTensor.toString());
            float[] scores = outputTensor.getDataAsFloatArray();
            float maxScore = -Float.MAX_VALUE;
            Log.v("ZackTensor", String.valueOf(maxScore));
            int maxScoreIdx = -1;
            for (int i = 0; i < scores.length; i++) {
                if (scores[i] > maxScore) {
                    maxScore = scores[i];
                    maxScoreIdx = i;
                }
            }
            String classResult = classes.get(maxScoreIdx);

            Log.v("ZackTensor", "Detected - " + classResult);
            runOnUiThread(new Runnable() {
                @Override
                public void run() {
                    textView.setText(classResult);
                }
            });
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    List<String> LoadClasses(String fileName) {
        List<String> classes = new ArrayList<>();
        try {
            BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(fileName)));
            String line;
            while ((line = br.readLine()) != null) {
                classes.add(line);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return classes;
    }
}

我尝试重新导出模型,但没有成功。 我尝试重塑传入的图像,但要么我没有正确执行,要么它不起作用。 我在 Google 和 Stackoverflow 上进行了大量搜索,但没有成功。

谢谢!

pytorch object-detection torchvision
1个回答
0
投票

对于我的应用程序,yolo 模型的输出与相机输入张量不同,因为输出包含置信度分数和边界框。它只是图像的“描述”,如果没有图像,你无法在图像上可视化它们,因为 yolo 分割模型的情况并非如此,它的输出可以被解释为查看每个像素(但在 Android 应用程序中缩小到 160x160px) .

对输出的进一步分析与我刚刚回答的另一个类似,因此可以在这里看到:如何解释 YOLOv8 Web 模型的输出张量

© www.soinside.com 2019 - 2024. All rights reserved.