mjAi/src/environment/chengdu_majiang_env.py

63 lines
2.2 KiB
Python

import gym
from gym import spaces
import numpy as np
from src.engine.chengdu_mahjong_state import ChengduMahjongState
class ChengduMahjongEnv(gym.Env):
def __init__(self):
super().__init__()
self.state = ChengduMahjongState()
self.action_space = spaces.Discrete(5) # 0: 出牌, 1: 碰, 2: 杠, 3: 胡, 4: 过
self.observation_space = spaces.Dict({
"hand": spaces.Box(low=0, high=4, shape=(108,), dtype=np.int32), # 手牌数量
"melds": spaces.Box(low=0, high=4, shape=(108,), dtype=np.int32), # 明牌数量
"discard_pile": spaces.Box(low=0, high=4, shape=(108,), dtype=np.int32), # 弃牌数量
"dealer": spaces.Discrete(4), # 当前庄家
})
self.reset()
def reset(self):
"""重置游戏状态"""
self.state.reset() # 初始化游戏状态
return self._get_observation()
def step(self, action):
reward = 0
done = False
if action == 0: # 出牌
self.state.discard()
elif action == 1: # 碰
self.state.peng()
elif action == 2: # 杠
self.state.kong()
elif action == 3: # 胡
reward, done = self.state.win()
elif action == 4: # 过
self.state.pass_turn()
# 检查游戏是否结束
done = done or self.state.is_game_over()
return self._get_observation(), reward, done, {}
def _get_observation(self):
"""获取玩家当前的观察空间"""
player_index = self.state.current_player
hand = np.zeros(108, dtype=np.int32)
melds = np.zeros(108, dtype=np.int32)
discard_pile = np.zeros(108, dtype=np.int32)
# 填充手牌、明牌和弃牌信息
for tile, count in self.state.hands[player_index].tile_count.items():
hand[tile.index] = count
for meld in self.state.melds[player_index]:
melds[meld.tile.index] += meld.count
for tile in self.state.discards[player_index]:
discard_pile[tile.index] += 1
return {
"hand": hand,
"melds": melds,
"discard_pile": discard_pile,
"dealer": self.state.current_player
}