Compare commits
7 Commits
e54a8c5f00
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e492e4d8f | |||
| f1836172d6 | |||
| adeb153c8a | |||
| 5336990f54 | |||
| deff3cb921 | |||
| 96f0fbdcd7 | |||
| 9aec69bb46 |
Binary file not shown.
@@ -23,7 +23,7 @@ def train_dizhu_model():
|
||||
|
||||
# 训练模型,设定总训练步数
|
||||
logger.info("开始训练斗地主模型...")
|
||||
model.learn(total_timesteps=100000) # 总训练步数
|
||||
model.learn(total_timesteps=10000000000000000) # 总训练步数
|
||||
logger.info("斗地主模型训练完成!")
|
||||
|
||||
# 保存训练后的模型
|
||||
|
||||
@@ -2,6 +2,8 @@ import numpy as np
|
||||
from loguru import logger
|
||||
from src.engine.dizhu.player_state import PlayerState
|
||||
from src.engine.dizhu.deck import Deck
|
||||
from src.engine.dizhu.utils import card_to_string, detect_card_type
|
||||
|
||||
|
||||
class DiZhuEngine:
|
||||
def __init__(self):
|
||||
@@ -10,6 +12,8 @@ class DiZhuEngine:
|
||||
self.landlord_index = -1 # 地主索引
|
||||
self.current_player_index = 0 # 当前玩家索引
|
||||
self.landlord_cards = [] # 地主牌
|
||||
self.last_player = None # 最后出牌的玩家索引
|
||||
self.current_pile = None # 当前牌面上的牌
|
||||
self.game_over = False # 是否游戏结束
|
||||
|
||||
def reset(self):
|
||||
@@ -32,16 +36,13 @@ class DiZhuEngine:
|
||||
|
||||
# 日志输出
|
||||
logger.info("游戏初始化完成")
|
||||
logger.info(f"地主牌: {self.landlord_cards}")
|
||||
logger.info(f"地主牌: {[card_to_string(card) for card in self.landlord_cards]}")
|
||||
for i, player in enumerate(self.players):
|
||||
logger.info(f"玩家 {i + 1} ({player.role}) 手牌: {player.hand_cards}")
|
||||
logger.info(f"玩家 {i + 1} ({player.role}) 手牌: {player.get_hand_cards_as_strings()}")
|
||||
|
||||
def get_current_player(self):
|
||||
"""
|
||||
获取当前玩家对象
|
||||
"""
|
||||
"""获取当前玩家对象"""
|
||||
current_player = self.players[self.current_player_index]
|
||||
logger.info(f"当前玩家: 玩家 {self.current_player_index + 1} ({current_player.role})")
|
||||
return current_player
|
||||
|
||||
def step(self, action):
|
||||
@@ -51,26 +52,35 @@ class DiZhuEngine:
|
||||
"""
|
||||
current_player = self.get_current_player()
|
||||
|
||||
logger.info(f"玩家 {self.current_player_index + 1} 的动作: {action}")
|
||||
|
||||
if action == "pass":
|
||||
# 过牌作为单独的历史记录
|
||||
current_player.history.append([])
|
||||
logger.info(f"玩家 {self.current_player_index + 1} 选择过牌")
|
||||
self.pass_count += 1
|
||||
|
||||
# 如果所有其他玩家都过牌,允许最后出牌玩家再次出牌
|
||||
if self.pass_count == 2: # 两名玩家连续过牌
|
||||
self.current_player_index = self.last_player
|
||||
self.pass_count = 0 # 重置过牌计数
|
||||
self.current_pile = None # 清空当前牌面
|
||||
else:
|
||||
# 确保动作始终为列表
|
||||
# 出牌逻辑
|
||||
if not isinstance(action, list):
|
||||
action = [action]
|
||||
# 检查动作合法性
|
||||
|
||||
if not all(card in current_player.hand_cards for card in action):
|
||||
logger.error(f"玩家 {self.current_player_index + 1} 的动作非法: {action}")
|
||||
raise ValueError(f"玩家手牌不足以完成此次出牌: {action}")
|
||||
# 移除出牌
|
||||
|
||||
if self.current_pile and not self._can_beat(self.current_pile, action):
|
||||
raise ValueError(f"出牌无法打过当前牌面: {action}")
|
||||
|
||||
# 出牌成功
|
||||
self.current_pile = action # 更新当前牌面
|
||||
self.pass_count = 0 # 出牌后重置过牌计数
|
||||
self.last_player = self.current_player_index # 更新最后出牌的玩家
|
||||
|
||||
|
||||
# 从手牌中移除
|
||||
for card in action:
|
||||
current_player.hand_cards.remove(card)
|
||||
# 记录动作
|
||||
current_player.history.append(action)
|
||||
logger.info(f"玩家 {self.current_player_index + 1} 出牌: {action}")
|
||||
|
||||
# 检查游戏是否结束
|
||||
if not current_player.hand_cards:
|
||||
@@ -80,7 +90,61 @@ class DiZhuEngine:
|
||||
|
||||
# 切换到下一个玩家
|
||||
self.current_player_index = (self.current_player_index + 1) % 3
|
||||
logger.info(f"切换到玩家 {self.current_player_index + 1}")
|
||||
|
||||
|
||||
def get_action_space(self):
|
||||
"""
|
||||
动态生成当前动作空间。
|
||||
:return: 合法动作的列表
|
||||
"""
|
||||
valid_actions = ["pass"]
|
||||
current_player = self.get_current_player()
|
||||
|
||||
# 遍历玩家手牌,生成所有可能的组合
|
||||
hand_cards = current_player.hand_cards
|
||||
valid_actions.extend(self._generate_valid_combinations(hand_cards))
|
||||
|
||||
return valid_actions
|
||||
|
||||
def _generate_valid_combinations(self, cards):
|
||||
"""
|
||||
根据手牌生成所有合法牌型组合
|
||||
:param cards: 当前玩家的手牌
|
||||
:return: 合法牌型的列表
|
||||
"""
|
||||
# 示例:生成单牌、对子和三张的合法组合
|
||||
from itertools import combinations
|
||||
valid_combinations = []
|
||||
for i in range(1, len(cards) + 1):
|
||||
for combo in combinations(cards, i):
|
||||
if detect_card_type(list(combo)): # 检查是否为合法牌型
|
||||
valid_combinations.append(list(combo))
|
||||
return valid_combinations
|
||||
|
||||
def _can_beat(self, current_pile, action):
|
||||
"""
|
||||
检查当前动作是否能打过牌面上的牌。
|
||||
:param current_pile: 当前牌面上的牌(列表)
|
||||
:param action: 当前玩家要出的牌(列表)
|
||||
:return: True 如果可以打过,否则 False
|
||||
"""
|
||||
current_type = detect_card_type(current_pile)
|
||||
action_type = detect_card_type(action)
|
||||
|
||||
if not current_type or not action_type:
|
||||
return False # 非法牌型
|
||||
|
||||
# 火箭可以压任何牌
|
||||
if action_type == "火箭":
|
||||
return True
|
||||
# 炸弹可以压非炸弹的牌型
|
||||
if action_type == "炸弹" and current_type != "炸弹":
|
||||
return True
|
||||
# 同牌型比较大小
|
||||
if current_type == action_type:
|
||||
return max(action) > max(current_pile)
|
||||
|
||||
return False # 其他情况不合法
|
||||
|
||||
def get_game_state(self):
|
||||
"""
|
||||
@@ -100,10 +164,32 @@ class DiZhuEngine:
|
||||
"game_over": self.game_over,
|
||||
}
|
||||
logger.info("当前游戏状态: ")
|
||||
logger.info(f"地主牌: {self.landlord_cards}")
|
||||
for i, player in enumerate(self.players):
|
||||
logger.info(f"玩家 {i + 1} ({player.role}) 手牌: {player.hand_cards}")
|
||||
logger.info(f"玩家 {i + 1} 出牌历史: {player.history}")
|
||||
logger.info(f"当前玩家索引: {self.current_player_index}")
|
||||
logger.info(f"游戏是否结束: {self.game_over}")
|
||||
return state
|
||||
|
||||
def is_valid_play(self, cards):
|
||||
"""
|
||||
检查给定的牌是否为合法的斗地主牌型。
|
||||
:param cards: 玩家出的牌(列表)
|
||||
:return: True 如果是合法牌型,否则 False
|
||||
"""
|
||||
if len(cards) == 1:
|
||||
return True # 单牌
|
||||
if len(cards) == 2 and cards[0] == cards[1]:
|
||||
return True # 对子
|
||||
if len(cards) == 3 and cards[0] == cards[1] == cards[2]:
|
||||
return True # 三张
|
||||
if len(cards) == 4:
|
||||
counts = {card: cards.count(card) for card in set(cards)}
|
||||
if 3 in counts.values():
|
||||
return True # 三带一
|
||||
if all(count == 4 for count in counts.values()):
|
||||
return True # 炸弹
|
||||
if len(cards) >= 5:
|
||||
# 顺子
|
||||
sorted_cards = sorted(cards)
|
||||
if all(sorted_cards[i] + 1 == sorted_cards[i + 1] for i in range(len(sorted_cards) - 1)):
|
||||
return True
|
||||
# TODO: 扩展支持更多牌型(如连对、飞机等)
|
||||
return False
|
||||
|
||||
|
||||
@@ -1,7 +1,22 @@
|
||||
from collections import deque
|
||||
from src.engine.dizhu.utils import card_to_string
|
||||
|
||||
class PlayerState:
|
||||
def __init__(self, hand_cards, role):
|
||||
self.hand_cards = hand_cards # 玩家手牌
|
||||
self.role = role # "地主" 或 "农民"
|
||||
self.history = deque() # 出牌历史,使用 deque
|
||||
|
||||
def get_hand_cards_as_strings(self):
|
||||
"""
|
||||
获取玩家手牌的具体牌型字符串
|
||||
:return: 手牌字符串列表
|
||||
"""
|
||||
return [card_to_string(card) for card in self.hand_cards]
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
返回玩家的字符串表示,包括手牌和角色
|
||||
"""
|
||||
hand_cards_str = ", ".join(self.get_hand_cards_as_strings())
|
||||
return f"玩家角色: {self.role}, 手牌: [{hand_cards_str}]"
|
||||
|
||||
32
src/engine/dizhu/scroing.py
Normal file
32
src/engine/dizhu/scroing.py
Normal file
@@ -0,0 +1,32 @@
|
||||
class DouDiZhuScoring:
|
||||
def __init__(self, base_score=1):
|
||||
self.base_score = base_score # 底分
|
||||
self.multiplier = 1 # 倍数
|
||||
self.landlord_win = False # 地主是否胜利
|
||||
|
||||
def apply_event(self, event):
|
||||
"""
|
||||
根据游戏事件调整倍数。
|
||||
:param event: 事件类型,如 "炸弹", "火箭", "春天", "反春天"
|
||||
"""
|
||||
if event in ["炸弹", "火箭", "春天", "反春天"]:
|
||||
self.multiplier *= 2
|
||||
elif event == "抢地主":
|
||||
self.multiplier += 1
|
||||
|
||||
def calculate_score(self, landlord_win):
|
||||
"""
|
||||
计算最终分数。
|
||||
:param landlord_win: 地主是否胜利
|
||||
:return: 地主分数,农民分数
|
||||
"""
|
||||
self.landlord_win = landlord_win
|
||||
if landlord_win:
|
||||
landlord_score = 2 * self.base_score * self.multiplier
|
||||
farmer_score = -self.base_score * self.multiplier
|
||||
else:
|
||||
landlord_score = -2 * self.base_score * self.multiplier
|
||||
farmer_score = self.base_score * self.multiplier
|
||||
|
||||
return landlord_score, farmer_score
|
||||
|
||||
48
src/engine/dizhu/utils.py
Normal file
48
src/engine/dizhu/utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
|
||||
|
||||
def card_to_string(card_index):
|
||||
"""
|
||||
将牌的索引转换为具体牌型的字符串表示
|
||||
:param card_index: 牌的索引(0-53)
|
||||
:return: 具体牌型字符串
|
||||
"""
|
||||
suits = ['♠️', '♥️', '♦️', '♣️'] # 花色
|
||||
values = ['3', '4', '5', '6', '7', '8', '9', '10', 'J', 'Q', 'K', 'A', '2']
|
||||
|
||||
if card_index < 52:
|
||||
# 普通牌:计算花色和牌面值
|
||||
value = values[card_index // 4]
|
||||
suit = suits[card_index % 4]
|
||||
return f"{suit}{value}"
|
||||
elif card_index == 52:
|
||||
return "小王"
|
||||
elif card_index == 53:
|
||||
return "大王"
|
||||
else:
|
||||
raise ValueError(f"无效的牌索引: {card_index}")
|
||||
|
||||
|
||||
def detect_card_type(cards):
|
||||
"""
|
||||
检测牌型
|
||||
:param cards: 玩家出的牌(列表)
|
||||
:return: 牌型字符串或 None(非法牌型)
|
||||
"""
|
||||
if len(cards) == 1:
|
||||
return "单牌"
|
||||
if len(cards) == 2 and cards[0] == cards[1]:
|
||||
return "对子"
|
||||
if len(cards) == 3 and cards[0] == cards[1] == cards[2]:
|
||||
return "三张"
|
||||
if len(cards) == 4 and cards[0] == cards[1] == cards[2] == cards[3]:
|
||||
return "炸弹"
|
||||
# 三带一
|
||||
if len(cards) == 4:
|
||||
counts = {card: cards.count(card) for card in set(cards)}
|
||||
if 3 in counts.values():
|
||||
return "三带一"
|
||||
# 顺子
|
||||
if len(cards) >= 5 and all(cards[i] + 1 == cards[i + 1] for i in range(len(cards) - 1)):
|
||||
return "顺子"
|
||||
# TODO: 实现其他牌型判断(如连对、飞机等)
|
||||
return None
|
||||
@@ -60,7 +60,7 @@ class ChengduMahjongEnv(gym.Env):
|
||||
# **执行动作**
|
||||
if action < max_hand_actions: # 打牌动作
|
||||
tile = hand[action]
|
||||
logger.info(f"玩家 {current_player} 选择打牌: {tile}")
|
||||
# logger.info(f"玩家 {current_player} 选择打牌: {tile}")
|
||||
self.engine.check_other_players(tile)
|
||||
elif action == max_hand_actions: # 碰
|
||||
tile_to_peng = self._get_tile_for_special_action("peng")
|
||||
|
||||
@@ -1,44 +1,70 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from src.engine.dizhu.dizhu_engine import DiZhuEngine # 引入地主引擎
|
||||
from src.engine.dizhu.dizhu_engine import DiZhuEngine # 引入斗地主引擎
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class DouDiZhuEnv(gym.Env):
|
||||
def __init__(self):
|
||||
super(DouDiZhuEnv, self).__init__()
|
||||
self.engine = DiZhuEngine() # 初始化斗地主引擎
|
||||
self.action_space = spaces.Discrete(55) # 假设最大动作空间为 55,表示可能的出牌和过牌
|
||||
|
||||
# 定义初始动作空间和观察空间
|
||||
self.action_space = spaces.Discrete(55) # 假设最大动作空间为 55
|
||||
self.observation_space = spaces.Dict({
|
||||
"hand_cards": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int32), # 玩家手牌(独热编码)
|
||||
"history": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int32), # 出牌历史
|
||||
"current_pile": spaces.Box(low=0, high=1, shape=(54,), dtype=np.int32), # 当前牌面上的牌
|
||||
"current_player": spaces.Discrete(3), # 当前玩家索引
|
||||
})
|
||||
|
||||
def reset(self):
|
||||
"""重置游戏环境"""
|
||||
self.engine.reset()
|
||||
logger.info("斗地主环境已重置")
|
||||
return self._get_observation()
|
||||
|
||||
def step(self, action):
|
||||
"""执行动作并更新环境"""
|
||||
try:
|
||||
# 根据动作索引解析出具体的出牌动作
|
||||
if action == 0:
|
||||
self.engine.step("pass")
|
||||
else:
|
||||
card_index = action - 1 # 动作索引 1-54 对应 54 张牌
|
||||
self.engine.step([card_index])
|
||||
reward = 0 # 初始化奖励
|
||||
current_player = self.engine.get_current_player()
|
||||
|
||||
# 更新游戏状态
|
||||
|
||||
if action == 0: # 过牌
|
||||
self.engine.step("pass")
|
||||
reward -= 0.5 # 对频繁过牌给予轻微惩罚
|
||||
else:
|
||||
# 玩家选择出牌
|
||||
action_cards = self._decode_action(action) # 解码动作为具体的牌型
|
||||
|
||||
|
||||
# 出牌前的手牌数量
|
||||
previous_hand_count = len(current_player.hand_cards)
|
||||
|
||||
# 尝试执行动作
|
||||
self.engine.step(action_cards)
|
||||
|
||||
# 出牌后的手牌数量
|
||||
current_hand_count = len(current_player.hand_cards)
|
||||
|
||||
# 根据减少的手牌数量计算奖励
|
||||
reward += (previous_hand_count - current_hand_count) * 1.0
|
||||
|
||||
# 检查游戏是否结束
|
||||
done = self.engine.game_over
|
||||
reward = 1 if done else 0 # 简单奖励:胜利得 1 分,其他情况得 0
|
||||
if done:
|
||||
reward += 10 # 胜利时给予较大的奖励
|
||||
logger.info(f"游戏结束!胜利玩家: {self.engine.current_player_index + 1}")
|
||||
|
||||
return self._get_observation(), reward, done, {}
|
||||
|
||||
except ValueError as e:
|
||||
# 如果玩家执行了无效动作,给予惩罚
|
||||
return self._get_observation(), -1, False, {"error": str(e)}
|
||||
return self._get_observation(), -5, False, {"error": str(e)}
|
||||
|
||||
def _get_observation(self):
|
||||
"""获取当前玩家的状态"""
|
||||
"""获取当前玩家的观察空间"""
|
||||
current_player = self.engine.get_current_player()
|
||||
|
||||
# 手牌的独热编码
|
||||
@@ -54,12 +80,65 @@ class DouDiZhuEnv(gym.Env):
|
||||
history[card] = 1
|
||||
elif isinstance(play, int): # 如果是单个牌
|
||||
history[play] = 1
|
||||
else:
|
||||
raise ValueError(f"历史记录格式错误: {play}")
|
||||
|
||||
return {"hand_cards": hand_cards, "history": history}
|
||||
# 当前牌面的独热编码
|
||||
current_pile = np.zeros(54, dtype=np.int32)
|
||||
if self.engine.current_pile:
|
||||
for card in self.engine.current_pile:
|
||||
current_pile[card] = 1
|
||||
|
||||
return {
|
||||
"hand_cards": hand_cards,
|
||||
"history": history,
|
||||
"current_pile": current_pile,
|
||||
"current_player": self.engine.current_player_index,
|
||||
}
|
||||
|
||||
def _decode_action(self, action):
|
||||
"""
|
||||
解码动作为具体的牌型。
|
||||
:param action: 动作索引
|
||||
:return: 解码后的牌型
|
||||
"""
|
||||
valid_actions = self.get_action_space()
|
||||
if action < len(valid_actions):
|
||||
return valid_actions[action]
|
||||
else:
|
||||
raise ValueError(f"非法动作索引: {action}")
|
||||
|
||||
def render(self, mode="human"):
|
||||
"""打印当前游戏状态"""
|
||||
state = self.engine.get_game_state()
|
||||
print(state)
|
||||
|
||||
def get_action_space(self):
|
||||
"""
|
||||
动态生成当前合法的动作空间。
|
||||
返回一个合法动作的列表,其中:
|
||||
- 索引 0 表示过牌。
|
||||
- 其余索引表示具体的出牌动作。
|
||||
"""
|
||||
current_player = self.engine.get_current_player()
|
||||
hand_cards = current_player.hand_cards
|
||||
|
||||
# 所有合法出牌组合
|
||||
valid_actions = [["pass"]] # 索引 0 表示过牌
|
||||
valid_actions.extend(self._generate_valid_combinations(hand_cards))
|
||||
return valid_actions
|
||||
|
||||
def _generate_valid_combinations(self, hand_cards):
|
||||
"""
|
||||
根据手牌生成所有合法牌型组合。
|
||||
:param hand_cards: 当前玩家的手牌
|
||||
:return: 所有合法的牌型组合
|
||||
"""
|
||||
from itertools import combinations
|
||||
valid_combinations = []
|
||||
|
||||
# 示例:生成单牌、对子和三张的合法组合
|
||||
for i in range(1, len(hand_cards) + 1):
|
||||
for combo in combinations(hand_cards, i):
|
||||
if self.engine.is_valid_play(list(combo)): # 检查是否为合法牌型
|
||||
valid_combinations.append(list(combo))
|
||||
|
||||
return valid_combinations
|
||||
|
||||
35
test_dizhu.py
Normal file
35
test_dizhu.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from stable_baselines3 import PPO
|
||||
from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def test_dizhu_model(): # 确保函数名以 test_ 开头
|
||||
# 创建斗地主环境
|
||||
env = DouDiZhuEnv()
|
||||
|
||||
# 加载已训练的模型
|
||||
model_path = "./models/ppo_doudizhu_model.zip" # 确保路径正确
|
||||
logger.info(f"加载模型: {model_path}")
|
||||
try:
|
||||
model = PPO.load(model_path)
|
||||
except Exception as e:
|
||||
logger.error(f"加载模型失败: {e}")
|
||||
return
|
||||
|
||||
# 测试模型
|
||||
obs = env.reset()
|
||||
done = False
|
||||
total_reward = 0
|
||||
logger.info("开始测试斗地主模型...")
|
||||
|
||||
max_steps = 1000 # 设置最大步数
|
||||
step_count = 0
|
||||
|
||||
while not done and step_count < max_steps:
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, info = env.step(action)
|
||||
total_reward += reward
|
||||
step_count += 1
|
||||
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}")
|
||||
|
||||
logger.info(f"测试完成,总奖励: {total_reward}")
|
||||
0
tests/mahjong/__init__.py
Normal file
0
tests/mahjong/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
Reference in New Issue
Block a user