mnist和fashion-mnist时,我的准确度为80-90%。但是,在使用cifar-10进行测试时,精度下降到约60%,而我预计为70-80%。 我正在使用数据集的整个测试集进行评估。 我已经仔细准备了数据,无法识别任何问题。在联合学习环境中与
cifar-10一起工作时,我应该考虑任何具体的事情吗? 下载和加载数据
def get_data(args):
if args.dataset == 'mnist' or args.dataset == 'fashion-mnist':
data_file = f"{args.data_path}/{args.dataset}.npz"
dataset = np.load(data_file) #데이터 불러오기
train_X, train_y = dataset['x_train'], dataset['y_train'].astype(np.int64)
test_X, test_y = dataset['x_test'], dataset['y_test'].astype(np.int64)
if args.dataset == 'fashion-mnist':
train_X = np.reshape(train_X, (-1, 1, 28, 28))
test_X = np.reshape(test_X, (-1, 1, 28, 28))
else:
train_X = np.expand_dims(train_X, 1)
test_X = np.expand_dims(test_X, 1)
elif args.dataset == 'cifar10':
# Only load data, transformation done later
trainset = torchvision.datasets.CIFAR10(root=f"{args.data_path}/{args.dataset}/",
train=True)
# download = True,
train_X = trainset.data.transpose([0, 3, 1, 2])
train_y = np.array(trainset.targets)
testset = torchvision.datasets.CIFAR10(root=f"{args.data_path}/{args.dataset}/",
train=False)
test_X = testset.data.transpose([0, 3, 1, 2])
test_y = np.array(testset.targets)
else:
raise ValueError("Unknown dataset")
return train_X, train_y, test_X, test_y
Mini批处理数据加载器
def data_loader(dataset, inputs, targets, batch_size, is_train=True):
def cifar10_norm(x):
x -= CIFAR10_TRAIN_MEAN
x /= CIFAR10_TRAIN_STD
return x
def no_norm(x):
return x
if dataset == 'cifar10':
norm_func = cifar10_norm
else:
norm_func = no_norm
assert inputs.shape[0] == targets.shape[0]
n_examples = inputs.shape[0]
sample_rate = batch_size / n_examples
num_blocks = int(n_examples / batch_size)
if is_train:
for i in range(num_blocks):
mask = np.random.rand(n_examples) < sample_rate
if np.sum(mask) != 0:
yield (norm_func(inputs[mask].astype(np.float32) / 255.),
targets[mask]) # 픽셀값을 0 ~ 1로 정규화
else:
for i in range(num_blocks):
yield (norm_func(inputs[i * batch_size: (i+1) * batch_size].astype(np.float32) / 255.),
targets[i * batch_size: (i+1) * batch_size])
if num_blocks * batch_size != n_examples:
yield (norm_func(inputs[num_blocks * batch_size:].astype(np.float32) / 255.),
targets[num_blocks * batch_size:])
您还可以共享培训循环和模型详细信息吗?纯粹基于数据的很难分辨出来。