pull/1/head
parent
7d040f7e40
commit
fd6006b186
|
|
@ -1,4 +1,8 @@
|
||||||
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# 配置日志
|
# 确保 ../logs 目录存在,如果不存在则创建
|
||||||
logger.add("mahjong_ai_{time}.log", rotation="10 MB", level="DEBUG", format="{time} {level} {message}")
|
os.makedirs("../logs", exist_ok=True)
|
||||||
|
|
||||||
|
# 配置日志,记录到 ../logs 目录下
|
||||||
|
logger.add("../logs/chengdu_mj_engine.log", rotation="10 MB", level="DEBUG", format="{time} {level} {message}")
|
||||||
|
|
|
||||||
Binary file not shown.
|
|
@ -1,19 +1,24 @@
|
||||||
import gym
|
import gym
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from src.environment.chengdu_majiang_env import MahjongEnv
|
from src.environment.chengdu_majiang_env import MahjongEnv
|
||||||
|
import torch
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
# 创建 MahjongEnv 环境实例
|
# 创建 MahjongEnv 环境实例
|
||||||
env = MahjongEnv()
|
env = MahjongEnv()
|
||||||
|
|
||||||
|
# 检查是否有可用的GPU
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"使用设备: {device}")
|
||||||
|
|
||||||
# 使用 PPO 算法训练模型
|
# 使用 PPO 算法训练模型
|
||||||
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/")
|
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device)
|
||||||
|
|
||||||
# 训练模型,训练总步数为100000
|
# 训练模型,训练总步数为100000
|
||||||
model.learn(total_timesteps=100000)
|
model.learn(total_timesteps=100)
|
||||||
|
|
||||||
# 保存训练后的模型
|
# 保存训练后的模型
|
||||||
model.save("ppo_mahjong_model")
|
model.save("../models/ppo_mahjong_model")
|
||||||
|
|
||||||
# 测试模型
|
# 测试模型
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue