wsy182 2024-12-01 19:45:31 +08:00
parent 27da1caf44
commit c9defe78f1
4 changed files with 27 additions and 14 deletions

View File

@ -1,22 +1,28 @@
import gym import gym
from stable_baselines3 import PPO from stable_baselines3 import PPO
from src.environment.chengdu_mahjong_env import MahjongEnv from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
import torch import torch
from configs.log_config import setup_logging from configs.log_config import setup_logging
def train_model(): def train_model():
# 创建 MahjongEnv 环境实例 # 创建 MahjongEnv 环境实例
env = MahjongEnv() env = ChengduMahjongEnv()
# 检查是否有可用的 GPU # 检查是否有可用的 GPU
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用设备: {device}") print(f"使用设备: {device}")
# 使用 PPO 算法训练模型 # 使用 PPO 算法训练模型,切换到 MultiInputPolicy
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device) model = PPO(
"MultiInputPolicy", # 更改为 MultiInputPolicy
env,
verbose=1,
tensorboard_log="../logs/ppo_mahjong_tensorboard/",
device=device
)
# 训练模型,训练总步数为 100000 # 训练模型,训练总步数为 100000
model.learn(total_timesteps=100) model.learn(total_timesteps=100000)
# 保存训练后的模型 # 保存训练后的模型
model.save("../models/ppo_mahjong_model") model.save("../models/ppo_mahjong_model")
@ -27,7 +33,7 @@ def train_model():
while not done: while not done:
action, _states = model.predict(obs) # 使用训练好的模型来选择动作 action, _states = model.predict(obs) # 使用训练好的模型来选择动作
obs, reward, done, info = env.step(action) # 执行动作 obs, reward, done, info = env.step(action) # 执行动作
env.render() # 打印环境状态 print(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
if __name__ == "__main__": if __name__ == "__main__":
# 调用配置函数来设置日志 # 调用配置函数来设置日志

View File

@ -4,6 +4,7 @@ from loguru import logger
from src.engine.calculate_fan import calculate_fan from src.engine.calculate_fan import calculate_fan
from src.engine.mahjong_tile import MahjongTile from src.engine.mahjong_tile import MahjongTile
from src.engine.meld import Meld
def draw_tile(engine): def draw_tile(engine):
@ -235,15 +236,16 @@ def handle_peng(self, player, tile):
处理玩家碰牌逻辑并更新出牌顺序 处理玩家碰牌逻辑并更新出牌顺序
""" """
if not isinstance(tile, MahjongTile): if not isinstance(tile, MahjongTile):
logger.error(f"玩家 {player} 碰牌的牌无效: {tile}") logger.error(f"tile 必须是 MahjongTile 类型,但收到的是: {type(tile)}")
return False return False
if self.state.hands[player].tile_count[tile] < 2: if self.state.hands[player].tile_count[tile] < 2:
logger.error(f"玩家 {player} 无法碰牌: {tile}") logger.error(f"玩家 {player} 无法碰牌: {tile}")
return False return False
# 更新状态 # 减少两张牌
self._update_meld(player, tile, "", count=2) self.state.hands[player].tile_count[tile] -= 2
self.state.melds[player].append(Meld(tile, "")) # 确保使用 MahjongTile 对象
logger.info(f"玩家 {player} 碰了牌: {tile}。当前明牌: {self.state.melds[player]}") logger.info(f"玩家 {player} 碰了牌: {tile}。当前明牌: {self.state.melds[player]}")
return True return True
@ -251,6 +253,8 @@ def handle_peng(self, player, tile):
def get_player_discard_choice(self, player): def get_player_discard_choice(self, player):
""" """
模拟获取玩家打牌的选择 模拟获取玩家打牌的选择

View File

@ -6,6 +6,7 @@ class MahjongTile:
raise ValueError("Invalid tile") raise ValueError("Invalid tile")
self.suit = suit self.suit = suit
self.value = value self.value = value
self.index = ({"": 0, "": 1, "": 2}[suit]) * 9 + (value - 1)
def __repr__(self): def __repr__(self):
return f"{self.value}{self.suit}" return f"{self.value}{self.suit}"

View File

@ -1,6 +1,8 @@
import gym import gym
from gym import spaces from gym import spaces
import numpy as np import numpy as np
from src.engine.actions import handle_peng, handle_gang, handle_win
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
from loguru import logger from loguru import logger
@ -50,13 +52,13 @@ class ChengduMahjongEnv(gym.Env):
elif action == 14: # 碰 elif action == 14: # 碰
tile_to_peng = self._get_tile_for_special_action("peng") tile_to_peng = self._get_tile_for_special_action("peng")
if tile_to_peng: if tile_to_peng:
self.engine.handle_peng(current_player, tile_to_peng) handle_peng(self.engine,current_player, tile_to_peng)
else: else:
logger.warning("碰动作无效,未满足条件") logger.warning("碰动作无效,未满足条件")
elif action == 15: # 杠 elif action == 15: # 杠
tile_to_gang = self._get_tile_for_special_action("gang") tile_to_gang = self._get_tile_for_special_action("gang")
if tile_to_gang: if tile_to_gang:
self.engine.handle_gang(current_player, tile_to_gang, mode="an") # 默认暗杠 handle_gang(current_player, tile_to_gang, mode="an") # 默认暗杠
else: else:
logger.warning("杠动作无效,未满足条件") logger.warning("杠动作无效,未满足条件")
elif action == 16: # 胡 elif action == 16: # 胡
@ -65,7 +67,7 @@ class ChengduMahjongEnv(gym.Env):
self.engine.state.melds[current_player], self.engine.state.melds[current_player],
self.engine.state.missing_suits[current_player] self.engine.state.missing_suits[current_player]
): ):
self.engine.handle_win(current_player, None, None) handle_win(current_player, None, None)
else: else:
logger.warning("胡动作无效,未满足条件") logger.warning("胡动作无效,未满足条件")
else: else: