mjAi/scripts/train_chengdu_mahjong_model.py

36 lines
1.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import gym
from stable_baselines3 import PPO
from src.environment.chengdu_majiang_env import MahjongEnv
import torch
from configs.log_config import setup_logging
def train_model():
# 创建 MahjongEnv 环境实例
env = MahjongEnv()
# 检查是否有可用的GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
# 使用 PPO 算法训练模型
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device)
# 训练模型训练总步数为100000
model.learn(total_timesteps=100)
# 保存训练后的模型
model.save("../models/ppo_mahjong_model")
# 测试模型
obs = env.reset()
done = False
while not done:
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
obs, reward, done, info = env.step(action) # 执行动作
env.render() # 打印环境状态
if __name__ == "__main__":
# 调用配置函数来设置日志
setup_logging()
train_model()