1
This commit is contained in:
@@ -1,19 +1,24 @@
|
||||
import gym
|
||||
from stable_baselines3 import PPO
|
||||
from src.environment.chengdu_majiang_env import MahjongEnv
|
||||
import torch
|
||||
|
||||
def train_model():
|
||||
# 创建 MahjongEnv 环境实例
|
||||
env = MahjongEnv()
|
||||
|
||||
# 检查是否有可用的GPU
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 使用 PPO 算法训练模型
|
||||
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/")
|
||||
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device)
|
||||
|
||||
# 训练模型,训练总步数为100000
|
||||
model.learn(total_timesteps=100000)
|
||||
model.learn(total_timesteps=100)
|
||||
|
||||
# 保存训练后的模型
|
||||
model.save("ppo_mahjong_model")
|
||||
model.save("../models/ppo_mahjong_model")
|
||||
|
||||
# 测试模型
|
||||
obs = env.reset()
|
||||
|
||||
Reference in New Issue
Block a user