parent
27da1caf44
commit
c9defe78f1
|
|
@ -1,22 +1,28 @@
|
|||
import gym
|
||||
from stable_baselines3 import PPO
|
||||
from src.environment.chengdu_mahjong_env import MahjongEnv
|
||||
from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
|
||||
import torch
|
||||
from configs.log_config import setup_logging
|
||||
|
||||
def train_model():
|
||||
# 创建 MahjongEnv 环境实例
|
||||
env = MahjongEnv()
|
||||
env = ChengduMahjongEnv()
|
||||
|
||||
# 检查是否有可用的GPU
|
||||
# 检查是否有可用的 GPU
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
# 使用 PPO 算法训练模型
|
||||
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device)
|
||||
# 使用 PPO 算法训练模型,切换到 MultiInputPolicy
|
||||
model = PPO(
|
||||
"MultiInputPolicy", # 更改为 MultiInputPolicy
|
||||
env,
|
||||
verbose=1,
|
||||
tensorboard_log="../logs/ppo_mahjong_tensorboard/",
|
||||
device=device
|
||||
)
|
||||
|
||||
# 训练模型,训练总步数为100000
|
||||
model.learn(total_timesteps=100)
|
||||
# 训练模型,训练总步数为 100000
|
||||
model.learn(total_timesteps=100000)
|
||||
|
||||
# 保存训练后的模型
|
||||
model.save("../models/ppo_mahjong_model")
|
||||
|
|
@ -27,7 +33,7 @@ def train_model():
|
|||
while not done:
|
||||
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
||||
obs, reward, done, info = env.step(action) # 执行动作
|
||||
env.render() # 打印环境状态
|
||||
print(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 调用配置函数来设置日志
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from loguru import logger
|
|||
|
||||
from src.engine.calculate_fan import calculate_fan
|
||||
from src.engine.mahjong_tile import MahjongTile
|
||||
from src.engine.meld import Meld
|
||||
|
||||
|
||||
def draw_tile(engine):
|
||||
|
|
@ -235,15 +236,16 @@ def handle_peng(self, player, tile):
|
|||
处理玩家碰牌逻辑并更新出牌顺序。
|
||||
"""
|
||||
if not isinstance(tile, MahjongTile):
|
||||
logger.error(f"玩家 {player} 碰牌的牌无效: {tile}")
|
||||
logger.error(f"tile 必须是 MahjongTile 类型,但收到的是: {type(tile)}")
|
||||
return False
|
||||
|
||||
if self.state.hands[player].tile_count[tile] < 2:
|
||||
logger.error(f"玩家 {player} 无法碰牌: {tile}")
|
||||
return False
|
||||
|
||||
# 更新状态
|
||||
self._update_meld(player, tile, "碰", count=2)
|
||||
# 减少两张牌
|
||||
self.state.hands[player].tile_count[tile] -= 2
|
||||
self.state.melds[player].append(Meld(tile, "碰")) # 确保使用 MahjongTile 对象
|
||||
|
||||
logger.info(f"玩家 {player} 碰了牌: {tile}。当前明牌: {self.state.melds[player]}")
|
||||
return True
|
||||
|
|
@ -251,6 +253,8 @@ def handle_peng(self, player, tile):
|
|||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_player_discard_choice(self, player):
|
||||
"""
|
||||
模拟获取玩家打牌的选择。
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ class MahjongTile:
|
|||
raise ValueError("Invalid tile")
|
||||
self.suit = suit
|
||||
self.value = value
|
||||
self.index = ({"条": 0, "筒": 1, "万": 2}[suit]) * 9 + (value - 1)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.value}{self.suit}"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import gym
|
||||
from gym import spaces
|
||||
import numpy as np
|
||||
|
||||
from src.engine.actions import handle_peng, handle_gang, handle_win
|
||||
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -50,13 +52,13 @@ class ChengduMahjongEnv(gym.Env):
|
|||
elif action == 14: # 碰
|
||||
tile_to_peng = self._get_tile_for_special_action("peng")
|
||||
if tile_to_peng:
|
||||
self.engine.handle_peng(current_player, tile_to_peng)
|
||||
handle_peng(self.engine,current_player, tile_to_peng)
|
||||
else:
|
||||
logger.warning("碰动作无效,未满足条件")
|
||||
elif action == 15: # 杠
|
||||
tile_to_gang = self._get_tile_for_special_action("gang")
|
||||
if tile_to_gang:
|
||||
self.engine.handle_gang(current_player, tile_to_gang, mode="an") # 默认暗杠
|
||||
handle_gang(current_player, tile_to_gang, mode="an") # 默认暗杠
|
||||
else:
|
||||
logger.warning("杠动作无效,未满足条件")
|
||||
elif action == 16: # 胡
|
||||
|
|
@ -65,7 +67,7 @@ class ChengduMahjongEnv(gym.Env):
|
|||
self.engine.state.melds[current_player],
|
||||
self.engine.state.missing_suits[current_player]
|
||||
):
|
||||
self.engine.handle_win(current_player, None, None)
|
||||
handle_win(current_player, None, None)
|
||||
else:
|
||||
logger.warning("胡动作无效,未满足条件")
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue