47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
from stable_baselines3 import PPO
|
|
from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境
|
|
import torch
|
|
from configs.log_config import setup_logging
|
|
from loguru import logger # 使用日志工具
|
|
|
|
def train_dizhu_model():
|
|
# 创建 DouDiZhuEnv 环境实例
|
|
env = DouDiZhuEnv()
|
|
|
|
# 检查是否有可用的 GPU
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
logger.info(f"使用设备: {device}") # 使用 logger 记录设备信息
|
|
|
|
# 使用 PPO 算法训练模型,设置为 MultiInputPolicy
|
|
model = PPO(
|
|
"MultiInputPolicy", # 适用于多输入的策略
|
|
env,
|
|
verbose=1,
|
|
tensorboard_log="../logs/ppo_doudizhu_tensorboard/", # TensorBoard 日志路径
|
|
device=device
|
|
)
|
|
|
|
# 训练模型,设定总训练步数
|
|
logger.info("开始训练斗地主模型...")
|
|
model.learn(total_timesteps=100000) # 总训练步数
|
|
logger.info("斗地主模型训练完成!")
|
|
|
|
# 保存训练后的模型
|
|
model_path = "../models/ppo_doudizhu_model"
|
|
model.save(model_path)
|
|
logger.info(f"模型已保存到 '{model_path}'")
|
|
|
|
# 测试模型
|
|
logger.info("开始测试斗地主模型...")
|
|
obs = env.reset()
|
|
done = False
|
|
while not done:
|
|
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
|
obs, reward, done, info = env.step(action) # 执行动作
|
|
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") # 记录测试过程
|
|
|
|
if __name__ == "__main__":
|
|
# 设置日志
|
|
setup_logging()
|
|
train_dizhu_model()
|