我在阅读C ++中的MNIST database of handwritten digits时遇到了麻烦。
它是二进制格式,我知道如何阅读,但我不知道MNIST的确切格式。
因此,我想问一下那些读过MNIST数据有关MNIST数据格式的人,你对如何在C ++中读取这些数据有什么建议吗?
我最近对MNIST数据做了一些工作。以下是我用Java编写的一些代码,这些代码应该很容易移植到:
import net.vivin.digit.DigitImage;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Created by IntelliJ IDEA.
* User: vivin
* Date: 11/11/11
* Time: 10:07 AM
*/
public class DigitImageLoadingService {
private String labelFileName;
private String imageFileName;
/** the following constants are defined as per the values described at http://yann.lecun.com/exdb/mnist/ **/
private static final int MAGIC_OFFSET = 0;
private static final int OFFSET_SIZE = 4; //in bytes
private static final int LABEL_MAGIC = 2049;
private static final int IMAGE_MAGIC = 2051;
private static final int NUMBER_ITEMS_OFFSET = 4;
private static final int ITEMS_SIZE = 4;
private static final int NUMBER_OF_ROWS_OFFSET = 8;
private static final int ROWS_SIZE = 4;
public static final int ROWS = 28;
private static final int NUMBER_OF_COLUMNS_OFFSET = 12;
private static final int COLUMNS_SIZE = 4;
public static final int COLUMNS = 28;
private static final int IMAGE_OFFSET = 16;
private static final int IMAGE_SIZE = ROWS * COLUMNS;
public DigitImageLoadingService(String labelFileName, String imageFileName) {
this.labelFileName = labelFileName;
this.imageFileName = imageFileName;
}
public List<DigitImage> loadDigitImages() throws IOException {
List<DigitImage> images = new ArrayList<DigitImage>();
ByteArrayOutputStream labelBuffer = new ByteArrayOutputStream();
ByteArrayOutputStream imageBuffer = new ByteArrayOutputStream();
InputStream labelInputStream = this.getClass().getResourceAsStream(labelFileName);
InputStream imageInputStream = this.getClass().getResourceAsStream(imageFileName);
int read;
byte[] buffer = new byte[16384];
while((read = labelInputStream.read(buffer, 0, buffer.length)) != -1) {
labelBuffer.write(buffer, 0, read);
}
labelBuffer.flush();
while((read = imageInputStream.read(buffer, 0, buffer.length)) != -1) {
imageBuffer.write(buffer, 0, read);
}
imageBuffer.flush();
byte[] labelBytes = labelBuffer.toByteArray();
byte[] imageBytes = imageBuffer.toByteArray();
byte[] labelMagic = Arrays.copyOfRange(labelBytes, 0, OFFSET_SIZE);
byte[] imageMagic = Arrays.copyOfRange(imageBytes, 0, OFFSET_SIZE);
if(ByteBuffer.wrap(labelMagic).getInt() != LABEL_MAGIC) {
throw new IOException("Bad magic number in label file!");
}
if(ByteBuffer.wrap(imageMagic).getInt() != IMAGE_MAGIC) {
throw new IOException("Bad magic number in image file!");
}
int numberOfLabels = ByteBuffer.wrap(Arrays.copyOfRange(labelBytes, NUMBER_ITEMS_OFFSET, NUMBER_ITEMS_OFFSET + ITEMS_SIZE)).getInt();
int numberOfImages = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, NUMBER_ITEMS_OFFSET, NUMBER_ITEMS_OFFSET + ITEMS_SIZE)).getInt();
if(numberOfImages != numberOfLabels) {
throw new IOException("The number of labels and images do not match!");
}
int numRows = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, NUMBER_OF_ROWS_OFFSET, NUMBER_OF_ROWS_OFFSET + ROWS_SIZE)).getInt();
int numCols = ByteBuffer.wrap(Arrays.copyOfRange(imageBytes, NUMBER_OF_COLUMNS_OFFSET, NUMBER_OF_COLUMNS_OFFSET + COLUMNS_SIZE)).getInt();
if(numRows != ROWS && numRows != COLUMNS) {
throw new IOException("Bad image. Rows and columns do not equal " + ROWS + "x" + COLUMNS);
}
for(int i = 0; i < numberOfLabels; i++) {
int label = labelBytes[OFFSET_SIZE + ITEMS_SIZE + i];
byte[] imageData = Arrays.copyOfRange(imageBytes, (i * IMAGE_SIZE) + IMAGE_OFFSET, (i * IMAGE_SIZE) + IMAGE_OFFSET + IMAGE_SIZE);
images.add(new DigitImage(label, imageData));
}
return images;
}
}
int reverseInt (int i)
{
unsigned char c1, c2, c3, c4;
c1 = i & 255;
c2 = (i >> 8) & 255;
c3 = (i >> 16) & 255;
c4 = (i >> 24) & 255;
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
}
void read_mnist(/*string full_path*/)
{
ifstream file (/*full_path*/"t10k-images-idx3-ubyte.gz");
if (file.is_open())
{
int magic_number=0;
int number_of_images=0;
int n_rows=0;
int n_cols=0;
file.read((char*)&magic_number,sizeof(magic_number));
magic_number= reverseInt(magic_number);
file.read((char*)&number_of_images,sizeof(number_of_images));
number_of_images= reverseInt(number_of_images);
file.read((char*)&n_rows,sizeof(n_rows));
n_rows= reverseInt(n_rows);
file.read((char*)&n_cols,sizeof(n_cols));
n_cols= reverseInt(n_cols);
for(int i=0;i<number_of_images;++i)
{
for(int r=0;r<n_rows;++r)
{
for(int c=0;c<n_cols;++c)
{
unsigned char temp=0;
file.read((char*)&temp,sizeof(temp));
}
}
}
}
}
为了它的价值,我调整了@ mrgloom的代码:
uchar** read_mnist_images(string full_path, int& number_of_images, int& image_size) {
auto reverseInt = [](int i) {
unsigned char c1, c2, c3, c4;
c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
};
typedef unsigned char uchar;
ifstream file(full_path, ios::binary);
if(file.is_open()) {
int magic_number = 0, n_rows = 0, n_cols = 0;
file.read((char *)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
if(magic_number != 2051) throw runtime_error("Invalid MNIST image file!");
file.read((char *)&number_of_images, sizeof(number_of_images)), number_of_images = reverseInt(number_of_images);
file.read((char *)&n_rows, sizeof(n_rows)), n_rows = reverseInt(n_rows);
file.read((char *)&n_cols, sizeof(n_cols)), n_cols = reverseInt(n_cols);
image_size = n_rows * n_cols;
uchar** _dataset = new uchar*[number_of_images];
for(int i = 0; i < number_of_images; i++) {
_dataset[i] = new uchar[image_size];
file.read((char *)_dataset[i], image_size);
}
return _dataset;
} else {
throw runtime_error("Cannot open file `" + full_path + "`!");
}
}
uchar* read_mnist_labels(string full_path, int& number_of_labels) {
auto reverseInt = [](int i) {
unsigned char c1, c2, c3, c4;
c1 = i & 255, c2 = (i >> 8) & 255, c3 = (i >> 16) & 255, c4 = (i >> 24) & 255;
return ((int)c1 << 24) + ((int)c2 << 16) + ((int)c3 << 8) + c4;
};
typedef unsigned char uchar;
ifstream file(full_path, ios::binary);
if(file.is_open()) {
int magic_number = 0;
file.read((char *)&magic_number, sizeof(magic_number));
magic_number = reverseInt(magic_number);
if(magic_number != 2049) throw runtime_error("Invalid MNIST label file!");
file.read((char *)&number_of_labels, sizeof(number_of_labels)), number_of_labels = reverseInt(number_of_labels);
uchar* _dataset = new uchar[number_of_labels];
for(int i = 0; i < number_of_labels; i++) {
file.read((char*)&_dataset[i], 1);
}
return _dataset;
} else {
throw runtime_error("Unable to open file `" + full_path + "`!");
}
}
编辑:感谢@JürgenBrauer提醒我更正我的答案,虽然我已经及时修复了我的代码,但忘了更新答案。
以下代码来自caffe
,我做了一些更改并将其转换为cv::Mat
:
uint32_t swap_endian(uint32_t val) {
val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF);
return (val << 16) | (val >> 16);
}
void read_mnist_cv(const char* image_filename, const char* label_filename){
// Open files
std::ifstream image_file(image_filename, std::ios::in | std::ios::binary);
std::ifstream label_file(label_filename, std::ios::in | std::ios::binary);
// Read the magic and the meta data
uint32_t magic;
uint32_t num_items;
uint32_t num_labels;
uint32_t rows;
uint32_t cols;
image_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
if(magic != 2051){
cout<<"Incorrect image file magic: "<<magic<<endl;
return;
}
label_file.read(reinterpret_cast<char*>(&magic), 4);
magic = swap_endian(magic);
if(magic != 2049){
cout<<"Incorrect image file magic: "<<magic<<endl;
return;
}
image_file.read(reinterpret_cast<char*>(&num_items), 4);
num_items = swap_endian(num_items);
label_file.read(reinterpret_cast<char*>(&num_labels), 4);
num_labels = swap_endian(num_labels);
if(num_items != num_labels){
cout<<"image file nums should equal to label num"<<endl;
return;
}
image_file.read(reinterpret_cast<char*>(&rows), 4);
rows = swap_endian(rows);
image_file.read(reinterpret_cast<char*>(&cols), 4);
cols = swap_endian(cols);
cout<<"image and label num is: "<<num_items<<endl;
cout<<"image rows: "<<rows<<", cols: "<<cols<<endl;
char label;
char* pixels = new char[rows * cols];
for (int item_id = 0; item_id < num_items; ++item_id) {
// read image pixel
image_file.read(pixels, rows * cols);
// read label
label_file.read(&label, 1);
string sLabel = std::to_string(int(label));
cout<<"lable is: "<<sLabel<<endl;
// convert it to cv Mat, and show it
cv::Mat image_tmp(rows,cols,CV_8UC1,pixels);
// resize bigger for showing
cv::resize(image_tmp, image_tmp, cv::Size(100, 100));
cv::imshow(sLabel, image_tmp);
cv::waitKey(0);
}
delete[] pixels;
}
用法(我已经简化了代码,ommited头和命名空间):
string base_dir = "/home/xy/caffe-master/data/mnist/";
string img_path = base_dir + "train-images-idx3-ubyte";
string label_path = base_dir + "train-labels-idx1-ubyte";
read_mnist_cv(img_path.c_str(), label_path.c_str());
输出如下:
通过使用in()
,您可以读取所需的任何大小的数据。
const int MAXN = 6e4 + 7;
unsigned int image[MAXN][30][30];
unsigned int num, magic, rows, cols;
unsigned int label[MAXN];
unsigned int in(ifstream& icin, unsigned int size) {
unsigned int ans = 0;
for (int i = 0; i < size; i++) {
unsigned char x;
icin.read((char*)&x, 1);
unsigned int temp = x;
ans <<= 8;
ans += temp;
}
return ans;
}
void input() {
ifstream icin;
icin.open("train-images.idx3-ubyte", ios::binary);
magic = in(icin, 4), num = in(icin, 4), rows = in(icin, 4), cols = in(icin, 4);
for (int i = 0; i < num; i++) {
for (int x = 0; x < rows; x++) {
for (int y = 0; y < cols; y++) {
image[i][x][y] = in(icin, 1);
}
}
}
icin.close();
icin.open("train-labels.idx1-ubyte", ios::binary);
magic = in(icin, 4), num = in(icin, 4);
for (int i = 0; i < num; i++) {
label[i] = in(icin, 1);
}
}