28 lines
796 B
Python
28 lines
796 B
Python
import gym
|
||
from stable_baselines3 import PPO
|
||
from src.environment.chengdu_majiang_env import MahjongEnv
|
||
|
||
def train_model():
|
||
# 创建 MahjongEnv 环境实例
|
||
env = MahjongEnv()
|
||
|
||
# 使用 PPO 算法训练模型
|
||
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/")
|
||
|
||
# 训练模型,训练总步数为100000
|
||
model.learn(total_timesteps=100000)
|
||
|
||
# 保存训练后的模型
|
||
model.save("ppo_mahjong_model")
|
||
|
||
# 测试模型
|
||
obs = env.reset()
|
||
done = False
|
||
while not done:
|
||
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
||
obs, reward, done, info = env.step(action) # 执行动作
|
||
env.render() # 打印环境状态
|
||
|
||
if __name__ == "__main__":
|
||
train_model()
|