36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
from stable_baselines3 import PPO
|
|
from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境
|
|
from loguru import logger
|
|
|
|
|
|
def test_dizhu_model(): # 确保函数名以 test_ 开头
|
|
# 创建斗地主环境
|
|
env = DouDiZhuEnv()
|
|
|
|
# 加载已训练的模型
|
|
model_path = "./models/ppo_doudizhu_model.zip" # 确保路径正确
|
|
logger.info(f"加载模型: {model_path}")
|
|
try:
|
|
model = PPO.load(model_path)
|
|
except Exception as e:
|
|
logger.error(f"加载模型失败: {e}")
|
|
return
|
|
|
|
# 测试模型
|
|
obs = env.reset()
|
|
done = False
|
|
total_reward = 0
|
|
logger.info("开始测试斗地主模型...")
|
|
|
|
max_steps = 1000 # 设置最大步数
|
|
step_count = 0
|
|
|
|
while not done and step_count < max_steps:
|
|
action, _ = model.predict(obs, deterministic=True)
|
|
obs, reward, done, info = env.step(action)
|
|
total_reward += reward
|
|
step_count += 1
|
|
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
|
|
|
|
logger.info(f"测试完成,总奖励: {total_reward}")
|