我用Pytorch和Ray调超参数但是代码运行出错

问题描述 投票:0回答:0

我想编写关于调整卷积神经网络超参数的python代码。 我根据这个页面写了一个代码(https://docs.ray.io/en/latest/tune/examples/includes/pbt_convnet_function_example.html) 但是,我认为我写的有错误。 我想安装它们,但我不知道我必须在哪里修理。 请帮助我...

我想调整很多超参数,比如配置。

我怎样才能解决这个问题,有没有更简单的方法来写这个?

import argparse
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

from ray import air, tune
from ray.air import session
from ray.air.checkpoint import Checkpoint
from ray.tune.schedulers import PopulationBasedTraining

class Net(nn.Module):
    def __init__(self, conv1_filter, conv1_kernel, conv2_filter, conv2_kernel, dense1_units):
        super().__init__()
        self.conv1 = nn.Conv2d(10, conv1_filter, kernel_size=conv1_kernel)
        self.conv2 = nn.Conv2d(conv1_filter, conv2_filter, kernel_size=conv2_kernel)
        self.pool = nn.MaxPool2d(2)
        self.act = nn.Tanh()
        self.lin = nn.Linear(conv2_filter * 7 * 7 , dense1_units)
        self.result = nn.Linear(dense1_units, 10)
        
    def forward(self, x):
        out = self.pool(self.act(self.conv1(x)))
        out = self.pool(self.act(self.conv2(out)))
        out = self.act(self.lin(out.view(-1, conv2_filter * 7 * 7)))
        out = self.result(out)
        return out

def training(config):
    step = 0
    train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, transform=transforms.ToTensor()), batch_size=config["batch_size"], shuffle=True)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.ToTensor()), batch_size=config["batch_size"], shuffle=True)
    model = Net(config["conv1_filter"], config["conv1_kernel"], config["conv2_filter"], config["conv2_kernel"], config["dense1_units"])
    optimizer = optim.SGD(
        model.parameters(),
        lr = config.get("learning_rate"),
    )
    
    if session.get_checkpoint():
        print("Loading from checkpoint.")
        loaded_checkpoint = session.get_checkpoint()
        with loaded_checkpoint.as_directory() as loaded_checkpoint_dir:
            path = os.path.join(loaded_checkpoint_dir, "checkpoint.pt")
            checkpoint = torch.load(path)
            model.load_state_dict(checkpoint["model"])
            step = checkpoint["step"]
    
    while True:
        train(model, optimizer, train_loader)
        acc = test(model, test_loader)
        checkpoint = None
        if step % 5 == 0:
            os.makedirs("my_model", exist_ok = True)
            torch.save(
                {
                    "step" : step,
                    "model" : model.state_dict()
                },
                "my_model/checkpoint.pt"
            )
            checkpoint = Checkpoint.from_directory("my_model")
        
        step += 1
        session.report({"mean_accuracy" : acc}, checkpoint = checkpoint)

def test_best_model(results: tune.ResultGrid):
    with results.get_best_result().checkpoint.as_directory() as best_checkpoint_path:
        best_model = Net()
        best_checkpoint = torch.load(
            os.path.join(best_checkpoint_path, "checkpoint.pt")
        )
        best_model.load_state_dict(best_checkpoint["model"])
        test_acc = test(best_model, get_data_loaders()[1])
        print("best model accuracy: ", test_acc)

if __name__ == "__main__":
    
    #하이퍼파라미터 생성
    config = {
        "batch_size" : tune.choice([2, 4, 8, 16]),
        "conv1_filter" : tune.randint(1, 10),
        "conv1_kernel" : tune.randint(3, 5),
        "conv2_filter" : tune.randint(1, 10),
        "conv2_kernel" : tune.randint(3, 5),
        "dense1_units" : tune.qlograndint(5, 8, 2),
        "learning_rate" : tune.loguniform(1e-5, 1e-1)
    }
    
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing"
    )
    args, _ = parser.parse_known_args()

    # __pbt_begin__
    scheduler = PopulationBasedTraining(
        time_attr="training_iteration",
        perturbation_interval=5,
        hyperparam_mutations = config
    )

    # __pbt_end__
    
    # __tune_begin__
    class CustomStopper(tune.Stopper):
        def __init__(self):
            self.should_stop = False

        def __call__(self, trial_id, result):
            max_iter = 5 if args.smoke_test else 100
            if not self.should_stop and result["mean_accuracy"] > 0.96:
                self.should_stop = True
            return self.should_stop or result["training_iteration"] >= max_iter

        def stop_all(self):
            return self.should_stop

    stopper = CustomStopper()

    tuner = tune.Tuner(
        training,
        run_config=air.RunConfig(
            name="pbt_test",
            stop=stopper,
            verbose=1,
            checkpoint_config=air.CheckpointConfig(
                checkpoint_score_attribute="mean_accuracy",
                num_to_keep=4,
            ),
        ),
        tune_config=tune.TuneConfig(
            scheduler=scheduler,
            metric="mean_accuracy",
            mode="max",
            num_samples=4,
        ),
        param_space=config
    )
    results = tuner.fit()
    # __tune_end__

    test_best_model(results)
deep-learning pytorch artificial-intelligence hyperparameters ray
© www.soinside.com 2019 - 2024. All rights reserved.