1
1
This commit is contained in:
27
scripts/train_chengdu_mahjong_model.py
Normal file
27
scripts/train_chengdu_mahjong_model.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import gym
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from src.environment.chengdu_majiang_env import MahjongEnv
|
||||||
|
|
||||||
|
def train_model():
|
||||||
|
# 创建 MahjongEnv 环境实例
|
||||||
|
env = MahjongEnv()
|
||||||
|
|
||||||
|
# 使用 PPO 算法训练模型
|
||||||
|
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="./ppo_mahjong_tensorboard/")
|
||||||
|
|
||||||
|
# 训练模型,训练总步数为100000
|
||||||
|
model.learn(total_timesteps=100000)
|
||||||
|
|
||||||
|
# 保存训练后的模型
|
||||||
|
model.save("ppo_mahjong_model")
|
||||||
|
|
||||||
|
# 测试模型
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
||||||
|
obs, reward, done, info = env.step(action) # 执行动作
|
||||||
|
env.render() # 打印环境状态
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_model()
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from utils import get_tile_name
|
from src.engine.utils import get_tile_name
|
||||||
|
|
||||||
|
|
||||||
def draw_tile(self):
|
def draw_tile(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user