wsy182 2024-12-01 22:14:23 +08:00
parent 5eef2384cf
commit 3e65e02704
29 changed files with 87 additions and 39 deletions

View File

@ -1,6 +1,5 @@
import gym
from stable_baselines3 import PPO
from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
from src import ChengduMahjongEnv
import torch
from configs.log_config import setup_logging
from loguru import logger # 添加 logger

View File

10
src/engine/dizhu/deck.py Normal file
View File

@ -0,0 +1,10 @@
import numpy as np
class Deck:
def __init__(self):
self.cards = [i for i in range(54)] # 0-53 表示54张牌
np.random.shuffle(self.cards)
def deal(self):
# 返回三位玩家的手牌和地主牌
return self.cards[:17], self.cards[17:34], self.cards[34:51], self.cards[51:]

View File

View File

@ -0,0 +1,5 @@
class PlayerState:
def __init__(self, hand_cards, role):
self.hand_cards = hand_cards # 玩家手牌
self.role = role # "地主" 或 "农民"
self.history = [] # 出牌历史

View File

View File

@ -2,9 +2,9 @@ import random as random_module
from loguru import logger
from src.engine.calculate_fan import calculate_fan
from src.engine.mahjong_tile import MahjongTile
from src.engine.meld import Meld
from src.engine.mahjong.calculate_fan import calculate_fan
from src.engine.mahjong.mahjong_tile import MahjongTile
from src.engine.mahjong.meld import Meld
def draw_tile(engine):

View File

@ -1,4 +1,4 @@
from src.engine.fan_type import is_terminal_fan,is_cleared,is_full_request,is_seven_pairs,is_basic_win,is_dragon_seven_pairs
from src.engine.mahjong.fan_type import is_terminal_fan,is_cleared,is_full_request,is_seven_pairs,is_basic_win,is_dragon_seven_pairs
from loguru import logger
def calculate_fan(hand, melds, is_self_draw, winning_tile, conditions):

View File

@ -3,9 +3,9 @@ import random
from loguru import logger
from configs.log_config import setup_logging
from src.engine.actions import draw_tile, should_gang, random_choice, handle_win, handle_gang, handle_peng
from src.engine.actions import set_missing_suit, check_other_players
from src.engine.chengdu_mahjong_state import ChengduMahjongState
from src.engine.mahjong.actions import draw_tile, random_choice, handle_win, handle_gang, handle_peng
from src.engine.mahjong.actions import set_missing_suit
from src.engine.mahjong.chengdu_mahjong_state import ChengduMahjongState
class ChengduMahjongEngine:

View File

@ -1,7 +1,7 @@
from collections import Counter
from src.engine.hand import Hand
from src.engine.mahjong_tile import MahjongTile
from src.engine.meld import Meld
from src.engine.mahjong.hand import Hand
from src.engine.mahjong.mahjong_tile import MahjongTile
from src.engine.mahjong.meld import Meld
from loguru import logger

View File

@ -1,4 +1,4 @@
from src.engine.utils import try_win,is_terminal_tile
from src.engine.mahjong.utils import try_win
from collections import Counter
def is_basic_win(hand):

View File

@ -1,4 +1,4 @@
from src.engine.mahjong_tile import MahjongTile
from src.engine.mahjong.mahjong_tile import MahjongTile
from collections import defaultdict
class Hand:

View File

@ -1,4 +1,4 @@
from src.engine.mahjong_tile import MahjongTile
from src.engine.mahjong.mahjong_tile import MahjongTile
class Meld:
def __init__(self, tile, type: str):

View File

View File

@ -2,8 +2,8 @@ import gym
from gym import spaces
import numpy as np
from src.engine.actions import handle_peng, handle_gang, handle_win
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
from src import handle_peng, handle_gang, handle_win
from src import ChengduMahjongEngine
from loguru import logger

View File

@ -0,0 +1,35 @@
from gym import spaces
from src.engine.dizhu.player_state import PlayerState
from src.engine.dizhu.deck import Deck
class DouDiZhuEnv:
def __init__(self):
self.deck = Deck()
self.players = [] # 初始化玩家
self.landlord = None
self.current_player_index = 0
self.action_space = spaces.Discrete(54) # 动作空间,出一张牌或“过牌”
self.observation_space = spaces.Box(low=0, high=1, shape=(54,)) # 牌局状态表示
def reset(self):
p1_hand, p2_hand, p3_hand, landlord_cards = self.deck.deal()
self.players = [
PlayerState(p1_hand, "农民"),
PlayerState(p2_hand, "农民"),
PlayerState(p3_hand, "地主"),
]
self.landlord = self.players[2]
self.current_player_index = 0
return self._get_observation()
def _get_observation(self):
# 返回当前玩家的状态,具体实现根据模型需求定制
return {
"hand": self.players[self.current_player_index].hand_cards,
"history": self.players[self.current_player_index].history,
}
def step(self, action):
# 执行动作,更新状态
pass

View File

@ -1,8 +1,8 @@
from src.engine.chengdu_mahjong_state import ChengduMahjongState
from src.engine.hand import Hand
from src import ChengduMahjongState
from src import Hand
from src.engine.mahjong_tile import MahjongTile
from src.engine.meld import Meld
from src import MahjongTile
from src import Meld
hand = Hand()

View File

@ -1,8 +1,7 @@
import pytest
from src.engine.calculate_fan import calculate_fan, is_seven_pairs, is_cleared, is_big_pairs
from src import calculate_fan, is_seven_pairs, is_cleared, is_big_pairs
from src.engine.hand import Hand
from src.engine.mahjong_tile import MahjongTile
from src import Hand
from src import MahjongTile
# 测试用例

View File

@ -1,4 +1,4 @@
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
from src import ChengduMahjongEngine
from loguru import logger
def test_mahjong_engine():

View File

@ -1,7 +1,7 @@
from src.engine.chengdu_mahjong_state import ChengduMahjongState
from src.engine.hand import Hand
from src.engine.mahjong_tile import MahjongTile
from src.engine.meld import Meld
from src import ChengduMahjongState
from src import Hand
from src import MahjongTile
from src import Meld
def test_set_missing_suit():

View File

@ -1,7 +1,7 @@
from src.engine.hand import Hand
from src.engine.mahjong_tile import MahjongTile
from src.engine.fan_type import is_basic_win,is_cleared,is_terminal_fan,is_seven_pairs,is_full_request,is_dragon_seven_pairs
from src.engine.meld import Meld
from src import Hand
from src import MahjongTile
from src import is_basic_win,is_cleared,is_terminal_fan,is_seven_pairs,is_full_request,is_dragon_seven_pairs
from src import Meld
def test_is_basic_win():
"""

View File

@ -1,5 +1,5 @@
from src.engine.hand import Hand
from src.engine.mahjong_tile import MahjongTile
from src import Hand
from src import MahjongTile
def test_add_tile():

View File

@ -1,4 +1,4 @@
from src.engine.mahjong_tile import MahjongTile
from src import MahjongTile
def test_mahjong_tile():
# 测试合法的牌

View File

@ -1,5 +1,5 @@
import pytest
from src.engine.scoring import calculate_score
from src import calculate_score
@pytest.mark.parametrize("fan, is_self_draw, base_score, expected_scores", [
# 测试用例 1: 自摸,总番数 3

View File

@ -1,4 +1,4 @@
from src.engine.utils import get_suit,get_tile_name
from src import get_suit,get_tile_name
def test_get_suit():
# 测试条花色0-35