import gym from stable_baselines3 import PPO from src.environment.chengdu_mahjong_env import MahjongEnv import torch from configs.log_config import setup_logging def train_model(): # 创建 MahjongEnv 环境实例 env = MahjongEnv() # 检查是否有可用的GPU device = "cuda" if torch.cuda.is_available() else "cpu" print(f"使用设备: {device}") # 使用 PPO 算法训练模型 model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device) # 训练模型,训练总步数为100000 model.learn(total_timesteps=100) # 保存训练后的模型 model.save("../models/ppo_mahjong_model") # 测试模型 obs = env.reset() done = False while not done: action, _states = model.predict(obs) # 使用训练好的模型来选择动作 obs, reward, done, info = env.step(action) # 执行动作 env.render() # 打印环境状态 if __name__ == "__main__": # 调用配置函数来设置日志 setup_logging() train_model()