import os, json, hmac, base64, hashlib, time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey from cryptography.hazmat.primitives.kdf.hkdf import HKDF from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes # ------------------------- Utils ------------------------- def b64e(b: bytes) -> str: return base64.b64encode(b).decode() def b64d(s: str) -> bytes: return base64.b64decode(s.encode()) def hkdf(ikm: bytes, info: bytes, length: int = 32) -> bytes: return HKDF(algorithm=hashes.SHA256(), length=length, salt=None, info=info).derive(ikm) def aes_cbc_enc(key: bytes, iv: bytes, pt: bytes) -> bytes: # PKCS#7 padding pad = 16 - (len(pt) % 16) pt = pt + bytes([pad])*pad cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) return cipher.encryptor().update(pt) + cipher.encryptor().finalize() def aes_cbc_dec(key: bytes, iv: bytes, ct: bytes) -> bytes: cipher = Cipher(algorithms.AES(key), modes.CBC(iv)) pt = cipher.decryptor().update(ct) + cipher.decryptor().finalize() pad = pt[-1] if pad < 1 or pad > 16 or pt[-pad:] != bytes([pad])*pad: raise ValueError("Bad padding") return pt[:-pad] def hmac_sha256(key: bytes, data: bytes) -> bytes: return hmac.new(key, data, hashlib.sha256).digest() def x25519_shared(sk: X25519PrivateKey, pk: X25519PublicKey) -> bytes: return sk.exchange(pk) def pub_bytes(pk: X25519PublicKey) -> bytes: return pk.public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw) def priv_from_bytes(b: bytes) -> X25519PrivateKey: return X25519PrivateKey.from_private_bytes(b) def pub_from_bytes(b: bytes) -> X25519PublicKey: return X25519PublicKey.from_public_bytes(b) # ------------------ Data structures ---------------------- @dataclass class PublicBundle: identity_pub: str signed_prekey_pub: str signed_prekey_sig: str # 演示保留字段,未做真实签名校验 onetime_prekeys: Dict[str, str] # id -> pub @dataclass class SessionState: # 简化:只做单向发送链(对称 ratchet 计数) root_key: bytes chain_key: bytes counter: int peer_identity_pub_b64: str opk_id: Optional[str] # 用了哪个一次性预键(初始化时) established_at: float # ------------------------- Server ------------------------ class RelayServer: """ 极简“服务器”:存放用户公开密钥束,提供查询;转发密文帧。 不保存/解密消息正文。 """ def __init__(self): self.directory: Dict[str, PublicBundle] = {} self.mailbox: Dict[str, List[dict]] = {} def register_bundle(self, user: str, bundle: PublicBundle): self.directory[user] = bundle def fetch_bundle(self, user: str) -> Optional[PublicBundle]: return self.directory.get(user) def mark_onetime_used(self, user: str, opk_id: str): b = self.directory.get(user) if not b: return if opk_id in b.onetime_prekeys: del b.onetime_prekeys[opk_id] def send_frame(self, to_user: str, frame: dict): self.mailbox.setdefault(to_user, []).append(frame) def pull_frames(self, user: str) -> List[dict]: return self.mailbox.pop(user, []) # ------------------------- Client ------------------------ class Client: def __init__(self, name: str, server: RelayServer): self.name = name self.server = server # 长期身份密钥 self.ik_priv = X25519PrivateKey.generate() self.ik_pub = self.ik_priv.public_key() # 已签名预共享密钥(演示:不做真实签名) self.spk_priv = X25519PrivateKey.generate() self.spk_pub = self.spk_priv.public_key() self.spk_sig = os.urandom(64) # 占位 # 一次性预共享密钥池 self.opk_priv_map: Dict[str, X25519PrivateKey] = {} self.opk_pub_map: Dict[str, X25519PublicKey] = {} for i in range(5): sk = X25519PrivateKey.generate() pk = sk.public_key() opk_id = f"opk-{i}" self.opk_priv_map[opk_id] = sk self.opk_pub_map[opk_id] = pk # 会话(按对端用户) self.sessions: Dict[str, SessionState] = {} def publish(self): bundle = PublicBundle( identity_pub=b64e(pub_bytes(self.ik_pub)), signed_prekey_pub=b64e(pub_bytes(self.spk_pub)), signed_prekey_sig=b64e(self.spk_sig), onetime_prekeys={oid: b64e(pub_bytes(pk)) for oid, pk in self.opk_pub_map.items()} ) self.server.register_bundle(self.name, bundle) # ---------- 建会话(发起方)简化 X3DH ---------- def establish_session_as_initiator(self, peer: str) -> SessionState: pb = self.server.fetch_bundle(peer) assert pb, f"{peer} has no bundle" IKb = pub_from_bytes(b64d(pb.identity_pub)) SPKb = pub_from_bytes(b64d(pb.signed_prekey_pub)) # 选一个对方的一次性预键 opk_items = list(pb.onetime_prekeys.items()) opk_id, OPKb_b64 = opk_items[0] if opk_items else (None, None) OPKb = pub_from_bytes(b64d(OPKb_b64)) if OPKb_b64 else None # 发起者生成临时密钥对 EKa_priv = X25519PrivateKey.generate() EKa_pub = EKa_priv.public_key() # X3DH 4次 ECDH(无OPK则忽略第4项) s1 = x25519_shared(self.ik_priv, SPKb) s2 = x25519_shared(EKa_priv, IKb) s3 = x25519_shared(EKa_priv, SPKb) parts = [s1, s2, s3] if OPKb: s4 = x25519_shared(EKa_priv, OPKb) parts.append(s4) master_secret = b"".join(parts) root_key = hkdf(master_secret, info=b"ROOT", length=32) chain_key = hkdf(root_key, info=b"CHAIN", length=32) st = SessionState( root_key=root_key, chain_key=chain_key, counter=0, peer_identity_pub_b64=pb.identity_pub, opk_id=opk_id, established_at=time.time() ) self.sessions[peer] = st # 发送“建会话 + 第一条消息”的头信息(包含 EKa_pub、opk_id) self._pending_ephemeral_pub = b64e(pub_bytes(EKa_pub)) self._pending_opk_id = opk_id return st # ---------- 建会话(接收方) ---------- def _establish_session_as_responder(self, peer: str, ek_pub_b64: str, opk_id: Optional[str]) -> SessionState: EKa = pub_from_bytes(b64d(ek_pub_b64)) # 对方临时公钥 # 自己的密钥 IKb_priv = self.ik_priv SPKb_priv = self.spk_priv OPKb_priv = self.opk_priv_map.get(opk_id) if opk_id else None # 4次 ECDH(无OPK则忽略第4项) s1 = x25519_shared(SPKb_priv, pub_from_bytes(b64d(self.server.directory[peer].identity_pub))) # = ECDH(Ia, SPKb) # 注意:发起方 s1 是 ECDH(Ia, SPKb),接收方等价项应是 ECDH(SPKb, Ia) # 但我们没有 Ia 私钥,这里换一种对称表达:按消息头与本地密钥构造相同串联 # 为确保与发起方一致,我们直接重算: # 对于接收方:ECDH(Iinitiator, Srecipient) == ECDH(Srecipient, Iinitiator) # 需要 Iinitiator 公钥:来自会话第一帧中?为简化,我们用目录中对方 identity_pub。 Ia_pub = pub_from_bytes(b64d(self.server.directory[peer].identity_pub)) s1 = x25519_shared(self.spk_priv, Ia_pub) s2 = x25519_shared(self.ik_priv, EKa) s3 = x25519_shared(self.spk_priv, EKa) parts = [s1, s2, s3] if OPKb_priv: s4 = x25519_shared(OPKb_priv, EKa) parts.append(s4) master_secret = b"".join(parts) root_key = hkdf(master_secret, info=b"ROOT", length=32) chain_key = hkdf(root_key, info=b"CHAIN", length=32) st = SessionState( root_key=root_key, chain_key=chain_key, counter=0, peer_identity_pub_b64=self.server.directory[peer].identity_pub, opk_id=opk_id, established_at=time.time() ) self.sessions[peer] = st # 一次性预键被使用后,服务端目录也标记删除 if opk_id: self.server.mark_onetime_used(self.name, opk_id) return st # ---------- 每条消息的派生与加密 ---------- def _derive_message_key(self, st: SessionState) -> Tuple[bytes, bytes, bytes]: """ 从 chain_key 派生本条消息的 (aes_key, hmac_key, iv),然后更新 chain_key,counter+1 """ info = b"MSG|" + st.counter.to_bytes(8, "big") msg_key = hkdf(st.chain_key, info=info, length=80) # 32 AES + 32 HMAC + 16 IV aes_key = msg_key[:32] mac_key = msg_key[32:64] iv = msg_key[64:80] # 下一条链 st.chain_key = hkdf(st.chain_key, info=b"STEP", length=32) st.counter += 1 return aes_key, mac_key, iv # ---------- 发送 ---------- def send(self, to_user: str, plaintext: bytes): # 若没有会话,先建立 if to_user not in self.sessions: st = self.establish_session_as_initiator(to_user) ek_pub_b64 = self._pending_ephemeral_pub opk_id = self._pending_opk_id else: st = self.sessions[to_user] ek_pub_b64 = None opk_id = None aes_key, mac_key, iv = self._derive_message_key(st) ct = aes_cbc_enc(aes_key, iv, plaintext) mac = hmac_sha256(mac_key, iv + ct)[:8] frame = { "from": self.name, "to": to_user, "hdr": { "init": ek_pub_b64 is not None, "ek_pub_b64": ek_pub_b64, # 仅首次带上 "opk_id": opk_id, "counter": st.counter - 1, # 本条的计数 }, "body": { "iv": b64e(iv), "ct": b64e(ct), "mac": b64e(mac), } } self.server.send_frame(to_user, frame) # ---------- 接收 ---------- def receive_all(self) -> List[Tuple[str, bytes]]: frames = self.server.pull_frames(self.name) outputs = [] for f in frames: sender = f["from"] hdr = f["hdr"] body = f["body"] if sender not in self.sessions: # 首帧:建立被动会话 assert hdr["init"], "Missing init header" st = self._establish_session_as_responder( peer=sender, ek_pub_b64=hdr["ek_pub_b64"], opk_id=hdr["opk_id"] ) else: st = self.sessions[sender] # 按对方 counter 对齐(演示:假设顺序到达) aes_key, mac_key, iv = self._derive_message_key(st) if st.counter - 1 != hdr["counter"]: # 简化:严格顺序,真实实现需支持跳号、乱序恢复 raise ValueError("Out-of-order message (demo limitation)") if b64e(iv) != body["iv"]: # 教学演示:我们强制使用派生的 iv iv = b64d(body["iv"]) # 宽松一点也可以直接信任对端 iv ct = b64d(body["ct"]) mac = b64d(body["mac"]) calc = hmac_sha256(mac_key, iv + ct)[:8] if not hmac.compare_digest(mac, calc): raise ValueError("MAC verification failed") pt = aes_cbc_dec(aes_key, iv, ct) outputs.append((sender, pt)) return outputs # ------------------------- Demo -------------------------- def main(): server = RelayServer() alice = Client("alice", server) bob = Client("bob", server) # 发布各自的公钥包 alice.publish() bob.publish() # Alice 先发一条(会自动建会话) alice.send("bob", b"Hello Bob, this is Alice.") # Bob 拉取并解密 for frm in bob.receive_all(): print("[bob] got:", frm) # Bob 回复 bob.send("alice", b"Hi Alice, Bob here. Message received.") for frm in alice.receive_all(): print("[alice] got:", frm) # 连发多条,观察每条都换密钥(server 无法解密) for i in range(1, 4): alice.send("bob", f"Msg#{i} from Alice".encode()) for frm in bob.receive_all(): print("[bob] got:", frm) # 验证“服务器看不到明文” print("\n[server] directory keys (truncated):") print(json.dumps({ u: { "identity_pub": v.identity_pub[:24] + "...", "spk_pub": v.signed_prekey_pub[:24] + "...", "opk_count": len(v.onetime_prekeys) } for u, v in server.directory.items() }, indent=2)) if __name__ == "__main__": main()