import gym from stable_baselines3 import PPO from src.environment.chengdu_mahjong_env import ChengduMahjongEnv import torch from configs.log_config import setup_logging from loguru import logger # 添加 logger def train_model(): # 创建 MahjongEnv 环境实例 env = ChengduMahjongEnv() # 检查是否有可用的 GPU device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {device}") # 替换 print 为 logger.info # 使用 PPO 算法训练模型,切换到 MultiInputPolicy model = PPO( "MultiInputPolicy", # 更改为 MultiInputPolicy env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device ) # 训练模型,训练总步数为 100000 logger.info("开始训练模型...") model.learn(total_timesteps=100) logger.info("模型训练完成!") # 保存训练后的模型 model.save("../models/ppo_mahjong_model") logger.info("模型已保存到 '../models/ppo_mahjong_model'") # 测试模型 logger.info("开始测试模型...") obs = env.reset() done = False while not done: action, _states = model.predict(obs) # 使用训练好的模型来选择动作 obs, reward, done, info = env.step(action) # 执行动作 logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") # 替换 print 为 logger.info if __name__ == "__main__": # 调用配置函数来设置日志 setup_logging() train_model()