parent
5eef2384cf
commit
3e65e02704
|
|
@ -1,6 +1,5 @@
|
||||||
import gym
|
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
from src.environment.chengdu_mahjong_env import ChengduMahjongEnv
|
from src import ChengduMahjongEnv
|
||||||
import torch
|
import torch
|
||||||
from configs.log_config import setup_logging
|
from configs.log_config import setup_logging
|
||||||
from loguru import logger # 添加 logger
|
from loguru import logger # 添加 logger
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class Deck:
|
||||||
|
def __init__(self):
|
||||||
|
self.cards = [i for i in range(54)] # 0-53 表示54张牌
|
||||||
|
np.random.shuffle(self.cards)
|
||||||
|
|
||||||
|
def deal(self):
|
||||||
|
# 返回三位玩家的手牌和地主牌
|
||||||
|
return self.cards[:17], self.cards[17:34], self.cards[34:51], self.cards[51:]
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
class PlayerState:
|
||||||
|
def __init__(self, hand_cards, role):
|
||||||
|
self.hand_cards = hand_cards # 玩家手牌
|
||||||
|
self.role = role # "地主" 或 "农民"
|
||||||
|
self.history = [] # 出牌历史
|
||||||
|
|
@ -2,9 +2,9 @@ import random as random_module
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from src.engine.calculate_fan import calculate_fan
|
from src.engine.mahjong.calculate_fan import calculate_fan
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src.engine.mahjong.mahjong_tile import MahjongTile
|
||||||
from src.engine.meld import Meld
|
from src.engine.mahjong.meld import Meld
|
||||||
|
|
||||||
|
|
||||||
def draw_tile(engine):
|
def draw_tile(engine):
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.engine.fan_type import is_terminal_fan,is_cleared,is_full_request,is_seven_pairs,is_basic_win,is_dragon_seven_pairs
|
from src.engine.mahjong.fan_type import is_terminal_fan,is_cleared,is_full_request,is_seven_pairs,is_basic_win,is_dragon_seven_pairs
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
def calculate_fan(hand, melds, is_self_draw, winning_tile, conditions):
|
def calculate_fan(hand, melds, is_self_draw, winning_tile, conditions):
|
||||||
|
|
@ -3,9 +3,9 @@ import random
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from configs.log_config import setup_logging
|
from configs.log_config import setup_logging
|
||||||
from src.engine.actions import draw_tile, should_gang, random_choice, handle_win, handle_gang, handle_peng
|
from src.engine.mahjong.actions import draw_tile, random_choice, handle_win, handle_gang, handle_peng
|
||||||
from src.engine.actions import set_missing_suit, check_other_players
|
from src.engine.mahjong.actions import set_missing_suit
|
||||||
from src.engine.chengdu_mahjong_state import ChengduMahjongState
|
from src.engine.mahjong.chengdu_mahjong_state import ChengduMahjongState
|
||||||
|
|
||||||
|
|
||||||
class ChengduMahjongEngine:
|
class ChengduMahjongEngine:
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from src.engine.hand import Hand
|
from src.engine.mahjong.hand import Hand
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src.engine.mahjong.mahjong_tile import MahjongTile
|
||||||
from src.engine.meld import Meld
|
from src.engine.mahjong.meld import Meld
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.engine.utils import try_win,is_terminal_tile
|
from src.engine.mahjong.utils import try_win
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
def is_basic_win(hand):
|
def is_basic_win(hand):
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src.engine.mahjong.mahjong_tile import MahjongTile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
class Hand:
|
class Hand:
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src.engine.mahjong.mahjong_tile import MahjongTile
|
||||||
|
|
||||||
class Meld:
|
class Meld:
|
||||||
def __init__(self, tile, type: str):
|
def __init__(self, tile, type: str):
|
||||||
|
|
@ -2,8 +2,8 @@ import gym
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.engine.actions import handle_peng, handle_gang, handle_win
|
from src import handle_peng, handle_gang, handle_win
|
||||||
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
|
from src import ChengduMahjongEngine
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
from src.engine.dizhu.player_state import PlayerState
|
||||||
|
from src.engine.dizhu.deck import Deck
|
||||||
|
|
||||||
|
class DouDiZhuEnv:
|
||||||
|
def __init__(self):
|
||||||
|
self.deck = Deck()
|
||||||
|
self.players = [] # 初始化玩家
|
||||||
|
self.landlord = None
|
||||||
|
self.current_player_index = 0
|
||||||
|
self.action_space = spaces.Discrete(54) # 动作空间,出一张牌或“过牌”
|
||||||
|
self.observation_space = spaces.Box(low=0, high=1, shape=(54,)) # 牌局状态表示
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
p1_hand, p2_hand, p3_hand, landlord_cards = self.deck.deal()
|
||||||
|
self.players = [
|
||||||
|
PlayerState(p1_hand, "农民"),
|
||||||
|
PlayerState(p2_hand, "农民"),
|
||||||
|
PlayerState(p3_hand, "地主"),
|
||||||
|
]
|
||||||
|
self.landlord = self.players[2]
|
||||||
|
self.current_player_index = 0
|
||||||
|
return self._get_observation()
|
||||||
|
|
||||||
|
def _get_observation(self):
|
||||||
|
# 返回当前玩家的状态,具体实现根据模型需求定制
|
||||||
|
return {
|
||||||
|
"hand": self.players[self.current_player_index].hand_cards,
|
||||||
|
"history": self.players[self.current_player_index].history,
|
||||||
|
}
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
# 执行动作,更新状态
|
||||||
|
pass
|
||||||
8
test.py
8
test.py
|
|
@ -1,8 +1,8 @@
|
||||||
from src.engine.chengdu_mahjong_state import ChengduMahjongState
|
from src import ChengduMahjongState
|
||||||
from src.engine.hand import Hand
|
from src import Hand
|
||||||
|
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src import MahjongTile
|
||||||
from src.engine.meld import Meld
|
from src import Meld
|
||||||
|
|
||||||
hand = Hand()
|
hand = Hand()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import pytest
|
from src import calculate_fan, is_seven_pairs, is_cleared, is_big_pairs
|
||||||
from src.engine.calculate_fan import calculate_fan, is_seven_pairs, is_cleared, is_big_pairs
|
|
||||||
|
|
||||||
from src.engine.hand import Hand
|
from src import Hand
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src import MahjongTile
|
||||||
|
|
||||||
# 测试用例
|
# 测试用例
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
|
from src import ChengduMahjongEngine
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
def test_mahjong_engine():
|
def test_mahjong_engine():
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from src.engine.chengdu_mahjong_state import ChengduMahjongState
|
from src import ChengduMahjongState
|
||||||
from src.engine.hand import Hand
|
from src import Hand
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src import MahjongTile
|
||||||
from src.engine.meld import Meld
|
from src import Meld
|
||||||
|
|
||||||
|
|
||||||
def test_set_missing_suit():
|
def test_set_missing_suit():
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from src.engine.hand import Hand
|
from src import Hand
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src import MahjongTile
|
||||||
from src.engine.fan_type import is_basic_win,is_cleared,is_terminal_fan,is_seven_pairs,is_full_request,is_dragon_seven_pairs
|
from src import is_basic_win,is_cleared,is_terminal_fan,is_seven_pairs,is_full_request,is_dragon_seven_pairs
|
||||||
from src.engine.meld import Meld
|
from src import Meld
|
||||||
|
|
||||||
def test_is_basic_win():
|
def test_is_basic_win():
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from src.engine.hand import Hand
|
from src import Hand
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src import MahjongTile
|
||||||
|
|
||||||
|
|
||||||
def test_add_tile():
|
def test_add_tile():
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.engine.mahjong_tile import MahjongTile
|
from src import MahjongTile
|
||||||
|
|
||||||
def test_mahjong_tile():
|
def test_mahjong_tile():
|
||||||
# 测试合法的牌
|
# 测试合法的牌
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from src.engine.scoring import calculate_score
|
from src import calculate_score
|
||||||
|
|
||||||
@pytest.mark.parametrize("fan, is_self_draw, base_score, expected_scores", [
|
@pytest.mark.parametrize("fan, is_self_draw, base_score, expected_scores", [
|
||||||
# 测试用例 1: 自摸,总番数 3
|
# 测试用例 1: 自摸,总番数 3
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from src.engine.utils import get_suit,get_tile_name
|
from src import get_suit,get_tile_name
|
||||||
|
|
||||||
def test_get_suit():
|
def test_get_suit():
|
||||||
# 测试条花色(0-35)
|
# 测试条花色(0-35)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue