dev
parent
96f0fbdcd7
commit
deff3cb921
|
|
@ -0,0 +1,33 @@
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
def test_dizhu_model():
|
||||||
|
# 创建斗地主环境
|
||||||
|
env = DouDiZhuEnv()
|
||||||
|
|
||||||
|
# 加载已训练的模型
|
||||||
|
model_path = "../models/ppo_doudizhu_model.zip"
|
||||||
|
logger.info(f"加载模型: {model_path}")
|
||||||
|
model = PPO.load(model_path)
|
||||||
|
|
||||||
|
# 测试模型
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
total_reward = 0
|
||||||
|
logger.info("开始测试斗地主模型...")
|
||||||
|
|
||||||
|
while not done:
|
||||||
|
action, _ = model.predict(obs, deterministic=True) # 使用训练好的模型预测动作
|
||||||
|
obs, reward, done, info = env.step(action) # 执行动作
|
||||||
|
total_reward += reward
|
||||||
|
|
||||||
|
# 输出当前状态
|
||||||
|
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
|
||||||
|
|
||||||
|
logger.info(f"测试完成,总奖励: {total_reward}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_dizhu_model()
|
||||||
Loading…
Reference in New Issue