Compare commits
2 Commits
96f0fbdcd7
...
5336990f54
| Author | SHA1 | Date |
|---|---|---|
|
|
5336990f54 | |
|
|
deff3cb921 |
Binary file not shown.
|
|
@ -2,7 +2,7 @@ import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from src.engine.dizhu.player_state import PlayerState
|
from src.engine.dizhu.player_state import PlayerState
|
||||||
from src.engine.dizhu.deck import Deck
|
from src.engine.dizhu.deck import Deck
|
||||||
from src.engine.dizhu.utils import card_to_string
|
from src.engine.dizhu.utils import card_to_string, detect_card_type
|
||||||
|
|
||||||
|
|
||||||
class DiZhuEngine:
|
class DiZhuEngine:
|
||||||
|
|
@ -12,6 +12,7 @@ class DiZhuEngine:
|
||||||
self.landlord_index = -1 # 地主索引
|
self.landlord_index = -1 # 地主索引
|
||||||
self.current_player_index = 0 # 当前玩家索引
|
self.current_player_index = 0 # 当前玩家索引
|
||||||
self.landlord_cards = [] # 地主牌
|
self.landlord_cards = [] # 地主牌
|
||||||
|
self.current_pile = None # 当前牌面上的牌
|
||||||
self.game_over = False # 是否游戏结束
|
self.game_over = False # 是否游戏结束
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|
@ -39,39 +40,54 @@ class DiZhuEngine:
|
||||||
logger.info(f"玩家 {i + 1} ({player.role}) 手牌: {player.get_hand_cards_as_strings()}")
|
logger.info(f"玩家 {i + 1} ({player.role}) 手牌: {player.get_hand_cards_as_strings()}")
|
||||||
|
|
||||||
def get_current_player(self):
|
def get_current_player(self):
|
||||||
"""
|
"""获取当前玩家对象"""
|
||||||
获取当前玩家对象
|
|
||||||
"""
|
|
||||||
current_player = self.players[self.current_player_index]
|
current_player = self.players[self.current_player_index]
|
||||||
logger.info(f"当前玩家: 玩家 {self.current_player_index + 1} ({current_player.role})")
|
logger.info(f"当前玩家: 玩家 {self.current_player_index + 1} ({current_player.role})")
|
||||||
logger.info(f"当前玩家手牌: {current_player.get_hand_cards_as_strings()}")
|
logger.info(f"当前玩家手牌: {current_player.get_hand_cards_as_strings()}")
|
||||||
return current_player
|
return current_player
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""
|
"""执行动作并更新状态"""
|
||||||
执行动作并更新状态
|
|
||||||
:param action: 当前玩家的动作(可以是 'pass' 或一个动作列表)
|
|
||||||
"""
|
|
||||||
current_player = self.get_current_player()
|
current_player = self.get_current_player()
|
||||||
|
|
||||||
if action == "pass":
|
if action == "pass":
|
||||||
logger.info(f"玩家 {self.current_player_index + 1} 选择过牌")
|
logger.info(f"玩家 {self.current_player_index + 1} 选择过牌")
|
||||||
|
self.pass_count += 1
|
||||||
current_player.history.append([])
|
current_player.history.append([])
|
||||||
|
if self.pass_count == 2: # 所有玩家连续过牌
|
||||||
|
self.current_pile = None
|
||||||
|
logger.info("所有玩家连续过牌,清空牌面")
|
||||||
else:
|
else:
|
||||||
|
# 确保动作是一个列表(可以是多张牌)
|
||||||
if not isinstance(action, list):
|
if not isinstance(action, list):
|
||||||
action = [action]
|
action = [action]
|
||||||
|
|
||||||
|
# 检查出牌是否在手牌中
|
||||||
if not all(card in current_player.hand_cards for card in action):
|
if not all(card in current_player.hand_cards for card in action):
|
||||||
logger.error(f"玩家 {self.current_player_index + 1} 的动作非法: {action}")
|
logger.error(f"玩家 {self.current_player_index + 1} 的动作非法: {action}")
|
||||||
raise ValueError(f"玩家手牌不足以完成此次出牌: {action}")
|
raise ValueError(f"玩家手牌不足以完成此次出牌: {action}")
|
||||||
|
|
||||||
# 移除出牌
|
# 检查牌型是否合法
|
||||||
|
if not detect_card_type(action):
|
||||||
|
logger.error(f"玩家 {self.current_player_index + 1} 出牌不合法: {action}")
|
||||||
|
raise ValueError(f"出牌牌型非法: {action}")
|
||||||
|
|
||||||
|
# 检查是否能打过当前牌面
|
||||||
|
if self.current_pile and not self._can_beat(self.current_pile, action):
|
||||||
|
logger.error(f"玩家 {self.current_player_index + 1} 出牌无法打过当前牌面: {action}")
|
||||||
|
raise ValueError(f"出牌无法打过当前牌面: {action}")
|
||||||
|
|
||||||
|
# 出牌成功
|
||||||
|
self.pass_count = 0
|
||||||
|
self.current_pile = action # 更新牌面
|
||||||
|
logger.info(f"玩家 {self.current_player_index + 1} 出牌: {[card_to_string(card) for card in action]}")
|
||||||
|
|
||||||
|
# 从手牌中移除
|
||||||
for card in action:
|
for card in action:
|
||||||
current_player.hand_cards.remove(card)
|
current_player.hand_cards.remove(card)
|
||||||
current_player.history.append(action)
|
current_player.history.append(action)
|
||||||
|
|
||||||
logger.info(f"玩家 {self.current_player_index + 1} 出牌: {[card_to_string(card) for card in action]}")
|
# 检查游戏是否结束
|
||||||
logger.info(f"玩家 {self.current_player_index + 1} 剩余手牌: {current_player.get_hand_cards_as_strings()}")
|
|
||||||
|
|
||||||
if not current_player.hand_cards:
|
if not current_player.hand_cards:
|
||||||
self.game_over = True
|
self.game_over = True
|
||||||
logger.info(f"游戏结束!玩家 {self.current_player_index + 1} ({current_player.role}) 获胜")
|
logger.info(f"游戏结束!玩家 {self.current_player_index + 1} ({current_player.role}) 获胜")
|
||||||
|
|
@ -81,6 +97,61 @@ class DiZhuEngine:
|
||||||
self.current_player_index = (self.current_player_index + 1) % 3
|
self.current_player_index = (self.current_player_index + 1) % 3
|
||||||
logger.info(f"切换到玩家 {self.current_player_index + 1}")
|
logger.info(f"切换到玩家 {self.current_player_index + 1}")
|
||||||
|
|
||||||
|
def get_action_space(self):
|
||||||
|
"""
|
||||||
|
动态生成当前动作空间。
|
||||||
|
:return: 合法动作的列表
|
||||||
|
"""
|
||||||
|
valid_actions = ["pass"]
|
||||||
|
current_player = self.get_current_player()
|
||||||
|
|
||||||
|
# 遍历玩家手牌,生成所有可能的组合
|
||||||
|
hand_cards = current_player.hand_cards
|
||||||
|
valid_actions.extend(self._generate_valid_combinations(hand_cards))
|
||||||
|
|
||||||
|
return valid_actions
|
||||||
|
|
||||||
|
def _generate_valid_combinations(self, cards):
|
||||||
|
"""
|
||||||
|
根据手牌生成所有合法牌型组合
|
||||||
|
:param cards: 当前玩家的手牌
|
||||||
|
:return: 合法牌型的列表
|
||||||
|
"""
|
||||||
|
# 示例:生成单牌、对子和三张的合法组合
|
||||||
|
from itertools import combinations
|
||||||
|
valid_combinations = []
|
||||||
|
for i in range(1, len(cards) + 1):
|
||||||
|
for combo in combinations(cards, i):
|
||||||
|
if detect_card_type(list(combo)): # 检查是否为合法牌型
|
||||||
|
valid_combinations.append(list(combo))
|
||||||
|
return valid_combinations
|
||||||
|
|
||||||
|
|
||||||
|
def _can_beat(self, current_pile, action):
|
||||||
|
"""
|
||||||
|
检查当前动作是否能打过牌面上的牌。
|
||||||
|
:param current_pile: 当前牌面上的牌(列表)
|
||||||
|
:param action: 当前玩家要出的牌(列表)
|
||||||
|
:return: True 如果可以打过,否则 False
|
||||||
|
"""
|
||||||
|
current_type = detect_card_type(current_pile)
|
||||||
|
action_type = detect_card_type(action)
|
||||||
|
|
||||||
|
if not current_type or not action_type:
|
||||||
|
return False # 非法牌型
|
||||||
|
|
||||||
|
# 火箭可以压任何牌
|
||||||
|
if action_type == "火箭":
|
||||||
|
return True
|
||||||
|
# 炸弹可以压非炸弹的牌型
|
||||||
|
if action_type == "炸弹" and current_type != "炸弹":
|
||||||
|
return True
|
||||||
|
# 同牌型比较大小
|
||||||
|
if current_type == action_type:
|
||||||
|
return max(action) > max(current_pile)
|
||||||
|
|
||||||
|
return False # 其他情况不合法
|
||||||
|
|
||||||
def get_game_state(self):
|
def get_game_state(self):
|
||||||
"""
|
"""
|
||||||
返回当前游戏状态,包括玩家手牌、出牌历史和当前玩家。
|
返回当前游戏状态,包括玩家手牌、出牌历史和当前玩家。
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
class DouDiZhuScoring:
|
||||||
|
def __init__(self, base_score=1):
|
||||||
|
self.base_score = base_score # 底分
|
||||||
|
self.multiplier = 1 # 倍数
|
||||||
|
self.landlord_win = False # 地主是否胜利
|
||||||
|
|
||||||
|
def apply_event(self, event):
|
||||||
|
"""
|
||||||
|
根据游戏事件调整倍数。
|
||||||
|
:param event: 事件类型,如 "炸弹", "火箭", "春天", "反春天"
|
||||||
|
"""
|
||||||
|
if event in ["炸弹", "火箭", "春天", "反春天"]:
|
||||||
|
self.multiplier *= 2
|
||||||
|
elif event == "抢地主":
|
||||||
|
self.multiplier += 1
|
||||||
|
|
||||||
|
def calculate_score(self, landlord_win):
|
||||||
|
"""
|
||||||
|
计算最终分数。
|
||||||
|
:param landlord_win: 地主是否胜利
|
||||||
|
:return: 地主分数,农民分数
|
||||||
|
"""
|
||||||
|
self.landlord_win = landlord_win
|
||||||
|
if landlord_win:
|
||||||
|
landlord_score = 2 * self.base_score * self.multiplier
|
||||||
|
farmer_score = -self.base_score * self.multiplier
|
||||||
|
else:
|
||||||
|
landlord_score = -2 * self.base_score * self.multiplier
|
||||||
|
farmer_score = self.base_score * self.multiplier
|
||||||
|
|
||||||
|
return landlord_score, farmer_score
|
||||||
|
|
||||||
|
|
@ -20,3 +20,29 @@ def card_to_string(card_index):
|
||||||
return "大王"
|
return "大王"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"无效的牌索引: {card_index}")
|
raise ValueError(f"无效的牌索引: {card_index}")
|
||||||
|
|
||||||
|
|
||||||
|
def detect_card_type(cards):
|
||||||
|
"""
|
||||||
|
检测牌型
|
||||||
|
:param cards: 玩家出的牌(列表)
|
||||||
|
:return: 牌型字符串或 None(非法牌型)
|
||||||
|
"""
|
||||||
|
if len(cards) == 1:
|
||||||
|
return "单牌"
|
||||||
|
if len(cards) == 2 and cards[0] == cards[1]:
|
||||||
|
return "对子"
|
||||||
|
if len(cards) == 3 and cards[0] == cards[1] == cards[2]:
|
||||||
|
return "三张"
|
||||||
|
if len(cards) == 4 and cards[0] == cards[1] == cards[2] == cards[3]:
|
||||||
|
return "炸弹"
|
||||||
|
# 三带一
|
||||||
|
if len(cards) == 4:
|
||||||
|
counts = {card: cards.count(card) for card in set(cards)}
|
||||||
|
if 3 in counts.values():
|
||||||
|
return "三带一"
|
||||||
|
# 顺子
|
||||||
|
if len(cards) >= 5 and all(cards[i] + 1 == cards[i + 1] for i in range(len(cards) - 1)):
|
||||||
|
return "顺子"
|
||||||
|
# TODO: 实现其他牌型判断(如连对、飞机等)
|
||||||
|
return None
|
||||||
|
|
|
||||||
|
|
@ -21,21 +21,33 @@ class DouDiZhuEnv(gym.Env):
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""执行动作并更新环境"""
|
"""执行动作并更新环境"""
|
||||||
try:
|
try:
|
||||||
# 根据动作索引解析出具体的出牌动作
|
reward = 0 # 初始化奖励
|
||||||
if action == 0:
|
if action == 0:
|
||||||
|
# 玩家选择过牌
|
||||||
self.engine.step("pass")
|
self.engine.step("pass")
|
||||||
|
reward -= 0.5 # 对频繁过牌给予轻微惩罚
|
||||||
else:
|
else:
|
||||||
|
# 玩家选择出牌
|
||||||
card_index = action - 1 # 动作索引 1-54 对应 54 张牌
|
card_index = action - 1 # 动作索引 1-54 对应 54 张牌
|
||||||
self.engine.step([card_index])
|
previous_hand_count = len(self.engine.get_current_player().hand_cards) # 出牌前手牌数
|
||||||
|
self.engine.step([card_index]) # 执行动作
|
||||||
|
current_hand_count = len(self.engine.get_current_player().hand_cards) # 出牌后手牌数
|
||||||
|
|
||||||
# 更新游戏状态
|
# 奖励根据减少的手牌数量计算
|
||||||
|
reward += (previous_hand_count - current_hand_count) * 1.0
|
||||||
|
|
||||||
|
# 检查游戏是否结束
|
||||||
done = self.engine.game_over
|
done = self.engine.game_over
|
||||||
reward = 1 if done else 0 # 简单奖励:胜利得 1 分,其他情况得 0
|
if done:
|
||||||
|
# 胜利时给予较大的奖励
|
||||||
|
reward += 10
|
||||||
|
|
||||||
return self._get_observation(), reward, done, {}
|
return self._get_observation(), reward, done, {}
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# 如果玩家执行了无效动作,给予惩罚
|
# 对无效动作设置较大的负奖励
|
||||||
return self._get_observation(), -1, False, {"error": str(e)}
|
reward -= 5
|
||||||
|
return self._get_observation(), reward, False, {"error": str(e)}
|
||||||
|
|
||||||
def _get_observation(self):
|
def _get_observation(self):
|
||||||
"""获取当前玩家的状态"""
|
"""获取当前玩家的状态"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
def test_dizhu_model(): # 确保函数名以 test_ 开头
|
||||||
|
# 创建斗地主环境
|
||||||
|
env = DouDiZhuEnv()
|
||||||
|
|
||||||
|
# 加载已训练的模型
|
||||||
|
model_path = "./models/ppo_doudizhu_model.zip" # 确保路径正确
|
||||||
|
logger.info(f"加载模型: {model_path}")
|
||||||
|
try:
|
||||||
|
model = PPO.load(model_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载模型失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 测试模型
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
total_reward = 0
|
||||||
|
logger.info("开始测试斗地主模型...")
|
||||||
|
|
||||||
|
max_steps = 1000 # 设置最大步数
|
||||||
|
step_count = 0
|
||||||
|
|
||||||
|
while not done and step_count < max_steps:
|
||||||
|
action, _ = model.predict(obs, deterministic=True)
|
||||||
|
obs, reward, done, info = env.step(action)
|
||||||
|
total_reward += reward
|
||||||
|
step_count += 1
|
||||||
|
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
|
||||||
|
|
||||||
|
logger.info(f"测试完成,总奖励: {total_reward}")
|
||||||
Loading…
Reference in New Issue