稳定基线3:加载模型时固定种子的原因

问题描述 投票:0回答:1

我正在研究强化学习,使用 PPO 模型(稳定基线 3)预测日内 VWAP。

最大的问题是模型的可重复性。

当我在测试环境中加载并测试训练好的模型时,模型的结果每次都不一样。

尽管我的环境没有任何随机性! (修复重置方法的种子、观察空间、动作空间时

我发现了两个耗费我时间的解决方案。

首先,加载模型时设置种子。

random_seed=42
# Set seed for reproducibility
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

其次,在代码开头为每个模块设置种子。

# Load Model
model.load("./logs/ppo_vwap_predict_20240919_20240111.zip", env=env, seed=random_seed)

我可以理解第二个解决方案,但不能理解第一个。 为什么加载模型时需要随机种子?

我怀疑 PPO 模型内的探索和动作选择(使用 np.choice)是原因。

我很感激任何能为我不确定的怀疑提供想法的人。

我的代码在这里(main.py)

import json
import datetime

import pandas as pd
import numpy as np
from numpy.random import SeedSequence, default_rng
import random

import gym
import talib as ta

import torch

from env.ExpectVolumeEnv import ExpectVolumeEnv
from env.ExpectVolumeEnvDiscrete import ExpectVolumeEnvDiscrete

from stable_baselines3 import PPO
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import StopTrainingOnNoModelImprovement, StopTrainingOnRewardThreshold, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3 import HerReplayBuffer
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy

import matplotlib.pyplot as plt


'''
reference
https://github.com/notadamking/Stock-Trading-Environment
'''

'''
Data
20XX-XX-XX KOSPI Intraday Data
'''

random_seed = 42

# # Set seed for reproducibility
# torch.manual_seed(random_seed)
# torch.cuda.manual_seed(random_seed)
# torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# np.random.seed(random_seed)
# random.seed(random_seed)

# Load data
df = pd.read_csv("data/raw/kospi_minutes/[지수KOSPI계열]일중 시세정보(1분)(주문번호-2499-1)_20240111.csv", encoding='cp949')

# DataFrame Preprocessing
df = df[df['지수명']=='코스피']
df = df[df['거래시각'] <= '1530']

data_date = str(df['거래일자'].iloc[0])

df = df[['거래시각', '시가', '고가', '저가', '종가', '거래량']]
df.columns = ['Time', 'Open', 'High', 'Low', 'Close', 'Volume']

print(df)

df = df.astype(float)
df = df.reset_index(drop=False)


# # Create environment
env = ExpectVolumeEnv(df, seed=random_seed)
env.action_space.seed(random_seed)
env.observation_space.seed(random_seed)

# Create model (PPO)
model = PPO("MlpPolicy",
            env,
            learning_rate=0.00025,
            batch_size=128,
            verbose=1,
            )


# print(help(model.load))


# # Total timesteps / Number of steps per episode = Number of episodes
# model.learn(total_timesteps=len(df)*100)

# # # Save model
# model.save(f"./logs/ppo_vwap_predict_{datetime.datetime.now().strftime('%Y%m%d')}_{data_date}.zip")


# Load Model
model.load("./logs/ppo_vwap_predict_20240919_20240111.zip", env=env, seed=random_seed)

# observation, empty = env.reset(seed=random_seed)
observation, empty = env.reset()


print("mean: ", df['Close'].mean())
plt.plot(df['Volume'], label=f'{data_date} Market Volume')
plt.show()

plt.plot(df['Close'], label=f'{data_date} Market Close')
plt.show()

# Render each environment separately
for _ in range(len(df)-1):
    action, _states = model.predict(observation)
    observation, reward, terminated, truncated, info = env.step(action)
    env.render()

market_vwap = env.render_plot(data_date=data_date)

volume_pattern = pd.read_csv('./data/volume.csv')
scaled_mean = volume_pattern['scaled_mean']

proportion = scaled_mean / np.sum(scaled_mean)

static_model_vwap = np.sum(df['Close'] * proportion)
print(f"Static Model VWAP: {static_model_vwap}")
print(f"Static Model VWAP Gap: {market_vwap - static_model_vwap}")
pytorch reinforcement-learning stable-baselines
1个回答
0
投票

你需要设置 model.predict(observation, deterministic = False)

© www.soinside.com 2019 - 2024. All rights reserved.