parent
c9defe78f1
commit
0864295a6e
Binary file not shown.
|
|
@ -3,6 +3,7 @@ from stable_baselines3 import PPO
|
||||||
from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
|
from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
|
||||||
import torch
|
import torch
|
||||||
from configs.log_config import setup_logging
|
from configs.log_config import setup_logging
|
||||||
|
from loguru import logger # 添加 logger
|
||||||
|
|
||||||
def train_model():
|
def train_model():
|
||||||
# 创建 MahjongEnv 环境实例
|
# 创建 MahjongEnv 环境实例
|
||||||
|
|
@ -10,7 +11,7 @@ def train_model():
|
||||||
|
|
||||||
# 检查是否有可用的 GPU
|
# 检查是否有可用的 GPU
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
print(f"使用设备: {device}")
|
logger.info(f"使用设备: {device}") # 替换 print 为 logger.info
|
||||||
|
|
||||||
# 使用 PPO 算法训练模型,切换到 MultiInputPolicy
|
# 使用 PPO 算法训练模型,切换到 MultiInputPolicy
|
||||||
model = PPO(
|
model = PPO(
|
||||||
|
|
@ -22,18 +23,22 @@ def train_model():
|
||||||
)
|
)
|
||||||
|
|
||||||
# 训练模型,训练总步数为 100000
|
# 训练模型,训练总步数为 100000
|
||||||
model.learn(total_timesteps=100000)
|
logger.info("开始训练模型...")
|
||||||
|
model.learn(total_timesteps=100)
|
||||||
|
logger.info("模型训练完成!")
|
||||||
|
|
||||||
# 保存训练后的模型
|
# 保存训练后的模型
|
||||||
model.save("../models/ppo_mahjong_model")
|
model.save("../models/ppo_mahjong_model")
|
||||||
|
logger.info("模型已保存到 '../models/ppo_mahjong_model'")
|
||||||
|
|
||||||
# 测试模型
|
# 测试模型
|
||||||
|
logger.info("开始测试模型...")
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
||||||
obs, reward, done, info = env.step(action) # 执行动作
|
obs, reward, done, info = env.step(action) # 执行动作
|
||||||
print(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
|
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") # 替换 print 为 logger.info
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 调用配置函数来设置日志
|
# 调用配置函数来设置日志
|
||||||
|
|
|
||||||
|
|
@ -48,57 +48,27 @@ class ChengduMahjongEngine:
|
||||||
def play_turn(self):
|
def play_turn(self):
|
||||||
current_player = self.state.current_player
|
current_player = self.state.current_player
|
||||||
|
|
||||||
if self.state.draw_counts[current_player] == 0:
|
# 玩家摸牌逻辑
|
||||||
# 判断庄家是否天胡
|
draw_tile(self)
|
||||||
tianhu = self.state.can_win(
|
|
||||||
self.state.hands[current_player],
|
|
||||||
self.state.melds[current_player],
|
|
||||||
self.state.missing_suits[current_player],
|
|
||||||
)
|
|
||||||
|
|
||||||
if tianhu:
|
# 玩家选择一张牌打出
|
||||||
# 天胡结算
|
tile = random_choice(self.state.hands[current_player], self.state.missing_suits[current_player])
|
||||||
self.state.winners.append(current_player)
|
logger.info(f"玩家 {current_player} 选择打牌: {tile}")
|
||||||
self.state.print_game_state(current_player)
|
|
||||||
for i in range(4):
|
|
||||||
if i != current_player:
|
|
||||||
self.state.scores[i] -= self.state.bottom_score * 2 ** 5
|
|
||||||
self.state.scores[current_player] += self.state.bottom_score * 2 ** 5
|
|
||||||
else:
|
|
||||||
# 判断是否可以杠
|
|
||||||
if self.state.hands[current_player].can_gang():
|
|
||||||
# 获取杠牌类型
|
|
||||||
gang_type = self.state.hands[current_player].get_gang_type()
|
|
||||||
|
|
||||||
# AI 决策是否杠牌
|
# 检查其他玩家是否可以对该牌进行操作
|
||||||
if should_gang(current_player, self.state, gang_type):
|
actions_taken = self.check_other_players(tile)
|
||||||
self.state.print_game_state(current_player)
|
|
||||||
|
|
||||||
# 计算杠牌得分
|
if not actions_taken:
|
||||||
if gang_type == "暗杠":
|
# 将牌加入弃牌堆
|
||||||
# 暗杠分数结算
|
self.state.discards[current_player].append(tile)
|
||||||
for i in range(4):
|
logger.info(f"玩家 {current_player} 打出的牌 {tile} 没有触发其他玩家的操作")
|
||||||
if i != current_player:
|
|
||||||
self.state.scores[i] -= self.state.bottom_score * 2 ** 2 # 扣分
|
|
||||||
self.state.scores[current_player] += self.state.bottom_score * 2 ** 2 * 3 # 加分
|
|
||||||
elif gang_type == "明杠":
|
|
||||||
# 明杠分数结算
|
|
||||||
for i in range(4):
|
|
||||||
if i != current_player:
|
|
||||||
self.state.scores[i] -= self.state.bottom_score * 2 ** 1 # 扣分
|
|
||||||
self.state.scores[current_player] += self.state.bottom_score * 2 ** 1 * 3 # 加分
|
|
||||||
|
|
||||||
logger.info(f"玩家 {current_player} 杠牌,类型: {gang_type}")
|
# 切换到下一位玩家
|
||||||
else:
|
self.state.current_player = (current_player + 1) % 4
|
||||||
# 当前玩家出牌
|
logger.info(f"轮到玩家 {self.state.current_player} 出牌")
|
||||||
tile = random_choice(self.state.hands[current_player], self.state.missing_suits[current_player])
|
|
||||||
logger.info(f"玩家 {current_player} 打出牌: {tile}")
|
|
||||||
|
|
||||||
# 检查其他玩家的反应
|
# 检查游戏结束条件
|
||||||
if not check_other_players(self,tile):
|
self.check_game_over()
|
||||||
# 没有触发其他玩家操作,移动到下一个玩家
|
|
||||||
self.state.current_player = (current_player + 1) % 4
|
|
||||||
draw_tile(self)
|
|
||||||
|
|
||||||
def check_game_over(self):
|
def check_game_over(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -41,33 +41,35 @@ class ChengduMahjongEnv(gym.Env):
|
||||||
:return: obs, reward, done, info
|
:return: obs, reward, done, info
|
||||||
"""
|
"""
|
||||||
current_player = self.engine.state.current_player
|
current_player = self.engine.state.current_player
|
||||||
|
hand = self.engine.state.hands[current_player].tiles # 当前玩家手牌
|
||||||
|
|
||||||
# **1. 检查动作是否合法并执行**
|
# **1. 执行动作并检查合法性**
|
||||||
if action < 14: # 打牌动作
|
if action < len(hand): # 打牌动作
|
||||||
if action >= len(self.engine.state.hands[current_player].tiles):
|
tile = hand[action]
|
||||||
raise ValueError(f"动作 {action} 超出手牌范围")
|
|
||||||
tile = self.engine.state.hands[current_player].tiles[action]
|
|
||||||
logger.info(f"玩家 {current_player} 选择打牌: {tile}")
|
logger.info(f"玩家 {current_player} 选择打牌: {tile}")
|
||||||
self.engine.check_other_players(tile)
|
self.engine.check_other_players(tile)
|
||||||
elif action == 14: # 碰
|
elif action == 14: # 碰
|
||||||
tile_to_peng = self._get_tile_for_special_action("peng")
|
tile_to_peng = self._get_tile_for_special_action("peng")
|
||||||
if tile_to_peng:
|
if tile_to_peng:
|
||||||
handle_peng(self.engine,current_player, tile_to_peng)
|
handle_peng(self.engine, current_player, tile_to_peng)
|
||||||
|
logger.info(f"玩家 {current_player} 碰了牌: {tile_to_peng}")
|
||||||
else:
|
else:
|
||||||
logger.warning("碰动作无效,未满足条件")
|
logger.warning("碰动作无效,未满足条件")
|
||||||
elif action == 15: # 杠
|
elif action == 15: # 杠
|
||||||
tile_to_gang = self._get_tile_for_special_action("gang")
|
tile_to_gang = self._get_tile_for_special_action("gang")
|
||||||
if tile_to_gang:
|
if tile_to_gang:
|
||||||
handle_gang(current_player, tile_to_gang, mode="an") # 默认暗杠
|
handle_gang(self.engine, current_player, tile_to_gang, mode="an")
|
||||||
|
logger.info(f"玩家 {current_player} 杠了牌: {tile_to_gang}")
|
||||||
else:
|
else:
|
||||||
logger.warning("杠动作无效,未满足条件")
|
logger.warning("杠动作无效,未满足条件")
|
||||||
elif action == 16: # 胡
|
elif action == 16: # 胡
|
||||||
if self.engine.state.can_win(
|
if self.engine.state.can_win(
|
||||||
self.engine.state.hands[current_player],
|
self.engine.state.hands[current_player],
|
||||||
self.engine.state.melds[current_player],
|
self.engine.state.melds[current_player],
|
||||||
self.engine.state.missing_suits[current_player]
|
self.engine.state.missing_suits[current_player]
|
||||||
):
|
):
|
||||||
handle_win(current_player, None, None)
|
handle_win(self.engine, current_player, None, None)
|
||||||
|
logger.info(f"玩家 {current_player} 胡牌!")
|
||||||
else:
|
else:
|
||||||
logger.warning("胡动作无效,未满足条件")
|
logger.warning("胡动作无效,未满足条件")
|
||||||
else:
|
else:
|
||||||
|
|
@ -90,6 +92,8 @@ class ChengduMahjongEnv(gym.Env):
|
||||||
}
|
}
|
||||||
return obs, reward, done, info
|
return obs, reward, done, info
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _get_observation(self):
|
def _get_observation(self):
|
||||||
"""
|
"""
|
||||||
提取当前玩家的观察空间
|
提取当前玩家的观察空间
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue