我对强化学习很陌生,无法理解它。我无法使用 PPO 更新批量数据的配置。
我正在使用自定义的 GYM 环境,并希望使用 PPO 和我以 torch DataLoader 形式加载的外部数据来训练它。 我正在使用 Python 3.11 和 Ray 2.40.0。以下是相关代码:
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from torch.utils.data import DataLoader
train_dataset = MultimodalDataset(
csv_file=config.TRAIN_CSV_PATH, max_images=config.MAX_IMAGES_RL
)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
# Define PPO configuration
ppo_config = (
PPOConfig()
.training(gamma=0.9, lr=0.01)
.environment(env="MultimodalSummarizationEnv", env_config=default_env_config)
.framework("torch")
.resources(num_gpus=0, num_cpus_per_worker=1)
)
# Create PPO trainer
trainer = ppo_config.build()
# Function to update worker environments
def update_env_config_and_reset(worker, new_env_config):
worker.foreach_env(lambda env: env.reset(env_config=new_env_config))
# Training loop
for batch_idx, batch in enumerate(train_loader):
# Prepare batch-specific env_config
new_env_config = {
# new data for the batch_idx
}
# Update and reset environments for all workers
trainer.workers.foreach_worker(
lambda worker: update_env_config_and_reset(worker, new_env_config)
)
# Train PPO
result = trainer.train()
ray.shutdown()
但是,在运行代码时,我在 foreach_worker 上收到如下错误:
‘function’对象没有属性‘foreach_worker’
请帮我找出哪里错了。
编辑:这是 MWE。
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from torch.utils.data import DataLoader
train_dataset = MultimodalDataset(
csv_file=config.TRAIN_CSV_PATH, max_images=config.MAX_IMAGES_RL
)
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True)
# Define PPO configuration
ppo_config = (
PPOConfig()
.training(gamma=0.9, lr=0.01)
.environment(env="MultimodalSummarizationEnv", env_config=default_env_config)
.framework("torch")
.resources(num_gpus=0, num_cpus_per_worker=1)
)
# Create PPO trainer
trainer = ppo_config.build()
# Function to update worker environments
def update_env_config_and_reset(worker, new_env_config):
worker.foreach_env(lambda env: env.reset(env_config=new_env_config))
# Training loop
for batch_idx, batch in enumerate(train_loader):
# Prepare batch-specific env_config
new_env_config = {
# new data for the batch_idx
}
# Update and reset environments for all workers
trainer.workers.foreach_worker(
lambda worker: update_env_config_and_reset(worker, new_env_config)
)
# Train PPO
result = trainer.train()
ray.shutdown()
config.build()
创建一个 Algorithm
。使用 ray.tune.Tuner
或 ray.train.Trainer
创建工人。
foreach_worker
仅在您生成多个工人时才有效。
对于非分布式培训,您可以使用 algo.env_runner
访问您的本地工作人员。
在您的情况下,您只有一名这样的工作人员修改您的代码,如下所示:
algo = ppo_config.build()
for batch_idx, batch in enumerate(train_loader):
# Prepare batch-specific env_config
new_env_config = {
# new data for the batch_idx
}
algo.env_runner.env.reset(env_config=new_env_config)
或者对于分布式训练(我认为也适用于本地训练)使用
algo.env_runner_group.foreach_env(lambda env: env.reset(env_config=new_env_config))
。