foreach_worker 和 foreach_env 的正确使用方法

问题描述 投票:0回答:1

我对强化学习很陌生,无法理解它。我无法使用 PPO 更新批量数据的配置。

我正在使用自定义的 GYM 环境,并希望使用 PPO 和我以 torch DataLoader 形式加载的外部数据来训练它。 我正在使用 Python 3.11 和 Ray 2.40.0。以下是相关代码:

import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from torch.utils.data import DataLoader

train_dataset = MultimodalDataset(
    csv_file=config.TRAIN_CSV_PATH, max_images=config.MAX_IMAGES_RL
)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)

# Define PPO configuration
ppo_config = (
    PPOConfig()
    .training(gamma=0.9, lr=0.01)
    .environment(env="MultimodalSummarizationEnv", env_config=default_env_config)
    .framework("torch")
    .resources(num_gpus=0, num_cpus_per_worker=1)
)

# Create PPO trainer
trainer = ppo_config.build()

# Function to update worker environments
def update_env_config_and_reset(worker, new_env_config):
    worker.foreach_env(lambda env: env.reset(env_config=new_env_config))

# Training loop
for batch_idx, batch in enumerate(train_loader):
    # Prepare batch-specific env_config
    new_env_config = {
# new data for the batch_idx
    }

    # Update and reset environments for all workers
    trainer.workers.foreach_worker(
        lambda worker: update_env_config_and_reset(worker, new_env_config)
    )

    # Train PPO
    result = trainer.train()
ray.shutdown()

但是,在运行代码时,我在 foreach_worker 上收到如下错误:

‘function’对象没有属性‘foreach_worker’

请帮我找出哪里错了。

编辑:这是 MWE。

import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from torch.utils.data import DataLoader

train_dataset = MultimodalDataset(
    csv_file=config.TRAIN_CSV_PATH, max_images=config.MAX_IMAGES_RL
)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)

# Define PPO configuration
ppo_config = (
    PPOConfig()
    .training(gamma=0.9, lr=0.01)
    .environment(env="MultimodalSummarizationEnv", env_config=default_env_config)
    .framework("torch")
    .resources(num_gpus=0, num_cpus_per_worker=1)
)

# Create PPO trainer
trainer = ppo_config.build()

# Function to update worker environments
def update_env_config_and_reset(worker, new_env_config):
    worker.foreach_env(lambda env: env.reset(env_config=new_env_config))

# Training loop
for batch_idx, batch in enumerate(train_loader):
    # Prepare batch-specific env_config
    new_env_config = {
# new data for the batch_idx
    }

    # Update and reset environments for all workers
    trainer.workers.foreach_worker(
        lambda worker: update_env_config_and_reset(worker, new_env_config)
    )

    # Train PPO
    result = trainer.train()
ray.shutdown()
pytorch reinforcement-learning ray rllib
1个回答
0
投票

config.build()
创建一个
Algorithm
。使用
ray.tune.Tuner
ray.train.Trainer
创建工人。

foreach_worker
仅在您生成多个工人时才有效。 对于非分布式培训,您可以使用
algo.env_runner
访问您的本地工作人员。

在您的情况下,您只有一名这样的工作人员修改您的代码,如下所示:

algo = ppo_config.build()
for batch_idx, batch in enumerate(train_loader):
    # Prepare batch-specific env_config
    new_env_config = {
         # new data for the batch_idx
    }
    algo.env_runner.env.reset(env_config=new_env_config)

或者对于分布式训练(我认为也适用于本地训练)使用

algo.env_runner_group.foreach_env(lambda env: env.reset(env_config=new_env_config))

© www.soinside.com 2019 - 2024. All rights reserved.