diff --git a/tests/__init__.py b/tests/dizhu/__init__.py similarity index 100% rename from tests/__init__.py rename to tests/dizhu/__init__.py diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_dizhu.py b/tests/models/test_dizhu.py new file mode 100644 index 0000000..6ebecb1 --- /dev/null +++ b/tests/models/test_dizhu.py @@ -0,0 +1,33 @@ +from stable_baselines3 import PPO +from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境 +from loguru import logger + + +def test_dizhu_model(): + # 创建斗地主环境 + env = DouDiZhuEnv() + + # 加载已训练的模型 + model_path = "../models/ppo_doudizhu_model.zip" + logger.info(f"加载模型: {model_path}") + model = PPO.load(model_path) + + # 测试模型 + obs = env.reset() + done = False + total_reward = 0 + logger.info("开始测试斗地主模型...") + + while not done: + action, _ = model.predict(obs, deterministic=True) # 使用训练好的模型预测动作 + obs, reward, done, info = env.step(action) # 执行动作 + total_reward += reward + + # 输出当前状态 + logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") + + logger.info(f"测试完成,总奖励: {total_reward}") + + +if __name__ == "__main__": + test_dizhu_model()