我正在尝试在java中的android应用程序中实现自定义tflite对象检测模型。当我粘贴模型和标签图时,出现此错误:
EE/AndroidRuntime: FATAL EXCEPTION: main
Process: com.soumio.inceptiontutorial, PID: 21661
java.lang.IllegalArgumentException: Cannot copy from a TensorFlowLite tensor (StatefulPartitionedCall:3) with shape [1, 10, 4] to a Java object with shape [1, 6].
at org.tensorflow.lite.Tensor.throwIfDstShapeIsIncompatible(Tensor.java:485)
at org.tensorflow.lite.Tensor.copyTo(Tensor.java:255)
at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:216)
at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:374)
at org.tensorflow.lite.Interpreter.run(Interpreter.java:332)
at com.soumio.inceptiontutorial.Classify$3.onClick(Classify.java:175)
at android.view.View.performClick(View.java:6659)
at android.view.View.performClickInternal(View.java:6631)
at android.view.View.access$3100(View.java:790)
at android.view.View$PerformClick.run(View.java:26187)
at android.os.Handler.handleCallback(Handler.java:907)
at android.os.Handler.dispatchMessage(Handler.java:105)
at android.os.Looper.loop(Looper.java:216)
at android.app.ActivityThread.main(ActivityThread.java:7625)
at java.lang.reflect.Method.invoke(Native Method)
at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:524)
at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:987)
我的分类器类:
public class Classify extends AppCompatActivity {
// presets for rgb conversion
private static final int RESULTS_TO_SHOW = 3;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;
// options for model interpreter
private final Interpreter.Options tfliteOptions = new Interpreter.Options();
// tflite graph
private Interpreter tflite;
// holds all the possible labels for model
private List<String> labelList;
// holds the selected image data as bytes
private ByteBuffer imgData = null;
// holds the probabilities of each label for non-quantized graphs
private float[][] labelProbArray = null;
// holds the probabilities of each label for quantized graphs
private byte[][] labelProbArrayB = null;
// array that holds the labels with the highest probabilities
private String[] topLables = null;
// array that holds the highest probabilities
private String[] topConfidence = null;
// selected classifier information received from extras
private String chosen;
private boolean quant;
// input image dimensions for the Inception Model
private int DIM_IMG_SIZE_X = 640;
private int DIM_IMG_SIZE_Y = 640;
private int DIM_PIXEL_SIZE = 3;
// int array to hold image data
private int[] intValues;
// activity elements
private ImageView selected_image;
private Button classify_button;
private Button back_button;
private TextView label1;
private TextView label2;
private TextView label3;
private TextView Confidence1;
private TextView Confidence2;
private TextView Confidence3;
// priority queue that will hold the top results from the CNN
private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
new PriorityQueue<>(
RESULTS_TO_SHOW,
new Comparator<Map.Entry<String, Float>>() {
@Override
public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
return (o1.getValue()).compareTo(o2.getValue());
}
});
@Override
protected void onCreate(Bundle savedInstanceState) {
// get all selected classifier data from classifiers
chosen = (String) getIntent().getStringExtra("chosen");
quant = (boolean) getIntent().getBooleanExtra("quant", false);
// initialize array that holds image data
intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
super.onCreate(savedInstanceState);
//initilize graph and labels
try{
tflite = new Interpreter(loadModelFile(), tfliteOptions);
labelList = loadLabelList();
} catch (Exception ex){
ex.printStackTrace();
}
// initialize byte array. The size depends if the input data needs to be quantized or not
if(quant){
imgData =
ByteBuffer.allocateDirect(
DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
} else {
imgData =
ByteBuffer.allocateDirect(
4 * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
}
imgData.order(ByteOrder.nativeOrder());
// initialize probabilities array. The datatypes that array holds depends if the input data needs to be quantized or not
if(quant){
labelProbArrayB= new byte[1][labelList.size()];
} else {
labelProbArray = new float[1][labelList.size()];
}
setContentView(R.layout.activity_classify);
// labels that hold top three results of CNN
label1 = (TextView) findViewById(R.id.label1);
label2 = (TextView) findViewById(R.id.label2);
label3 = (TextView) findViewById(R.id.label3);
// displays the probabilities of top labels
Confidence1 = (TextView) findViewById(R.id.Confidence1);
Confidence2 = (TextView) findViewById(R.id.Confidence2);
Confidence3 = (TextView) findViewById(R.id.Confidence3);
// initialize imageView that displays selected image to the user
selected_image = (ImageView) findViewById(R.id.selected_image);
// initialize array to hold top labels
topLables = new String[RESULTS_TO_SHOW];
// initialize array to hold top probabilities
topConfidence = new String[RESULTS_TO_SHOW];
// allows user to go back to activity to select a different image
back_button = (Button)findViewById(R.id.back_button);
back_button.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
Intent i = new Intent(Classify.this, ChooseModel.class);
startActivity(i);
}
});
// classify current dispalyed image
classify_button = (Button)findViewById(R.id.classify_image);
classify_button.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
// get current bitmap from imageView
Bitmap bitmap_orig = ((BitmapDrawable)selected_image.getDrawable()).getBitmap();
// resize the bitmap to the required input size to the CNN
Bitmap bitmap = getResizedBitmap(bitmap_orig, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y);
// convert bitmap to byte array
convertBitmapToByteBuffer(bitmap);
// pass byte data to the graph
if(quant){
tflite.run(imgData, labelProbArrayB);
} else {
tflite.run(imgData, labelProbArray);
}
// display the results
printTopKLabels();
}
});
// get image from previous activity to show in the imageView
Uri uri = (Uri)getIntent().getParcelableExtra("resID_uri");
try {
Bitmap bitmap = MediaStore.Images.Media.getBitmap(getContentResolver(), uri);
selected_image.setImageBitmap(bitmap);
// not sure why this happens, but without this the image appears on its side
selected_image.setRotation(selected_image.getRotation() + 90);
} catch (IOException e) {
e.printStackTrace();
}
}
// loads tflite grapg from file
private MappedByteBuffer loadModelFile() throws IOException {
AssetFileDescriptor fileDescriptor = this.getAssets().openFd(chosen);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
// converts bitmap to byte array which is passed in the tflite graph
private void convertBitmapToByteBuffer(Bitmap bitmap) {
if (imgData == null) {
return;
}
imgData.rewind();
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
// loop through all pixels
int pixel = 0;
for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
final int val = intValues[pixel++];
// get rgb values from intValues where each int holds the rgb values for a pixel.
// if quantized, convert each rgb value to a byte, otherwise to a float
if(quant){
imgData.put((byte) ((val >> 16) & 0xFF));
imgData.put((byte) ((val >> 8) & 0xFF));
imgData.put((byte) (val & 0xFF));
} else {
imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
}
}
}
}
// loads the labels from the label txt file in assets into a string array
private List<String> loadLabelList() throws IOException {
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
new BufferedReader(new InputStreamReader(this.getAssets().open("labelmap.txt")));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
reader.close();
return labelList;
}
// print the top labels and respective confidences
private void printTopKLabels() {
// add all results to priority queue
for (int i = 0; i < labelList.size(); ++i) {
if(quant){
sortedLabels.add(
new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArrayB[0][i] & 0xff) / 255.0f));
} else {
sortedLabels.add(
new AbstractMap.SimpleEntry<>(labelList.get(i), labelProbArray[0][i]));
}
if (sortedLabels.size() > RESULTS_TO_SHOW) {
sortedLabels.poll();
}
}
// get top results from priority queue
final int size = sortedLabels.size();
for (int i = 0; i < size; ++i) {
Map.Entry<String, Float> label = sortedLabels.poll();
topLables[i] = label.getKey();
topConfidence[i] = String.format("%.0f%%",label.getValue()*100);
}
// set the corresponding textviews with the results
label1.setText("1. "+topLables[2]);
label2.setText("2. "+topLables[1]);
label3.setText("3. "+topLables[0]);
Confidence1.setText(topConfidence[2]);
Confidence2.setText(topConfidence[1]);
Confidence3.setText(topConfidence[0]);
}
// resizes bitmap to given dimensions
public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) {
int width = bm.getWidth();
int height = bm.getHeight();
float scaleWidth = ((float) newWidth) / width;
float scaleHeight = ((float) newHeight) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
Bitmap resizedBitmap = Bitmap.createBitmap(
bm, 0, 0, width, height, matrix, false);
return resizedBitmap;
}
}
我的模型和标签图: https://www.pastefile.com/vpg57x https://www.pastefile.com/ncfyht 我在另一个 stackoverflow 问题中尝试了解决方案,该人说我需要更改列表。 所以我删除了 4:
imgData = ByteBuffer.allocateDirect(4 * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
但是如果我这样做,我会得到 BufferOverflow 异常。有人可以帮我解决这个问题吗?
您的自定义模型会生成形状为 [1, 10, 4] 而不是形状 [1, 6] 的输出张量。目标检测模型通常具有多个输出,而不是生成一个输出张量。自定义对象检测模型的典型输出签名有四个输出,例如 this。
请确保首先了解您的自定义对象检测模型的输出签名。
您解决了该错误吗,我遇到了同样的错误(无法从形状为 [1, 25200, 7] 的 TensorFlowLite 张量 (StatefulPartitionedCall_1:0) 复制到形状为 [1, 20, 20, 35] 的 Java 对象.)