我正在尝试创建一个小型工作示例,说明如何将多离散操作空间与盒子观察空间一起使用。我遇到的问题之一是使用正常策略返回的尺寸与 Box 尺寸不匹配。基本策略返回大小为 25 的内容,而我需要 (5,5) 的内容。
我尝试通过生成自定义“策略”(实际上是一个网络)来缓解这个问题,作为最后一步,我将输出重塑为 (5,5) 而不是 25。这导致了一系列问题。我尝试阅读有关如何创建自定义策略的文档;然而,我一生都无法找到问题所在。
我尝试使用policy_kwargs;然而,我不知道如何写应该重塑神经网络。
我尝试使用 BaseFeaturesExtractor,但也没有成功。
1和2的各种组合。
我已经包含了我所做的各种不同尝试中收到的一些错误消息。有谁知道我缺少什么?我是否误解了一些完全基本的东西?
import numpy as np
import gym
import torch.nn as nn
import torch as th
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor # don't know if this is necessary
# -------- Attempt using BaseFeaturesExtractor
# class CustomPolicy(BaseFeaturesExtractor): # Don't know if BaseFeaturesExtractor is correct
# def __init__(self, observation_space, action_space, features_dim: int = 25): # Features should perhaps be (5,5)
# super().__init__(observation_space, features_dim)
# --------
# Define a custom neural network architecture
class CustomPolicy():
def __init__(self, observation_space, action_space):
super().__init__()
# Define the layers of the neural network
self.fc1 = nn.Linear(observation_space.shape[0], 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_space.shape[0])
# Reshape the output to match the Box observation space shape
def forward(self, x):
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
x = th.reshape(x, (5, 5))
return x
# Define the grid world environment
class GridWorldEnv(gym.Env):
def __init__(self):
self.observation_space = gym.spaces.Box(low=0, high=1, shape=(5, 5), dtype=np.float32)
self.action_space = gym.spaces.MultiDiscrete([5, 3]) # 5 movement directions, 3 movement distances
self.state = np.zeros((5, 5))
self.state[0, 0] = 1 # Start location
self.goal = (4, 4) # Goal location
self.steps = 0
self.state.flatten()
def reset(self):
self.state = np.zeros((5, 5))
self.state[0, 0] = 1 # Start location
self.goal = (4, 4) # Goal location
self.steps = 0
return self.state.flatten()
def step(self, action):
direction, distance = action
reward = -1
done = False
# Calculate the movement offset based on the selected direction and distance
if direction == 0:
offset = (distance, 0)
elif direction == 1:
offset = (-distance, 0)
elif direction == 2:
offset = (0, distance)
elif direction == 3:
offset = (0, -distance)
else:
offset = (0, 0)
# Calculate the new position based on the current position and movement offset
current_pos = np.argwhere(self.state == 1)[0]
new_pos = tuple(np.clip(current_pos + np.array(offset), 0, 4))
# Update the state with the new position
self.state[current_pos] = 0
self.state[new_pos] = 1
# Check if the agent has reached the goal
if np.argmax(self.state) == np.ravel_multi_index(self.goal, self.state.shape):
reward = 10
done = True
# Increment step count and check if episode should end
self.steps += 1
if self.steps >= 50:
done = True
return self.state, reward, done, {}
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
# Create an environment with the CustomEnv environment
env = GridWorldEnv()
# Create policy
policy = CustomPolicy(env.observation_space, env.action_space)
# Create a PPO agent with the CustomPolicy
model = PPO(policy=policy, env=env, verbose=1)
# --------- TypeError: 'CustomPolicy' object is not callable
# --------- Attempt at using policy_kwargs
# policy_kwargs = dict(activation_fn=th.nn.ReLU,
# net_arch=dict(pi=[32, 32], vf=[32, 32]))
# model = PPO("MlpPolicy", env=env, verbose=1, policy_kwargs=policy_kwargs)
# --------- ValueError: could not broadcast input array from shape (25,) into shape (5,5)
# --------- Attempt at using policy_kwargs with custom policy
# policy_kwargs = dict(
# features_extractor_class=CustomPolicy,
# features_extractor_kwargs=dict(features_dim=25), # should perhaps be (5,5)
# )
# model = PPO(policy=policy, env=env, verbose=1, policy_kwargs=policy_kwargs)
# --------- TypeError: CustomPolicy.forward() got an unexpected keyword argument 'use_sde'
# Train the agent for 1000 steps
model.learn(total_timesteps=1000)
为什么不在 env 步骤函数中进行必要的重塑?
例如,
def step(self, action):
action = action.reshape(5,5)