我有一个关于 Java 和机器学习的小问题。我用 Keras 训练了一个模型,当我使用 Python 预测图像时,它按预期工作。
训练模型的形状是 [ width, height, RGB ].
但是当我在 Java 中加载图像时,我得到了 [RGB、宽度、高度] - 所以我尝试使用 .reshape() 来改变形状,但我显然搞砸了一些东西,因为之后所有的预测都是错误的:
ResizeImageTransform rit = new ResizeImageTransform(128, 128);
NativeImageLoader loader = new NativeImageLoader(128, 128, 3, rit);
INDArray features = loader.asMatrix(f); // GIVES ME A SHAPE OF 1, 3, 128, 128
features = features.reshape(1, 128, 128, 3); // GIVES ME THE SHAPE 1, 128, 128, 3 AS NEEDED
INDArray[] prediction = model.output(features); // all predictions wrong
我不是 Java 开发人员,我试图与文档相处,但在这里我显然忽略了一些东西。也许这里有人可以提示我做错了什么......
所以现在我至少有 136 张测试集的图像被标记了。 Python 版本标记了 195 张图像...
所以我猜归一化是个问题。我训练模型:
train = ImageDataGenerator(rotation_range=5, horizontal_flip=True, vertical_flip=True, rescale=1/255)
我用
X *= 1/255
在测试脚本中的预测之前。
在 Java 中我使用
features = features.permute(0, 2, 3, 1);
DataNormalization scalar = new ImagePreProcessingScaler(0, 1);
scalar.transform(features);
但我不确定规范化是否是问题所在,或者我是否已经为 .permute() 调整了参数...
有什么建议吗?
这就是模型的全部训练方式:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.applications import ResNet152V2
# GENERAL WIDTH AND HIGHT FOR THE IMAGES
WIDTH = 128
HEIGHT = 128
train = ImageDataGenerator(rotation_range=5, horizontal_flip=True, vertical_flip=True, rescale=1/255)
valid = ImageDataGenerator(rescale=1/255)
train_set = train.flow_from_directory('images_train/', target_size=(WIDTH,HEIGHT), batch_size=64, class_mode='categorical')
test_set = valid.flow_from_directory('images_test/', target_size=(WIDTH,HEIGHT), batch_size=64, class_mode='categorical')
resnet = ResNet50V2(include_top=False, weights='imagenet', input_shape=(WIDTH,HEIGHT,3))
for layer in resnet.layers:
layer.trainable = False
x = tf.keras.layers.Flatten()(resnet.output)
x = tf.keras.layers.Dense(512, activation='relu')(x)
n_classes = len(train_set.class_indices)
predictions = tf.keras.layers.Dense(n_classes, activation='softmax')(x)
model = tf.keras.Model(inputs=resnet.input, outputs=predictions)
model.compile(loss='categorical_crossentropy', optimizer="adam", metrics=['accuracy'])
hist = model.fit(train_set, epochs=20, validation_data=test_set)
model.save('resnet50v2.h5')
这就是我在 Python 中测试图像的代码:
ctr = 0
for root, dirs, files in os.walk(base_path):
for name in files:
image_path = os.path.join(root, name)
tmp = image_path.lower().split(".")
if tmp[-1] in ["jpg", "jpeg", "png", "bmp"]:
orig_image = Image.open(image_path)
if orig_image.mode != "RGB":
orig_image = orig_image.convert("RGB")
image = orig_image.resize((128, 128))
X = []
X.append(np.array(image.getdata()).reshape((128,128,3)))
X = np.array(X).astype('float64')
X *= 1/255
# PREDICT AND WRITE REPORT
pred = model.predict(X)
pred = np.rint(pred).astype("int32")
if(pred[0][1] != 1):
ctr += 1
print(f"{ctr} :: {image_path} == {pred[0]}")
这就是我用 Java 测试图像的代码:
int ctr = 0;
for (File f : listOfFiles) {
if (f.isFile()) {
ResizeImageTransform rit = new ResizeImageTransform(128, 128);
NativeImageLoader loader = new NativeImageLoader(128, 128, 3, rit);
INDArray features = null;
try{
features = loader.asMatrix(f); // GIVES ME A SHAPE OF 1, 3, 128, 128
}
catch(IOException ex){
continue;
}
features = features.permute(0, 2, 3, 1);
DataNormalization scalar = new ImagePreProcessingScaler(0, 1);
scalar.transform(features);
INDArray[] prediction = model.output(features);
// Get Class
double pred[] = prediction[0].toDoubleVector();
int predClass = 0;
for(int i = 0; i < pred.length; i++){
predClass = pred[i] > pred[predClass] ? i : predClass;
}
if(predClass != 1){
ctr++;
System.out.println(f.getName());
System.out.println(ctr + ") PORN FOUND :: " + predClass);
}
}
}
蟒蛇
DSC_3767.jpg
A B C
[[[[0.9254902 0.88627451 0.87843137]
[0.9254902 0.88627451 0.87843137]
[0.9254902 0.88627451 0.87843137]
...
Java
DSC_3767.jpg
C B A
[[[[ 0.8784, 0.8863, 0.9255],
[ 0.8784, 0.8863, 0.9255],
[ 0.8784, 0.8863, 0.9255],
...
我所要做的就是交换 C 和 A,模型就可以正常工作了。我只是不明白如何。