Android Tensorflow Lite API 中的 Yolo 姿态估计输出处理

问题描述 投票:0回答:1

我正在尝试将 yolo11n-pose_float16.tflite 模型实现到 Android Kotlin 项目。但我无法在输出上获得正确的关键点,因为我做错了一些事情。我只想在屏幕上绘制包含关键点和身体部位连接的骨架。我应该如何将输出数组提取到我的 Person 对象?

data class Person(
    var keyPoints: MutableList<KeyPoint>,
    val score: Float
)

enum class BodyPart {
    NOSE, LEFT_EYE, RIGHT_EYE, LEFT_EAR, RIGHT_EAR,
    LEFT_SHOULDER, RIGHT_SHOULDER, LEFT_ELBOW, RIGHT_ELBOW,
    LEFT_WRIST, RIGHT_WRIST, LEFT_HIP, RIGHT_HIP,
    LEFT_KNEE, RIGHT_KNEE, LEFT_ANKLE, RIGHT_ANKLE
}

data class KeyPoint(val bodyPart: BodyPart, var coordinate: PointF, val score: Float)

// Here is my function:

private fun processOutput(output: FloatArray): List<Person> {
  
   val persons = mutableListOf<Person>()
   // How should I process output???
   // output shape is [1,56,8400]
                
}
android tensorflow tensorflow-lite yolo pose-estimation
1个回答
0
投票

我明白了,这里对 1,56,8400 形状输出进行后处理

for (c in 0 until numElements) {
            val cnf = array[c + numElements * 4]
            if (cnf > CONFIDENCE_THRESHOLD) {
                val cx = array[c]
                val cy = array[c + numElements]
                val w = array[c + numElements * 2]
                val h = array[c + numElements * 3]
                val x1 = cx - (w / 2F)
                val y1 = cy - (h / 2F)
                val x2 = cx + (w / 2F)
                val y2 = cy + (h / 2F)
                if (x1 <= 0F || x1 >= tensorWidth) continue
                if (y1 <= 0F || y1 >= tensorHeight) continue
                if (x2 <= 0F || x2 >= tensorWidth) continue
                if (y2 <= 0F || y2 >= tensorHeight) continue

                val keypoints = mutableListOf<KeyPoint>()
                for (k in 0 until 17) {
                    var kx = array[c + numElements * (5 + k * 3)]
                    var ky = array[c + numElements * (5 + k * 3 + 1)]

                    kx /= tensorWidth
                    ky /= tensorHeight

                    keypoints.add(KeyPoint(kx, ky))
                }

                boundingBoxes.add(
                    BoundingBox(
                        x1 = x1, y1 = y1, x2 = x2, y2 = y2,
                        cx = cx, cy = cy, w = w, h = h, cnf = cnf,
                        keyPoints = keypoints
                    )
                )
            }
        }

这里是身体连接,

private val edges = listOf(
        Pair(0, 1), Pair(0, 2), // Nose ↔ Left Eye, Nose ↔ Right Eye
        Pair(1, 3), Pair(2, 4), // Left Eye ↔ Left Ear, Right Eye ↔ Right Ear
        Pair(0, 5), Pair(0, 6), // Nose ↔ Left Shoulder, Nose ↔ Right Shoulder
        Pair(5, 7), Pair(7, 9), // Left Shoulder ↔ Left Elbow, Left Elbow ↔ Left Wrist
        Pair(6, 8), Pair(8, 10), // Right Shoulder ↔ Right Elbow, Right Elbow ↔ Right Wrist
        Pair(5, 6), // Left Shoulder ↔ Right Shoulder
        Pair(5, 11), Pair(6, 12), // Left Shoulder ↔ Left Hip, Right Shoulder ↔ Right Hip
        Pair(11, 12), // Left Hip ↔ Right Hip
        Pair(11, 13), Pair(13, 15), // Left Hip ↔ Left Knee, Left Knee ↔ Left Ankle
        Pair(12, 14), Pair(14, 16) // Right Hip ↔ Right Knee, Right Knee ↔ Right Ankle
    )
© www.soinside.com 2019 - 2024. All rights reserved.