mjAi/test_dizhu.py

36 lines
1.1 KiB
Python

from stable_baselines3 import PPO
from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境
from loguru import logger
def test_dizhu_model(): # 确保函数名以 test_ 开头
# 创建斗地主环境
env = DouDiZhuEnv()
# 加载已训练的模型
model_path = "./models/ppo_doudizhu_model.zip" # 确保路径正确
logger.info(f"加载模型: {model_path}")
try:
model = PPO.load(model_path)
except Exception as e:
logger.error(f"加载模型失败: {e}")
return
# 测试模型
obs = env.reset()
done = False
total_reward = 0
logger.info("开始测试斗地主模型...")
max_steps = 1000 # 设置最大步数
step_count = 0
while not done and step_count < max_steps:
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
total_reward += reward
step_count += 1
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
logger.info(f"测试完成,总奖励: {total_reward}")