1
This commit is contained in:
2024-12-01 19:45:31 +08:00
parent 27da1caf44
commit c9defe78f1
4 changed files with 27 additions and 14 deletions

View File

@@ -1,22 +1,28 @@
import gym
from stable_baselines3 import PPO
from src.environment.chengdu_mahjong_env import MahjongEnv
from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
import torch
from configs.log_config import setup_logging
def train_model():
# 创建 MahjongEnv 环境实例
env = MahjongEnv()
env = ChengduMahjongEnv()
# 检查是否有可用的GPU
# 检查是否有可用的 GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
# 使用 PPO 算法训练模型
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device)
# 使用 PPO 算法训练模型,切换到 MultiInputPolicy
model = PPO(
"MultiInputPolicy", # 更改为 MultiInputPolicy
env,
verbose=1,
tensorboard_log="../logs/ppo_mahjong_tensorboard/",
device=device
)
# 训练模型训练总步数为100000
model.learn(total_timesteps=100)
# 训练模型,训练总步数为 100000
model.learn(total_timesteps=100000)
# 保存训练后的模型
model.save("../models/ppo_mahjong_model")
@@ -27,7 +33,7 @@ def train_model():
while not done:
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
obs, reward, done, info = env.step(action) # 执行动作
env.render() # 打印环境状态
print(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
if __name__ == "__main__":
# 调用配置函数来设置日志