联邦学习实现代码显示一个RuntimeError: all elements of input should be between 0 and 1

问题描述 投票:0回答:1

` 进口手电筒 将 torch.nn 导入为 nn 导入 torch.optim 作为优化 从 torch.utils.data 导入 DataLoader,数据集 将 numpy 导入为 np 从 sklearn.datasets 导入 load_breast_cancer 从 sklearn.model_selection 导入 train_test_split

# Define the deep neural network model
class DNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(DNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.sigmoid(out)
        return out

# Load the breast cancer dataset
data = load_breast_cancer()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the number of training rounds and the number of clients
num_rounds = 100
num_clients = 2
batch_size = 10

# Split the data into equal chunks for each client
X_splits = np.array_split(X_train, num_clients)
y_splits = np.array_split(y_train, num_clients)

# Define the loss function and optimizer
criterion = nn.BCELoss()

# Perform federated learning
global_model = DNN(X_train.shape[1], 16, 1)
optimizer = optim.SGD(global_model.parameters(), lr=.01)

for i in range(num_rounds):
    local_models = []
    for j in range(num_clients):
        # Create a local model by copying the current global model
        local_model = DNN(X_train.shape[1], 16, 1)
        local_model.load_state_dict(global_model.state_dict())

        # Create a dataloader for the local client's data
        local_X = torch.tensor(X_splits[j], dtype=torch.float32)
        local_y = torch.tensor(y_splits[j], dtype=torch.float32)
        local_dataset = torch.utils.data.TensorDataset(local_X, local_y)
        local_dataloader = DataLoader(local_dataset, batch_size=batch_size, shuffle=True)

        # Train the local model
        local_optimizer = optim.SGD(local_model.parameters(), lr=0.1)
        for inputs, labels in local_dataloader:
            local_optimizer.zero_grad()
            outputs = local_model(inputs)
            loss = criterion(outputs, labels.view(-1, 1))
            loss.backward()
            local_optimizer.step()

        # Add the trained local model to the list of local models
        local_models.append(local_model)

    # Aggregate the local models to create a global model
    with torch.no_grad():
        for global_param, local_params in zip(global_model.parameters(), zip(*[local_model.parameters() for local_model in local_models])):
            global_param.data += torch.stack(local_params).sum(0) / num_clients

    # Evaluate the global model on the train dataset
    global_model.eval()
    with torch.no_grad():
        global_outputs = global_model(torch.tensor(X_train, dtype=torch.float32))
        global_loss = criterion(global_outputs, torch.tensor(y_train, dtype=torch.float32).view(-1, 1))
        global_pred = (global_outputs > 0.5).int().numpy().flatten()
        accuracy = np.mean(global_pred == y_train)
        print(f"Round {i}, train accuracy:{accuracy}")

`

代码在 num_rounds=96 之前完美运行,但是当 numround 大于或等于 97 时,它会显示错误:

` RuntimeError Traceback(最后一次调用)

() 79 与 torch.no_grad(): 80 global_outputs = global_model(torch.张量(X_train, dtype=torch.float32)) ---> 81 global_loss = criterion(global_outputs, torch.tensor(y_train, dtype=torch.float32).view(-1, 1)) 82 global_pred =(global_outputs > 0.5).int().numpy().flatten() 83 精度 = np.mean(global_pred == y_train)

2 帧

/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py in binary_cross_entropy(输入、目标、权重、大小平均、减少、减少) 第3093章 3094 -> 3095 返回 torch._C._nn.binary_cross_entropy(输入、目标、权重、reduction_enum) 3096 3097

RuntimeError:输入的所有元素都应该在0到1之间 `

deep-learning pytorch runtime-error pytorch-lightning federated-learning
1个回答
0
投票

似乎你的数据加载器在所有情况下都没有返回落在所需范围内的标签(因为可以安全地假设 sigmoid 激活函数的输出确实落在这个范围内,当然你也可以仔细检查。我建议检查与断言的一致性:

for inputs,labels in local_dataloader:
    ...
    assert labels.max() <= 1 and labels.min() >= 0, "Labels violate assumed range"
    assert outputs.max() < 1 and outputs.max() > 0, "Inputs violate assumed range"
    loss = criterion(outputs, labels.view(-1, 1))
    ...
© www.soinside.com 2019 - 2024. All rights reserved.