from stable_baselines3 import PPO from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境 import torch from configs.log_config import setup_logging from loguru import logger # 使用日志工具 def train_dizhu_model(): # 创建 DouDiZhuEnv 环境实例 env = DouDiZhuEnv() # 检查是否有可用的 GPU device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"使用设备: {device}") # 使用 logger 记录设备信息 # 使用 PPO 算法训练模型,设置为 MultiInputPolicy model = PPO( "MultiInputPolicy", # 适用于多输入的策略 env, verbose=1, tensorboard_log="../logs/ppo_doudizhu_tensorboard/", # TensorBoard 日志路径 device=device ) # 训练模型,设定总训练步数 logger.info("开始训练斗地主模型...") model.learn(total_timesteps=100000) # 总训练步数 logger.info("斗地主模型训练完成!") # 保存训练后的模型 model_path = "../models/ppo_doudizhu_model" model.save(model_path) logger.info(f"模型已保存到 '{model_path}'") # 测试模型 logger.info("开始测试斗地主模型...") obs = env.reset() done = False while not done: action, _states = model.predict(obs) # 使用训练好的模型来选择动作 obs, reward, done, info = env.step(action) # 执行动作 logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") # 记录测试过程 if __name__ == "__main__": # 设置日志 setup_logging() train_dizhu_model()