Compare commits

...

2 Commits

Author SHA1 Message Date
7632edd0e3 Create chengdu_majiang_env.py 2024-11-30 17:48:04 +08:00
6c6fdff706 Update actions.py 2024-11-30 17:37:24 +08:00
2 changed files with 35 additions and 1 deletions

View File

@@ -1,5 +1,5 @@
from loguru import logger
from utils import get_tile_name # 确保 get_tile_name 已在 utils.py 中定义并导入
from utils import get_tile_name
def draw_tile(state):

View File

@@ -0,0 +1,34 @@
import gym
from gym import spaces
import numpy as np
from src.engine.chengdu_mahjong_engine import ChengduMahjongEngine
class MahjongEnv(gym.Env):
def __init__(self):
super(MahjongEnv, self).__init__()
self.engine = ChengduMahjongEngine()
self.action_space = spaces.Discrete(108) # 假设108种动作摸牌/打牌)
self.observation_space = spaces.Box(low=0, high=4, shape=(108,), dtype=np.int32)
def reset(self):
self.engine = ChengduMahjongEngine()
return self.engine.state.hands[self.engine.state.current_player]
def step(self, action):
reward = 0
done = False
try:
self.engine.discard_tile(action)
reward = self.calculate_reward() # 根据胡牌等状态定义奖励
except ValueError:
reward = -10 # 非法操作扣分
return self.engine.state.hands[self.engine.state.current_player], reward, done, {}
def calculate_reward(self):
if self.engine.state.can_win(self.engine.state.hands[self.engine.state.current_player]):
return 100 # 胡牌奖励
return -1 # 默认每步小惩罚
def render(self, mode="human"):
print("当前状态:", self.engine.state.hands)