我想在Android设备上开发TensorFlow,到目前为止我用python和导出模型训练到Protobuf .pb
文件
在python上测试的.pb
文件,它的返回没有错误
......
graph = load_graph("./frozen_model.pb")
for op in graph.get_operations():
print(op.name)
with tf.Session(graph=graph) as sess:
tf_predik = graph.get_tensor_by_name("prefix/tf_pred:0")
tf_data = graph.get_tensor_by_name("prefix/tf_data:0")
img = np.invert(Image.open("7.png").convert('L')).ravel(); image = array(img).reshape(1, 28,28,1);
fd = {tf_data: image};
test_pred = sess.run(tf_predik, feed_dict=fd); temp = np.argmax(test_pred, axis=1); print(temp)
我在Xamarin Android上尝试:
using Org.Tensorflow.Contrib.Android;
.....
var assets = Android.App.Application.Context.Assets;
var inferenceInterface = new TensorFlowInferenceInterface(assets, "frozen_model.pb");
using (Stream inputSteam = this.Assets.Open("7.png"))
{
byte[] bytes = inputSteam.ReadAllBytes();// convert to byte array???
inferenceInterface.Feed("tf_data", bytes, bytes.Length);
inferenceInterface.Run(new [] { "tf_pred:0" });
inferenceInterface.Fetch("tf_pred:0", predictions);
....
}
我收到一个错误:
Java.Lang.IllegalArgumentException:期望arg [0]为float但提供uint8
预先感谢。
期望arg [0]为float,但提供了uint8
TensorFlowInferenceInterface.Feed
期待一个浮点数组,因此你需要转换基于资产的图像,将其文件编码(jpg | png | ...)解码为Bitmap并从中获取浮点数组。
public float[] AndroidBitmapToFloatArray(Bitmap bitmap)
{
// Assuming a square image to sample|process, adjust based upon your model requirements
const int sizeX = 255;
const int sizeY = 255;
float[] floatArray;
int[] intArray;
using (var sampleImage = Bitmap.CreateScaledBitmap(bitmap, sizeX, sizeY, false).Copy(Bitmap.Config.Argb8888, false))
{
floatArray = new float[sizeX * sizeY * 3];
intArray = new int[sizeX * sizeY];
sampleImage.GetPixels(intArray, 0, sizeX, 0, 0, sizeX, sizeY);
sampleImage.Recycle();
}
for (int i = 0; i < intArray.Length; ++i)
{
var intValue = intArray[i];
floatArray[i * 3 + 0] = ((intValue & 0xFF) - 104);
floatArray[i * 3 + 1] = (((intValue >> 8) & 0xFF) - 117);
floatArray[i * 3 + 2] = (((intValue >> 16) & 0xFF) - 123);
}
return floatArray;
}
float[] feedArray;
using (var imageAsset = Assets.Open("someimage"))
using (var bitmappAsset = BitmapFactory.DecodeStream(imageAsset))
{
feedArray = AndroidBitmapToFloatArray(bitmappAsset);
}
inferenceInterface.Feed("tf_data", feedArray, feedArray.Length);