parent
27da1caf44
commit
c9defe78f1
|
|
@ -1,22 +1,28 @@
|
||||||
import gym
|
import gym
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from src.environment.chengdu_mahjong_env import MahjongEnv
|
from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
|
||||||
import torch
|
import torch
|
||||||
from configs.log_config import setup_logging
|
from configs.log_config import setup_logging
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
# 创建 MahjongEnv 环境实例
|
# 创建 MahjongEnv 环境实例
|
||||||
env = MahjongEnv()
|
env = ChengduMahjongEnv()
|
||||||
|
|
||||||
# 检查是否有可用的GPU
|
# 检查是否有可用的 GPU
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print(f"使用设备: {device}")
|
print(f"使用设备: {device}")
|
||||||
|
|
||||||
# 使用 PPO 算法训练模型
|
# 使用 PPO 算法训练模型,切换到 MultiInputPolicy
|
||||||
model = PPO("MlpPolicy", env, verbose=1, tensorboard_log="../logs/ppo_mahjong_tensorboard/", device=device)
|
model = PPO(
|
||||||
|
"MultiInputPolicy", # 更改为 MultiInputPolicy
|
||||||
|
env,
|
||||||
|
verbose=1,
|
||||||
|
tensorboard_log="../logs/ppo_mahjong_tensorboard/",
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
# 训练模型,训练总步数为100000
|
# 训练模型,训练总步数为 100000
|
||||||
model.learn(total_timesteps=100)
|
model.learn(total_timesteps=100000)
|
||||||
|
|
||||||
# 保存训练后的模型
|
# 保存训练后的模型
|
||||||
model.save("../models/ppo_mahjong_model")
|
model.save("../models/ppo_mahjong_model")
|
||||||
|
|
@ -27,7 +33,7 @@ def train_model():
|
||||||
while not done:
|
while not done:
|
||||||
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
||||||
obs, reward, done, info = env.step(action) # 执行动作
|
obs, reward, done, info = env.step(action) # 执行动作
|
||||||
env.render() # 打印环境状态
|
print(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 调用配置函数来设置日志
|
# 调用配置函数来设置日志
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ from loguru import logger
|
||||||
|
|
||||||
from src.engine.calculate_fan import calculate_fan
|
from src.engine.calculate_fan import calculate_fan
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src.engine.mahjong_tile import MahjongTile
|
||||||
|
from src.engine.meld import Meld
|
||||||
|
|
||||||
|
|
||||||
def draw_tile(engine):
|
def draw_tile(engine):
|
||||||
|
|
@ -235,15 +236,16 @@ def handle_peng(self, player, tile):
|
||||||
处理玩家碰牌逻辑并更新出牌顺序。
|
处理玩家碰牌逻辑并更新出牌顺序。
|
||||||
"""
|
"""
|
||||||
if not isinstance(tile, MahjongTile):
|
if not isinstance(tile, MahjongTile):
|
||||||
logger.error(f"玩家 {player} 碰牌的牌无效: {tile}")
|
logger.error(f"tile 必须是 MahjongTile 类型,但收到的是: {type(tile)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.state.hands[player].tile_count[tile] < 2:
|
if self.state.hands[player].tile_count[tile] < 2:
|
||||||
logger.error(f"玩家 {player} 无法碰牌: {tile}")
|
logger.error(f"玩家 {player} 无法碰牌: {tile}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 更新状态
|
# 减少两张牌
|
||||||
self._update_meld(player, tile, "碰", count=2)
|
self.state.hands[player].tile_count[tile] -= 2
|
||||||
|
self.state.melds[player].append(Meld(tile, "碰")) # 确保使用 MahjongTile 对象
|
||||||
|
|
||||||
logger.info(f"玩家 {player} 碰了牌: {tile}。当前明牌: {self.state.melds[player]}")
|
logger.info(f"玩家 {player} 碰了牌: {tile}。当前明牌: {self.state.melds[player]}")
|
||||||
return True
|
return True
|
||||||
|
|
@ -251,6 +253,8 @@ def handle_peng(self, player, tile):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_player_discard_choice(self, player):
|
def get_player_discard_choice(self, player):
|
||||||
"""
|
"""
|
||||||
模拟获取玩家打牌的选择。
|
模拟获取玩家打牌的选择。
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ class MahjongTile:
|
||||||
raise ValueError("Invalid tile")
|
raise ValueError("Invalid tile")
|
||||||
self.suit = suit
|
self.suit = suit
|
||||||
self.value = value
|
self.value = value
|
||||||
|
self.index = ({"条": 0, "筒": 1, "万": 2}[suit]) * 9 + (value - 1)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.value}{self.suit}"
|
return f"{self.value}{self.suit}"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import gym
|
import gym
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from src.engine.actions import handle_peng, handle_gang, handle_win
|
||||||
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
|
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
@ -50,13 +52,13 @@ class ChengduMahjongEnv(gym.Env):
|
||||||
elif action == 14: # 碰
|
elif action == 14: # 碰
|
||||||
tile_to_peng = self._get_tile_for_special_action("peng")
|
tile_to_peng = self._get_tile_for_special_action("peng")
|
||||||
if tile_to_peng:
|
if tile_to_peng:
|
||||||
self.engine.handle_peng(current_player, tile_to_peng)
|
handle_peng(self.engine,current_player, tile_to_peng)
|
||||||
else:
|
else:
|
||||||
logger.warning("碰动作无效,未满足条件")
|
logger.warning("碰动作无效,未满足条件")
|
||||||
elif action == 15: # 杠
|
elif action == 15: # 杠
|
||||||
tile_to_gang = self._get_tile_for_special_action("gang")
|
tile_to_gang = self._get_tile_for_special_action("gang")
|
||||||
if tile_to_gang:
|
if tile_to_gang:
|
||||||
self.engine.handle_gang(current_player, tile_to_gang, mode="an") # 默认暗杠
|
handle_gang(current_player, tile_to_gang, mode="an") # 默认暗杠
|
||||||
else:
|
else:
|
||||||
logger.warning("杠动作无效,未满足条件")
|
logger.warning("杠动作无效,未满足条件")
|
||||||
elif action == 16: # 胡
|
elif action == 16: # 胡
|
||||||
|
|
@ -65,7 +67,7 @@ class ChengduMahjongEnv(gym.Env):
|
||||||
self.engine.state.melds[current_player],
|
self.engine.state.melds[current_player],
|
||||||
self.engine.state.missing_suits[current_player]
|
self.engine.state.missing_suits[current_player]
|
||||||
):
|
):
|
||||||
self.engine.handle_win(current_player, None, None)
|
handle_win(current_player, None, None)
|
||||||
else:
|
else:
|
||||||
logger.warning("胡动作无效,未满足条件")
|
logger.warning("胡动作无效,未满足条件")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue