63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
import gym
|
|
from gym import spaces
|
|
import numpy as np
|
|
from src.engine.chengdu_mahjong_state import ChengduMahjongState
|
|
|
|
class ChengduMahjongEnv(gym.Env):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.state = ChengduMahjongState()
|
|
self.action_space = spaces.Discrete(5) # 0: 出牌, 1: 碰, 2: 杠, 3: 胡, 4: 过
|
|
self.observation_space = spaces.Dict({
|
|
"hand": spaces.Box(low=0, high=4, shape=(108,), dtype=np.int32), # 手牌数量
|
|
"melds": spaces.Box(low=0, high=4, shape=(108,), dtype=np.int32), # 明牌数量
|
|
"discard_pile": spaces.Box(low=0, high=4, shape=(108,), dtype=np.int32), # 弃牌数量
|
|
"dealer": spaces.Discrete(4), # 当前庄家
|
|
})
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
"""重置游戏状态"""
|
|
self.state.reset() # 初始化游戏状态
|
|
return self._get_observation()
|
|
|
|
def step(self, action):
|
|
reward = 0
|
|
done = False
|
|
|
|
if action == 0: # 出牌
|
|
self.state.discard()
|
|
elif action == 1: # 碰
|
|
self.state.peng()
|
|
elif action == 2: # 杠
|
|
self.state.kong()
|
|
elif action == 3: # 胡
|
|
reward, done = self.state.win()
|
|
elif action == 4: # 过
|
|
self.state.pass_turn()
|
|
|
|
# 检查游戏是否结束
|
|
done = done or self.state.is_game_over()
|
|
return self._get_observation(), reward, done, {}
|
|
|
|
def _get_observation(self):
|
|
"""获取玩家当前的观察空间"""
|
|
player_index = self.state.current_player
|
|
hand = np.zeros(108, dtype=np.int32)
|
|
melds = np.zeros(108, dtype=np.int32)
|
|
discard_pile = np.zeros(108, dtype=np.int32)
|
|
|
|
# 填充手牌、明牌和弃牌信息
|
|
for tile, count in self.state.hands[player_index].tile_count.items():
|
|
hand[tile.index] = count
|
|
for meld in self.state.melds[player_index]:
|
|
melds[meld.tile.index] += meld.count
|
|
for tile in self.state.discards[player_index]:
|
|
discard_pile[tile.index] += 1
|
|
|
|
return {
|
|
"hand": hand,
|
|
"melds": melds,
|
|
"discard_pile": discard_pile,
|
|
"dealer": self.state.current_player
|
|
} |