1
1
This commit is contained in:
@@ -3,6 +3,7 @@ from stable_baselines3 import PPO
|
||||
from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
|
||||
import torch
|
||||
from configs.log_config import setup_logging
|
||||
from loguru import logger # 添加 logger
|
||||
|
||||
def train_model():
|
||||
# 创建 MahjongEnv 环境实例
|
||||
@@ -10,7 +11,7 @@ def train_model():
|
||||
|
||||
# 检查是否有可用的 GPU
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"使用设备: {device}")
|
||||
logger.info(f"使用设备: {device}") # 替换 print 为 logger.info
|
||||
|
||||
# 使用 PPO 算法训练模型,切换到 MultiInputPolicy
|
||||
model = PPO(
|
||||
@@ -22,18 +23,22 @@ def train_model():
|
||||
)
|
||||
|
||||
# 训练模型,训练总步数为 100000
|
||||
model.learn(total_timesteps=100000)
|
||||
logger.info("开始训练模型...")
|
||||
model.learn(total_timesteps=100)
|
||||
logger.info("模型训练完成!")
|
||||
|
||||
# 保存训练后的模型
|
||||
model.save("../models/ppo_mahjong_model")
|
||||
logger.info("模型已保存到 '../models/ppo_mahjong_model'")
|
||||
|
||||
# 测试模型
|
||||
logger.info("开始测试模型...")
|
||||
obs = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
||||
obs, reward, done, info = env.step(action) # 执行动作
|
||||
print(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
|
||||
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") # 替换 print 为 logger.info
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 调用配置函数来设置日志
|
||||
|
||||
Reference in New Issue
Block a user