pull/1/head
wsy182 2024-11-30 20:43:21 +08:00
parent 7d040f7e40
commit fd6006b186
4 changed files with 15 additions and 5 deletions

View File

@ -1,4 +1,8 @@
import os
from loguru import logger
# 配置日志
logger.add("mahjong_ai_{time}.log", rotation="10 MB", level="DEBUG", format="{time} {level} {message}")
# 确保 ../logs 目录存在,如果不存在则创建
os.makedirs("../logs", exist_ok=True)
# 配置日志,记录到 ../logs 目录下
logger.add("../logs/chengdu_mj_engine.log", rotation="10 MB", level="DEBUG", format="{time} {level} {message}")

Binary file not shown.

View File

@ -1,19 +1,24 @@
import gym
from stable_baselines3 import PPO
from src.environment.chengdu_majiang_env import MahjongEnv
import torch
def train_model():
# 创建 MahjongEnv 环境实例
env = MahjongEnv()
# 检查是否有可用的GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}")
# 使用 PPO 算法训练模型
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/")
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device)
# 训练模型训练总步数为100000
model.learn(total_timesteps=100000)
model.learn(total_timesteps=100)
# 保存训练后的模型
model.save("ppo_mahjong_model")
model.save("../models/ppo_mahjong_model")
# 测试模型
obs = env.reset()

View File

@ -1,2 +1,3 @@
import torch
print(torch.cuda.is_available()) # 如果返回True说明可以使用GPU
print(torch.__version__)