From deff3cb9218964c71ce823af0198fd4244cf8c96 Mon Sep 17 00:00:00 2001 From: wsy182 <2392948297@qq.com> Date: Sun, 1 Dec 2024 22:47:09 +0800 Subject: [PATCH] 1 --- tests/{ => dizhu}/__init__.py | 0 tests/models/__init__.py | 0 tests/models/test_dizhu.py | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+) rename tests/{ => dizhu}/__init__.py (100%) create mode 100644 tests/models/__init__.py create mode 100644 tests/models/test_dizhu.py diff --git a/tests/__init__.py b/tests/dizhu/__init__.py similarity index 100% rename from tests/__init__.py rename to tests/dizhu/__init__.py diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_dizhu.py b/tests/models/test_dizhu.py new file mode 100644 index 0000000..6ebecb1 --- /dev/null +++ b/tests/models/test_dizhu.py @@ -0,0 +1,33 @@ +from stable_baselines3 import PPO +from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境 +from loguru import logger + + +def test_dizhu_model(): + # 创建斗地主环境 + env = DouDiZhuEnv() + + # 加载已训练的模型 + model_path = "../models/ppo_doudizhu_model.zip" + logger.info(f"加载模型: {model_path}") + model = PPO.load(model_path) + + # 测试模型 + obs = env.reset() + done = False + total_reward = 0 + logger.info("开始测试斗地主模型...") + + while not done: + action, _ = model.predict(obs, deterministic=True) # 使用训练好的模型预测动作 + obs, reward, done, info = env.step(action) # 执行动作 + total_reward += reward + + # 输出当前状态 + logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") + + logger.info(f"测试完成,总奖励: {total_reward}") + + +if __name__ == "__main__": + test_dizhu_model()