parent
a14984a263
commit
1d6593a8f2
|
|
@ -0,0 +1,46 @@
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from src.environment.dizhu_env import DouDiZhuEnv # 导入斗地主环境
|
||||||
|
import torch
|
||||||
|
from configs.log_config import setup_logging
|
||||||
|
from loguru import logger # 添加 logger
|
||||||
|
|
||||||
|
def train_dizhu_model():
|
||||||
|
# 创建 DouDiZhuEnv 环境实例
|
||||||
|
env = DouDiZhuEnv()
|
||||||
|
|
||||||
|
# 检查是否有可用的 GPU
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
logger.info(f"使用设备: {device}") # 使用 logger 记录设备信息
|
||||||
|
|
||||||
|
# 使用 PPO 算法训练模型,设置为 MultiInputPolicy
|
||||||
|
model = PPO(
|
||||||
|
"MultiInputPolicy", # 适用于多输入的策略
|
||||||
|
env,
|
||||||
|
verbose=1,
|
||||||
|
tensorboard_log="../logs/ppo_doudizhu_tensorboard/", # TensorBoard 日志路径
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练模型,设定总训练步数
|
||||||
|
logger.info("开始训练斗地主模型...")
|
||||||
|
model.learn(total_timesteps=100000) # 总训练步数
|
||||||
|
logger.info("斗地主模型训练完成!")
|
||||||
|
|
||||||
|
# 保存训练后的模型
|
||||||
|
model_path = "../models/ppo_doudizhu_model"
|
||||||
|
model.save(model_path)
|
||||||
|
logger.info(f"模型已保存到 '{model_path}'")
|
||||||
|
|
||||||
|
# 测试模型
|
||||||
|
logger.info("开始测试斗地主模型...")
|
||||||
|
obs = env.reset()
|
||||||
|
done = False
|
||||||
|
while not done:
|
||||||
|
action, _states = model.predict(obs) # 使用训练好的模型来选择动作
|
||||||
|
obs, reward, done, info = env.step(action) # 执行动作
|
||||||
|
logger.info(f"动作: {action}, 奖励: {reward}, 是否结束: {done}, 信息: {info}") # 记录测试过程
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 设置日志
|
||||||
|
setup_logging()
|
||||||
|
train_dizhu_model()
|
||||||
|
|
@ -36,23 +36,28 @@ class DiZhuEngine:
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
"""
|
"""
|
||||||
执行玩家的出牌动作。
|
执行动作并更新状态
|
||||||
:param action: 当前玩家的动作(出牌或过牌)
|
:param action: 当前玩家的动作(可以是 'pass' 或一个动作列表)
|
||||||
"""
|
"""
|
||||||
current_player = self.get_current_player()
|
current_player = self.get_current_player()
|
||||||
|
|
||||||
if action == "pass":
|
if action == "pass":
|
||||||
# 玩家选择过牌
|
# 过牌作为单独的历史记录
|
||||||
current_player.history.append("pass")
|
current_player.history.append([])
|
||||||
else:
|
else:
|
||||||
# 玩家出牌,移除对应牌
|
# 确保动作始终为列表
|
||||||
|
if not isinstance(action, list):
|
||||||
|
action = [action]
|
||||||
|
# 检查动作合法性
|
||||||
if not all(card in current_player.hand_cards for card in action):
|
if not all(card in current_player.hand_cards for card in action):
|
||||||
raise ValueError("玩家手牌不足以完成此次出牌")
|
raise ValueError(f"玩家手牌不足以完成此次出牌: {action}")
|
||||||
|
# 移除出牌
|
||||||
for card in action:
|
for card in action:
|
||||||
current_player.hand_cards.remove(card)
|
current_player.hand_cards.remove(card)
|
||||||
|
# 记录动作
|
||||||
current_player.history.append(action)
|
current_player.history.append(action)
|
||||||
|
|
||||||
# 检查是否游戏结束
|
# 检查游戏是否结束
|
||||||
if not current_player.hand_cards:
|
if not current_player.hand_cards:
|
||||||
self.game_over = True
|
self.game_over = True
|
||||||
return f"{current_player.role} 胜利!"
|
return f"{current_player.role} 胜利!"
|
||||||
|
|
@ -60,6 +65,7 @@ class DiZhuEngine:
|
||||||
# 切换到下一个玩家
|
# 切换到下一个玩家
|
||||||
self.current_player_index = (self.current_player_index + 1) % 3
|
self.current_player_index = (self.current_player_index + 1) % 3
|
||||||
|
|
||||||
|
|
||||||
def get_game_state(self):
|
def get_game_state(self):
|
||||||
"""
|
"""
|
||||||
返回当前游戏状态,包括玩家手牌、出牌历史和当前玩家。
|
返回当前游戏状态,包括玩家手牌、出牌历史和当前玩家。
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
class PlayerState:
|
class PlayerState:
|
||||||
def __init__(self, hand_cards, role):
|
def __init__(self, hand_cards, role):
|
||||||
self.hand_cards = hand_cards # 玩家手牌
|
self.hand_cards = hand_cards # 玩家手牌
|
||||||
self.role = role # "地主" 或 "农民"
|
self.role = role # "地主" 或 "农民"
|
||||||
self.history = [] # 出牌历史
|
self.history = deque() # 出牌历史,使用 deque
|
||||||
|
|
|
||||||
|
|
@ -40,14 +40,22 @@ class DouDiZhuEnv(gym.Env):
|
||||||
def _get_observation(self):
|
def _get_observation(self):
|
||||||
"""获取当前玩家的状态"""
|
"""获取当前玩家的状态"""
|
||||||
current_player = self.engine.get_current_player()
|
current_player = self.engine.get_current_player()
|
||||||
|
|
||||||
|
# 手牌的独热编码
|
||||||
hand_cards = np.zeros(54, dtype=np.int32)
|
hand_cards = np.zeros(54, dtype=np.int32)
|
||||||
for card in current_player.hand_cards:
|
for card in current_player.hand_cards:
|
||||||
hand_cards[card] = 1
|
hand_cards[card] = 1
|
||||||
|
|
||||||
|
# 出牌历史的独热编码
|
||||||
history = np.zeros(54, dtype=np.int32)
|
history = np.zeros(54, dtype=np.int32)
|
||||||
for play in current_player.history:
|
for play in current_player.history:
|
||||||
|
if isinstance(play, list): # 如果是列表
|
||||||
for card in play:
|
for card in play:
|
||||||
history[card] = 1
|
history[card] = 1
|
||||||
|
elif isinstance(play, int): # 如果是单个牌
|
||||||
|
history[play] = 1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"历史记录格式错误: {play}")
|
||||||
|
|
||||||
return {"hand_cards": hand_cards, "history": history}
|
return {"hand_cards": hand_cards, "history": history}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue