我正在尝试从此处运行代码(此页面上的 Github 链接):https://keras.io/examples/rl/ppo_cartpole/
我在
observation = observation.reshape(1,-1)
的训练部分得到一个属性错误,它说“‘tuple’对象没有属性‘reshape’”。
看起来
observation
目前是env.reset()
这是一个数组(初始观察)和一个空字典(重置状态)的元组。我试过使用 observation[0].reshape(1,-1)
或 env.reset[0]
将其仅应用于数组,但是两行后会抛出“太多值无法解压(预期 4)”错误。有谁知道如何在不弄乱其余代码的情况下解决这个问题?
根据要求提供最小的可重现示例
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import gym
import scipy.signal
env = gym.make("CartPole-v0")
observation_dimensions = env.observation_space.shape[0]
num_actions = env.action_space.n
observation_input = keras.Input(shape=(observation_dimensions,), dtype=tf.float32)
logits = mlp(observation_input, list(hidden_sizes) + [num_actions], tf.tanh, None)
actor = keras.Model(inputs=observation_input, outputs=logits)
observation, episode_return, episode_length = env.reset(), 0, 0
observation = observation.reshape(1, -1)
logits, action = sample_action(observation)
observation_new, reward, done, _ = env.step(action[0].numpy())
episode_return += reward
episode_length += 1
在哪里
def mlp(x, sizes, activation=tf.tanh, output_activation=None):
# Build a feedforward neural network
for size in sizes[:-1]:
x = layers.Dense(units=size, activation=activation)(x)
return layers.Dense(units=sizes[-1], activation=output_activation)(x)
和
@tf.function
def sample_action(observation):
logits = actor(observation)
action = tf.squeeze(tf.random.categorical(logits, 1), axis=1)
return logits, action