LOWCIFAR -10精度(60%)在分散的联邦学习(DFL)中 - 寻求改进

问题描述 投票:0回答:1
i在分散的联合学习(DFL)环境中实施了算法。当我尝试使用

mnistfashion-mnist时,我的准确度为80-90%。但是,在使用cifar-10进行测试时,精度下降到约60%,而我预计为70-80%。 我正在使用数据集的整个测试集进行评估。 我已经仔细准备了数据,无法识别任何问题。在联合学习环境中与

cifar-10

一起工作时,我应该考虑任何具体的事情吗? 下载和加载数据

  • def get_data(args): if args.dataset == 'mnist' or args.dataset == 'fashion-mnist': data_file = f"{args.data_path}/{args.dataset}.npz" dataset = np.load(data_file) #데이터 불러오기 train_X, train_y = dataset['x_train'], dataset['y_train'].astype(np.int64) test_X, test_y = dataset['x_test'], dataset['y_test'].astype(np.int64) if args.dataset == 'fashion-mnist': train_X = np.reshape(train_X, (-1, 1, 28, 28)) test_X = np.reshape(test_X, (-1, 1, 28, 28)) else: train_X = np.expand_dims(train_X, 1) test_X = np.expand_dims(test_X, 1) elif args.dataset == 'cifar10': # Only load data, transformation done later trainset = torchvision.datasets.CIFAR10(root=f"{args.data_path}/{args.dataset}/", train=True) # download = True, train_X = trainset.data.transpose([0, 3, 1, 2]) train_y = np.array(trainset.targets) testset = torchvision.datasets.CIFAR10(root=f"{args.data_path}/{args.dataset}/", train=False) test_X = testset.data.transpose([0, 3, 1, 2]) test_y = np.array(testset.targets) else: raise ValueError("Unknown dataset") return train_X, train_y, test_X, test_y
Mini批处理数据加载器
    def data_loader(dataset, inputs, targets, batch_size, is_train=True): def cifar10_norm(x): x -= CIFAR10_TRAIN_MEAN x /= CIFAR10_TRAIN_STD return x def no_norm(x): return x if dataset == 'cifar10': norm_func = cifar10_norm else: norm_func = no_norm assert inputs.shape[0] == targets.shape[0] n_examples = inputs.shape[0] sample_rate = batch_size / n_examples num_blocks = int(n_examples / batch_size) if is_train: for i in range(num_blocks): mask = np.random.rand(n_examples) < sample_rate if np.sum(mask) != 0: yield (norm_func(inputs[mask].astype(np.float32) / 255.), targets[mask]) # 픽셀값을 0 ~ 1로 정규화 else: for i in range(num_blocks): yield (norm_func(inputs[i * batch_size: (i+1) * batch_size].astype(np.float32) / 255.), targets[i * batch_size: (i+1) * batch_size]) if num_blocks * batch_size != n_examples: yield (norm_func(inputs[num_blocks * batch_size:].astype(np.float32) / 255.), targets[num_blocks * batch_size:])
您还可以共享培训循环和模型详细信息吗?纯粹基于数据的很难分辨出来。
    
python distributed federated
1个回答
0
投票
最新问题
© www.soinside.com 2019 - 2025. All rights reserved.