1
pull/1/head
wsy182 2024-11-30 19:27:17 +08:00
parent 3487c805d4
commit b78d6a17a4
2 changed files with 28 additions and 1 deletions

View File

@ -0,0 +1,27 @@
import gym
from stable_baselines3 import PPO
from src.environment.chengdu_majiang_env import MahjongEnv
def train_model():
# 创建 MahjongEnv 环境实例
env = MahjongEnv()
# 使用 PPO 算法训练模型
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="./ppo_mahjong_tensorboard/")
# 训练模型训练总步数为100000
model.learn(total_timesteps=100000)
# 保存训练后的模型
model.save("ppo_mahjong_model")
# 测试模型
obs = env.reset()
done = False
while not done:
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
obs, reward, done, info = env.step(action) # 执行动作
env.render() # 打印环境状态
if __name__ == "__main__":
train_model()

View File

@ -1,5 +1,5 @@
from loguru import logger
from utils import get_tile_name
from src.engine.utils import get_tile_name
def draw_tile(self):