diff --git a/models/ppo_doudizhu_model.zip b/models/ppo_doudizhu_model.zip index 9390e81..469bf78 100644 Binary files a/models/ppo_doudizhu_model.zip and b/models/ppo_doudizhu_model.zip differ diff --git a/src/engine/dizhu/dizhu_engine.py b/src/engine/dizhu/dizhu_engine.py index a1dc12c..ec02f1a 100644 --- a/src/engine/dizhu/dizhu_engine.py +++ b/src/engine/dizhu/dizhu_engine.py @@ -2,7 +2,7 @@ import numpy as np from loguru import logger from src.engine.dizhu.player_state import PlayerState from src.engine.dizhu.deck import Deck -from src.engine.dizhu.utils import card_to_string +from src.engine.dizhu.utils import card_to_string, detect_card_type class DiZhuEngine: @@ -12,6 +12,7 @@ class DiZhuEngine: self.landlord_index = -1 # 地主索引 self.current_player_index = 0 # 当前玩家索引 self.landlord_cards = [] # 地主牌 + self.current_pile = None # 当前牌面上的牌 self.game_over = False # 是否游戏结束 def reset(self): @@ -39,39 +40,54 @@ class DiZhuEngine: logger.info(f"玩家 {i + 1} ({player.role}) 手牌: {player.get_hand_cards_as_strings()}") def get_current_player(self): - """ - 获取当前玩家对象 - """ + """获取当前玩家对象""" current_player = self.players[self.current_player_index] logger.info(f"当前玩家: 玩家 {self.current_player_index + 1} ({current_player.role})") logger.info(f"当前玩家手牌: {current_player.get_hand_cards_as_strings()}") return current_player def step(self, action): - """ - 执行动作并更新状态 - :param action: 当前玩家的动作(可以是 'pass' 或一个动作列表) - """ + """执行动作并更新状态""" current_player = self.get_current_player() if action == "pass": logger.info(f"玩家 {self.current_player_index + 1} 选择过牌") + self.pass_count += 1 current_player.history.append([]) + if self.pass_count == 2: # 所有玩家连续过牌 + self.current_pile = None + logger.info("所有玩家连续过牌,清空牌面") else: + # 确保动作是一个列表(可以是多张牌) if not isinstance(action, list): action = [action] + + # 检查出牌是否在手牌中 if not all(card in current_player.hand_cards for card in action): logger.error(f"玩家 {self.current_player_index + 1} 的动作非法: {action}") raise ValueError(f"玩家手牌不足以完成此次出牌: {action}") - # 移除出牌 + # 检查牌型是否合法 + if not detect_card_type(action): + logger.error(f"玩家 {self.current_player_index + 1} 出牌不合法: {action}") + raise ValueError(f"出牌牌型非法: {action}") + + # 检查是否能打过当前牌面 + if self.current_pile and not self._can_beat(self.current_pile, action): + logger.error(f"玩家 {self.current_player_index + 1} 出牌无法打过当前牌面: {action}") + raise ValueError(f"出牌无法打过当前牌面: {action}") + + # 出牌成功 + self.pass_count = 0 + self.current_pile = action # 更新牌面 + logger.info(f"玩家 {self.current_player_index + 1} 出牌: {[card_to_string(card) for card in action]}") + + # 从手牌中移除 for card in action: current_player.hand_cards.remove(card) current_player.history.append(action) - logger.info(f"玩家 {self.current_player_index + 1} 出牌: {[card_to_string(card) for card in action]}") - logger.info(f"玩家 {self.current_player_index + 1} 剩余手牌: {current_player.get_hand_cards_as_strings()}") - + # 检查游戏是否结束 if not current_player.hand_cards: self.game_over = True logger.info(f"游戏结束!玩家 {self.current_player_index + 1} ({current_player.role}) 获胜") @@ -81,6 +97,61 @@ class DiZhuEngine: self.current_player_index = (self.current_player_index + 1) % 3 logger.info(f"切换到玩家 {self.current_player_index + 1}") + def get_action_space(self): + """ + 动态生成当前动作空间。 + :return: 合法动作的列表 + """ + valid_actions = ["pass"] + current_player = self.get_current_player() + + # 遍历玩家手牌,生成所有可能的组合 + hand_cards = current_player.hand_cards + valid_actions.extend(self._generate_valid_combinations(hand_cards)) + + return valid_actions + + def _generate_valid_combinations(self, cards): + """ + 根据手牌生成所有合法牌型组合 + :param cards: 当前玩家的手牌 + :return: 合法牌型的列表 + """ + # 示例:生成单牌、对子和三张的合法组合 + from itertools import combinations + valid_combinations = [] + for i in range(1, len(cards) + 1): + for combo in combinations(cards, i): + if detect_card_type(list(combo)): # 检查是否为合法牌型 + valid_combinations.append(list(combo)) + return valid_combinations + + + def _can_beat(self, current_pile, action): + """ + 检查当前动作是否能打过牌面上的牌。 + :param current_pile: 当前牌面上的牌(列表) + :param action: 当前玩家要出的牌(列表) + :return: True 如果可以打过,否则 False + """ + current_type = detect_card_type(current_pile) + action_type = detect_card_type(action) + + if not current_type or not action_type: + return False # 非法牌型 + + # 火箭可以压任何牌 + if action_type == "火箭": + return True + # 炸弹可以压非炸弹的牌型 + if action_type == "炸弹" and current_type != "炸弹": + return True + # 同牌型比较大小 + if current_type == action_type: + return max(action) > max(current_pile) + + return False # 其他情况不合法 + def get_game_state(self): """ 返回当前游戏状态,包括玩家手牌、出牌历史和当前玩家。 diff --git a/src/engine/dizhu/scroing.py b/src/engine/dizhu/scroing.py new file mode 100644 index 0000000..6d4b0c8 --- /dev/null +++ b/src/engine/dizhu/scroing.py @@ -0,0 +1,32 @@ +class DouDiZhuScoring: + def __init__(self, base_score=1): + self.base_score = base_score # 底分 + self.multiplier = 1 # 倍数 + self.landlord_win = False # 地主是否胜利 + + def apply_event(self, event): + """ + 根据游戏事件调整倍数。 + :param event: 事件类型,如 "炸弹", "火箭", "春天", "反春天" + """ + if event in ["炸弹", "火箭", "春天", "反春天"]: + self.multiplier *= 2 + elif event == "抢地主": + self.multiplier += 1 + + def calculate_score(self, landlord_win): + """ + 计算最终分数。 + :param landlord_win: 地主是否胜利 + :return: 地主分数,农民分数 + """ + self.landlord_win = landlord_win + if landlord_win: + landlord_score = 2 * self.base_score * self.multiplier + farmer_score = -self.base_score * self.multiplier + else: + landlord_score = -2 * self.base_score * self.multiplier + farmer_score = self.base_score * self.multiplier + + return landlord_score, farmer_score + diff --git a/src/engine/dizhu/utils.py b/src/engine/dizhu/utils.py index 32d4172..86231bc 100644 --- a/src/engine/dizhu/utils.py +++ b/src/engine/dizhu/utils.py @@ -20,3 +20,29 @@ def card_to_string(card_index): return "大王" else: raise ValueError(f"无效的牌索引: {card_index}") + + +def detect_card_type(cards): + """ + 检测牌型 + :param cards: 玩家出的牌(列表) + :return: 牌型字符串或 None(非法牌型) + """ + if len(cards) == 1: + return "单牌" + if len(cards) == 2 and cards[0] == cards[1]: + return "对子" + if len(cards) == 3 and cards[0] == cards[1] == cards[2]: + return "三张" + if len(cards) == 4 and cards[0] == cards[1] == cards[2] == cards[3]: + return "炸弹" + # 三带一 + if len(cards) == 4: + counts = {card: cards.count(card) for card in set(cards)} + if 3 in counts.values(): + return "三带一" + # 顺子 + if len(cards) >= 5 and all(cards[i] + 1 == cards[i + 1] for i in range(len(cards) - 1)): + return "顺子" + # TODO: 实现其他牌型判断(如连对、飞机等) + return None diff --git a/src/environment/dizhu_env.py b/src/environment/dizhu_env.py index dc7523c..a8326f0 100644 --- a/src/environment/dizhu_env.py +++ b/src/environment/dizhu_env.py @@ -21,21 +21,33 @@ class DouDiZhuEnv(gym.Env): def step(self, action): """执行动作并更新环境""" try: - # 根据动作索引解析出具体的出牌动作 + reward = 0 # 初始化奖励 if action == 0: + # 玩家选择过牌 self.engine.step("pass") + reward -= 0.5 # 对频繁过牌给予轻微惩罚 else: + # 玩家选择出牌 card_index = action - 1 # 动作索引 1-54 对应 54 张牌 - self.engine.step([card_index]) + previous_hand_count = len(self.engine.get_current_player().hand_cards) # 出牌前手牌数 + self.engine.step([card_index]) # 执行动作 + current_hand_count = len(self.engine.get_current_player().hand_cards) # 出牌后手牌数 - # 更新游戏状态 + # 奖励根据减少的手牌数量计算 + reward += (previous_hand_count - current_hand_count) * 1.0 + + # 检查游戏是否结束 done = self.engine.game_over - reward = 1 if done else 0 # 简单奖励:胜利得 1 分,其他情况得 0 + if done: + # 胜利时给予较大的奖励 + reward += 10 + return self._get_observation(), reward, done, {} except ValueError as e: - # 如果玩家执行了无效动作,给予惩罚 - return self._get_observation(), -1, False, {"error": str(e)} + # 对无效动作设置较大的负奖励 + reward -= 5 + return self._get_observation(), reward, False, {"error": str(e)} def _get_observation(self): """获取当前玩家的状态""" diff --git a/test_dizhu.py b/test_dizhu.py new file mode 100644 index 0000000..678ee94 --- /dev/null +++ b/test_dizhu.py @@ -0,0 +1,35 @@ +from stable_baselines3 import PPO +from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境 +from loguru import logger + + +def test_dizhu_model(): # 确保函数名以 test_ 开头 + # 创建斗地主环境 + env = DouDiZhuEnv() + + # 加载已训练的模型 + model_path = "./models/ppo_doudizhu_model.zip" # 确保路径正确 + logger.info(f"加载模型: {model_path}") + try: + model = PPO.load(model_path) + except Exception as e: + logger.error(f"加载模型失败: {e}") + return + + # 测试模型 + obs = env.reset() + done = False + total_reward = 0 + logger.info("开始测试斗地主模型...") + + max_steps = 1000 # 设置最大步数 + step_count = 0 + + while not done and step_count < max_steps: + action, _ = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + total_reward += reward + step_count += 1 + logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") + + logger.info(f"测试完成,总奖励: {total_reward}") diff --git a/tests/models/test_dizhu.py b/tests/models/test_dizhu.py deleted file mode 100644 index 6ebecb1..0000000 --- a/tests/models/test_dizhu.py +++ /dev/null @@ -1,33 +0,0 @@ -from stable_baselines3 import PPO -from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境 -from loguru import logger - - -def test_dizhu_model(): - # 创建斗地主环境 - env = DouDiZhuEnv() - - # 加载已训练的模型 - model_path = "../models/ppo_doudizhu_model.zip" - logger.info(f"加载模型: {model_path}") - model = PPO.load(model_path) - - # 测试模型 - obs = env.reset() - done = False - total_reward = 0 - logger.info("开始测试斗地主模型...") - - while not done: - action, _ = model.predict(obs, deterministic=True) # 使用训练好的模型预测动作 - obs, reward, done, info = env.step(action) # 执行动作 - total_reward += reward - - # 输出当前状态 - logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") - - logger.info(f"测试完成,总奖励: {total_reward}") - - -if __name__ == "__main__": - test_dizhu_model()