parent
5336990f54
commit
adeb153c8a
|
|
@ -12,6 +12,7 @@ class DiZhuEngine:
|
||||||
self.landlord_index = -1 # 地主索引
|
self.landlord_index = -1 # 地主索引
|
||||||
self.current_player_index = 0 # 当前玩家索引
|
self.current_player_index = 0 # 当前玩家索引
|
||||||
self.landlord_cards = [] # 地主牌
|
self.landlord_cards = [] # 地主牌
|
||||||
|
self.last_player = None # 最后出牌的玩家索引
|
||||||
self.current_pile = None # 当前牌面上的牌
|
self.current_pile = None # 当前牌面上的牌
|
||||||
self.game_over = False # 是否游戏结束
|
self.game_over = False # 是否游戏结束
|
||||||
|
|
||||||
|
|
@ -47,39 +48,40 @@ class DiZhuEngine:
|
||||||
return current_player
|
return current_player
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""执行动作并更新状态"""
|
"""
|
||||||
|
执行动作并更新状态
|
||||||
|
:param action: 当前玩家的动作(可以是 'pass' 或一个动作列表)
|
||||||
|
"""
|
||||||
current_player = self.get_current_player()
|
current_player = self.get_current_player()
|
||||||
|
|
||||||
if action == "pass":
|
if action == "pass":
|
||||||
logger.info(f"玩家 {self.current_player_index + 1} 选择过牌")
|
logger.info(f"玩家 {self.current_player_index + 1} 选择过牌")
|
||||||
self.pass_count += 1
|
self.pass_count += 1
|
||||||
current_player.history.append([])
|
|
||||||
if self.pass_count == 2: # 所有玩家连续过牌
|
# 如果所有其他玩家都过牌,允许最后出牌玩家再次出牌
|
||||||
self.current_pile = None
|
if self.pass_count == 2: # 两名玩家连续过牌
|
||||||
logger.info("所有玩家连续过牌,清空牌面")
|
self.current_player_index = self.last_player
|
||||||
|
self.pass_count = 0 # 重置过牌计数
|
||||||
|
self.current_pile = None # 清空当前牌面
|
||||||
|
logger.info(f"所有玩家过牌,玩家 {self.current_player_index + 1} 可以继续出牌")
|
||||||
else:
|
else:
|
||||||
# 确保动作是一个列表(可以是多张牌)
|
# 出牌逻辑
|
||||||
if not isinstance(action, list):
|
if not isinstance(action, list):
|
||||||
action = [action]
|
action = [action]
|
||||||
|
|
||||||
# 检查出牌是否在手牌中
|
|
||||||
if not all(card in current_player.hand_cards for card in action):
|
if not all(card in current_player.hand_cards for card in action):
|
||||||
logger.error(f"玩家 {self.current_player_index + 1} 的动作非法: {action}")
|
logger.error(f"玩家 {self.current_player_index + 1} 的动作非法: {action}")
|
||||||
raise ValueError(f"玩家手牌不足以完成此次出牌: {action}")
|
raise ValueError(f"玩家手牌不足以完成此次出牌: {action}")
|
||||||
|
|
||||||
# 检查牌型是否合法
|
|
||||||
if not detect_card_type(action):
|
|
||||||
logger.error(f"玩家 {self.current_player_index + 1} 出牌不合法: {action}")
|
|
||||||
raise ValueError(f"出牌牌型非法: {action}")
|
|
||||||
|
|
||||||
# 检查是否能打过当前牌面
|
|
||||||
if self.current_pile and not self._can_beat(self.current_pile, action):
|
if self.current_pile and not self._can_beat(self.current_pile, action):
|
||||||
logger.error(f"玩家 {self.current_player_index + 1} 出牌无法打过当前牌面: {action}")
|
logger.error(f"玩家 {self.current_player_index + 1} 出牌非法: {action}")
|
||||||
raise ValueError(f"出牌无法打过当前牌面: {action}")
|
raise ValueError(f"出牌无法打过当前牌面: {action}")
|
||||||
|
|
||||||
# 出牌成功
|
# 出牌成功
|
||||||
self.pass_count = 0
|
self.current_pile = action # 更新当前牌面
|
||||||
self.current_pile = action # 更新牌面
|
self.pass_count = 0 # 出牌后重置过牌计数
|
||||||
|
self.last_player = self.current_player_index # 更新最后出牌的玩家
|
||||||
|
|
||||||
logger.info(f"玩家 {self.current_player_index + 1} 出牌: {[card_to_string(card) for card in action]}")
|
logger.info(f"玩家 {self.current_player_index + 1} 出牌: {[card_to_string(card) for card in action]}")
|
||||||
|
|
||||||
# 从手牌中移除
|
# 从手牌中移除
|
||||||
|
|
@ -97,6 +99,7 @@ class DiZhuEngine:
|
||||||
self.current_player_index = (self.current_player_index + 1) % 3
|
self.current_player_index = (self.current_player_index + 1) % 3
|
||||||
logger.info(f"切换到玩家 {self.current_player_index + 1}")
|
logger.info(f"切换到玩家 {self.current_player_index + 1}")
|
||||||
|
|
||||||
|
|
||||||
def get_action_space(self):
|
def get_action_space(self):
|
||||||
"""
|
"""
|
||||||
动态生成当前动作空间。
|
动态生成当前动作空间。
|
||||||
|
|
@ -126,7 +129,6 @@ class DiZhuEngine:
|
||||||
valid_combinations.append(list(combo))
|
valid_combinations.append(list(combo))
|
||||||
return valid_combinations
|
return valid_combinations
|
||||||
|
|
||||||
|
|
||||||
def _can_beat(self, current_pile, action):
|
def _can_beat(self, current_pile, action):
|
||||||
"""
|
"""
|
||||||
检查当前动作是否能打过牌面上的牌。
|
检查当前动作是否能打过牌面上的牌。
|
||||||
|
|
@ -177,3 +179,30 @@ class DiZhuEngine:
|
||||||
logger.info(f"当前玩家索引: {self.current_player_index}")
|
logger.info(f"当前玩家索引: {self.current_player_index}")
|
||||||
logger.info(f"游戏是否结束: {self.game_over}")
|
logger.info(f"游戏是否结束: {self.game_over}")
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
def is_valid_play(self, cards):
|
||||||
|
"""
|
||||||
|
检查给定的牌是否为合法的斗地主牌型。
|
||||||
|
:param cards: 玩家出的牌(列表)
|
||||||
|
:return: True 如果是合法牌型,否则 False
|
||||||
|
"""
|
||||||
|
if len(cards) == 1:
|
||||||
|
return True # 单牌
|
||||||
|
if len(cards) == 2 and cards[0] == cards[1]:
|
||||||
|
return True # 对子
|
||||||
|
if len(cards) == 3 and cards[0] == cards[1] == cards[2]:
|
||||||
|
return True # 三张
|
||||||
|
if len(cards) == 4:
|
||||||
|
counts = {card: cards.count(card) for card in set(cards)}
|
||||||
|
if 3 in counts.values():
|
||||||
|
return True # 三带一
|
||||||
|
if all(count == 4 for count in counts.values()):
|
||||||
|
return True # 炸弹
|
||||||
|
if len(cards) >= 5:
|
||||||
|
# 顺子
|
||||||
|
sorted_cards = sorted(cards)
|
||||||
|
if all(sorted_cards[i] + 1 == sorted_cards[i + 1] for i in range(len(sorted_cards) - 1)):
|
||||||
|
return True
|
||||||
|
# TODO: 扩展支持更多牌型(如连对、飞机等)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,56 +1,73 @@
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
from src.engine.dizhu.dizhu_engine import DiZhuEngine # 引入地主引擎
|
from src.engine.dizhu.dizhu_engine import DiZhuEngine # 引入斗地主引擎
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class DouDiZhuEnv(gym.Env):
|
class DouDiZhuEnv(gym.Env):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(DouDiZhuEnv, self).__init__()
|
super(DouDiZhuEnv, self).__init__()
|
||||||
self.engine = DiZhuEngine() # 初始化斗地主引擎
|
self.engine = DiZhuEngine() # 初始化斗地主引擎
|
||||||
self.action_space = spaces.Discrete(55) # 假设最大动作空间为 55,表示可能的出牌和过牌
|
|
||||||
|
# 定义初始动作空间和观察空间
|
||||||
|
self.action_space = spaces.Discrete(55) # 假设最大动作空间为 55
|
||||||
self.observation_space = spaces.Dict({
|
self.observation_space = spaces.Dict({
|
||||||
"hand_cards": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int32), # 玩家手牌(独热编码)
|
"hand_cards": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int32), # 玩家手牌(独热编码)
|
||||||
"history": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int32), # 出牌历史
|
"history": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int32), # 出牌历史
|
||||||
|
"current_pile": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int32), # 当前牌面上的牌
|
||||||
|
"current_player": spaces.Discrete(3), # 当前玩家索引
|
||||||
})
|
})
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""重置游戏环境"""
|
"""重置游戏环境"""
|
||||||
self.engine.reset()
|
self.engine.reset()
|
||||||
|
logger.info("斗地主环境已重置")
|
||||||
return self._get_observation()
|
return self._get_observation()
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""执行动作并更新环境"""
|
"""执行动作并更新环境"""
|
||||||
try:
|
try:
|
||||||
reward = 0 # 初始化奖励
|
reward = 0 # 初始化奖励
|
||||||
if action == 0:
|
current_player = self.engine.get_current_player()
|
||||||
# 玩家选择过牌
|
logger.info(f"当前玩家: 玩家 {self.engine.current_player_index + 1} 手牌: {current_player.get_hand_cards_as_strings()}")
|
||||||
|
|
||||||
|
if action == 0: # 过牌
|
||||||
|
logger.info(f"玩家 {self.engine.current_player_index + 1} 选择过牌")
|
||||||
self.engine.step("pass")
|
self.engine.step("pass")
|
||||||
reward -= 0.5 # 对频繁过牌给予轻微惩罚
|
reward -= 0.5 # 对频繁过牌给予轻微惩罚
|
||||||
else:
|
else:
|
||||||
# 玩家选择出牌
|
# 玩家选择出牌
|
||||||
card_index = action - 1 # 动作索引 1-54 对应 54 张牌
|
action_cards = self._decode_action(action) # 解码动作为具体的牌型
|
||||||
previous_hand_count = len(self.engine.get_current_player().hand_cards) # 出牌前手牌数
|
logger.info(f"玩家 {self.engine.current_player_index + 1} 选择出牌: {action_cards}")
|
||||||
self.engine.step([card_index]) # 执行动作
|
|
||||||
current_hand_count = len(self.engine.get_current_player().hand_cards) # 出牌后手牌数
|
|
||||||
|
|
||||||
# 奖励根据减少的手牌数量计算
|
# 出牌前的手牌数量
|
||||||
|
previous_hand_count = len(current_player.hand_cards)
|
||||||
|
|
||||||
|
# 尝试执行动作
|
||||||
|
self.engine.step(action_cards)
|
||||||
|
|
||||||
|
# 出牌后的手牌数量
|
||||||
|
current_hand_count = len(current_player.hand_cards)
|
||||||
|
|
||||||
|
# 根据减少的手牌数量计算奖励
|
||||||
reward += (previous_hand_count - current_hand_count) * 1.0
|
reward += (previous_hand_count - current_hand_count) * 1.0
|
||||||
|
|
||||||
# 检查游戏是否结束
|
# 检查游戏是否结束
|
||||||
done = self.engine.game_over
|
done = self.engine.game_over
|
||||||
if done:
|
if done:
|
||||||
# 胜利时给予较大的奖励
|
reward += 10 # 胜利时给予较大的奖励
|
||||||
reward += 10
|
logger.info(f"游戏结束!胜利玩家: {self.engine.current_player_index + 1}")
|
||||||
|
|
||||||
return self._get_observation(), reward, done, {}
|
return self._get_observation(), reward, done, {}
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# 对无效动作设置较大的负奖励
|
# 对无效动作设置较大的负奖励
|
||||||
reward -= 5
|
logger.error(f"无效动作: {e}")
|
||||||
return self._get_observation(), reward, False, {"error": str(e)}
|
return self._get_observation(), -5, False, {"error": str(e)}
|
||||||
|
|
||||||
def _get_observation(self):
|
def _get_observation(self):
|
||||||
"""获取当前玩家的状态"""
|
"""获取当前玩家的观察空间"""
|
||||||
current_player = self.engine.get_current_player()
|
current_player = self.engine.get_current_player()
|
||||||
|
|
||||||
# 手牌的独热编码
|
# 手牌的独热编码
|
||||||
|
|
@ -66,12 +83,65 @@ class DouDiZhuEnv(gym.Env):
|
||||||
history[card] = 1
|
history[card] = 1
|
||||||
elif isinstance(play, int): # 如果是单个牌
|
elif isinstance(play, int): # 如果是单个牌
|
||||||
history[play] = 1
|
history[play] = 1
|
||||||
else:
|
|
||||||
raise ValueError(f"历史记录格式错误: {play}")
|
|
||||||
|
|
||||||
return {"hand_cards": hand_cards, "history": history}
|
# 当前牌面的独热编码
|
||||||
|
current_pile = np.zeros(54, dtype=np.int32)
|
||||||
|
if self.engine.current_pile:
|
||||||
|
for card in self.engine.current_pile:
|
||||||
|
current_pile[card] = 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"hand_cards": hand_cards,
|
||||||
|
"history": history,
|
||||||
|
"current_pile": current_pile,
|
||||||
|
"current_player": self.engine.current_player_index,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _decode_action(self, action):
|
||||||
|
"""
|
||||||
|
解码动作为具体的牌型。
|
||||||
|
:param action: 动作索引
|
||||||
|
:return: 解码后的牌型
|
||||||
|
"""
|
||||||
|
valid_actions = self.get_action_space()
|
||||||
|
if action < len(valid_actions):
|
||||||
|
return valid_actions[action]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"非法动作索引: {action}")
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self, mode="human"):
|
||||||
"""打印当前游戏状态"""
|
"""打印当前游戏状态"""
|
||||||
state = self.engine.get_game_state()
|
state = self.engine.get_game_state()
|
||||||
print(state)
|
print(state)
|
||||||
|
|
||||||
|
def get_action_space(self):
|
||||||
|
"""
|
||||||
|
动态生成当前合法的动作空间。
|
||||||
|
返回一个合法动作的列表,其中:
|
||||||
|
- 索引 0 表示过牌。
|
||||||
|
- 其余索引表示具体的出牌动作。
|
||||||
|
"""
|
||||||
|
current_player = self.engine.get_current_player()
|
||||||
|
hand_cards = current_player.hand_cards
|
||||||
|
|
||||||
|
# 所有合法出牌组合
|
||||||
|
valid_actions = [["pass"]] # 索引 0 表示过牌
|
||||||
|
valid_actions.extend(self._generate_valid_combinations(hand_cards))
|
||||||
|
return valid_actions
|
||||||
|
|
||||||
|
def _generate_valid_combinations(self, hand_cards):
|
||||||
|
"""
|
||||||
|
根据手牌生成所有合法牌型组合。
|
||||||
|
:param hand_cards: 当前玩家的手牌
|
||||||
|
:return: 所有合法的牌型组合
|
||||||
|
"""
|
||||||
|
from itertools import combinations
|
||||||
|
valid_combinations = []
|
||||||
|
|
||||||
|
# 示例:生成单牌、对子和三张的合法组合
|
||||||
|
for i in range(1, len(hand_cards) + 1):
|
||||||
|
for combo in combinations(hand_cards, i):
|
||||||
|
if self.engine.is_valid_play(list(combo)): # 检查是否为合法牌型
|
||||||
|
valid_combinations.append(list(combo))
|
||||||
|
|
||||||
|
return valid_combinations
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue