feat(whatsapp): 实现 WhatsApp 风格的端到端加密通信原型

- 添加了基于 X3DH 和 Double Ratchet 的加密会话逻辑
- 实现了客户端密钥生成、bundle 注册与获取
- 构建了极简中继服务器用于转发加密消息
- 支持消息加密、解密及 MAC 校验- 提供完整演示流程,包括双向通信和多消息发送
- 使用 AES-CBC 加密和 HMAC-SHA256 认证
- 引入 X25519 密钥交换和 HKDF 密钥派生函数
- 包含一次性预共享密钥(OPK)管理机制
main
wsy182 2025-10-06 14:42:14 +08:00
parent f7eb913be3
commit 266ba6ad31
2 changed files with 356 additions and 3 deletions

View File

@ -1,7 +1,7 @@
import ssl
import websocket
URL = "wss://192.168.1.3/ws/" # 注意走 443不要再连 8080 了
URL = "ws://192.168.1.41:9516/ws/" # 注意走 443不要再连 8080 了
# 如果你的 WS 路径是 /ws/,就写成上面这样;若是别的路径自己改
def on_message(ws, msg): print("收到:", msg)
@ -9,7 +9,7 @@ def on_error(ws, err): print("错误:", err)
def on_close(ws, code, reason): print("关闭:", code, reason)
def on_open(ws):
print("连接成功")
ws.send("hello server")
ws.send("hello server mac")
if __name__ == "__main__":
websocket.enableTrace(True)
@ -19,7 +19,8 @@ if __name__ == "__main__":
on_message=on_message,
on_error=on_error,
on_close=on_close,
header=["Origin: https://192.168.1.3"] # 如后端不校验 Origin 可删
# header=["Origin: https://192.168.1.3"] # 如后端不校验 Origin 可删
header=[] # 如后端不校验 Origin 可删
)
ws.run_forever(sslopt={
"cert_reqs": ssl.CERT_NONE,

352
test_whatsapp.py Normal file
View File

@ -0,0 +1,352 @@
import os, json, hmac, base64, hashlib, time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
# ------------------------- Utils -------------------------
def b64e(b: bytes) -> str:
return base64.b64encode(b).decode()
def b64d(s: str) -> bytes:
return base64.b64decode(s.encode())
def hkdf(ikm: bytes, info: bytes, length: int = 32) -> bytes:
return HKDF(algorithm=hashes.SHA256(), length=length, salt=None, info=info).derive(ikm)
def aes_cbc_enc(key: bytes, iv: bytes, pt: bytes) -> bytes:
# PKCS#7 padding
pad = 16 - (len(pt) % 16)
pt = pt + bytes([pad])*pad
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
return cipher.encryptor().update(pt) + cipher.encryptor().finalize()
def aes_cbc_dec(key: bytes, iv: bytes, ct: bytes) -> bytes:
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
pt = cipher.decryptor().update(ct) + cipher.decryptor().finalize()
pad = pt[-1]
if pad < 1 or pad > 16 or pt[-pad:] != bytes([pad])*pad:
raise ValueError("Bad padding")
return pt[:-pad]
def hmac_sha256(key: bytes, data: bytes) -> bytes:
return hmac.new(key, data, hashlib.sha256).digest()
def x25519_shared(sk: X25519PrivateKey, pk: X25519PublicKey) -> bytes:
return sk.exchange(pk)
def pub_bytes(pk: X25519PublicKey) -> bytes:
return pk.public_bytes(encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw)
def priv_from_bytes(b: bytes) -> X25519PrivateKey:
return X25519PrivateKey.from_private_bytes(b)
def pub_from_bytes(b: bytes) -> X25519PublicKey:
return X25519PublicKey.from_public_bytes(b)
# ------------------ Data structures ----------------------
@dataclass
class PublicBundle:
identity_pub: str
signed_prekey_pub: str
signed_prekey_sig: str # 演示保留字段,未做真实签名校验
onetime_prekeys: Dict[str, str] # id -> pub
@dataclass
class SessionState:
# 简化:只做单向发送链(对称 ratchet 计数)
root_key: bytes
chain_key: bytes
counter: int
peer_identity_pub_b64: str
opk_id: Optional[str] # 用了哪个一次性预键(初始化时)
established_at: float
# ------------------------- Server ------------------------
class RelayServer:
"""
极简服务器存放用户公开密钥束提供查询转发密文帧
不保存/解密消息正文
"""
def __init__(self):
self.directory: Dict[str, PublicBundle] = {}
self.mailbox: Dict[str, List[dict]] = {}
def register_bundle(self, user: str, bundle: PublicBundle):
self.directory[user] = bundle
def fetch_bundle(self, user: str) -> Optional[PublicBundle]:
return self.directory.get(user)
def mark_onetime_used(self, user: str, opk_id: str):
b = self.directory.get(user)
if not b: return
if opk_id in b.onetime_prekeys:
del b.onetime_prekeys[opk_id]
def send_frame(self, to_user: str, frame: dict):
self.mailbox.setdefault(to_user, []).append(frame)
def pull_frames(self, user: str) -> List[dict]:
return self.mailbox.pop(user, [])
# ------------------------- Client ------------------------
class Client:
def __init__(self, name: str, server: RelayServer):
self.name = name
self.server = server
# 长期身份密钥
self.ik_priv = X25519PrivateKey.generate()
self.ik_pub = self.ik_priv.public_key()
# 已签名预共享密钥(演示:不做真实签名)
self.spk_priv = X25519PrivateKey.generate()
self.spk_pub = self.spk_priv.public_key()
self.spk_sig = os.urandom(64) # 占位
# 一次性预共享密钥池
self.opk_priv_map: Dict[str, X25519PrivateKey] = {}
self.opk_pub_map: Dict[str, X25519PublicKey] = {}
for i in range(5):
sk = X25519PrivateKey.generate()
pk = sk.public_key()
opk_id = f"opk-{i}"
self.opk_priv_map[opk_id] = sk
self.opk_pub_map[opk_id] = pk
# 会话(按对端用户)
self.sessions: Dict[str, SessionState] = {}
def publish(self):
bundle = PublicBundle(
identity_pub=b64e(pub_bytes(self.ik_pub)),
signed_prekey_pub=b64e(pub_bytes(self.spk_pub)),
signed_prekey_sig=b64e(self.spk_sig),
onetime_prekeys={oid: b64e(pub_bytes(pk)) for oid, pk in self.opk_pub_map.items()}
)
self.server.register_bundle(self.name, bundle)
# ---------- 建会话(发起方)简化 X3DH ----------
def establish_session_as_initiator(self, peer: str) -> SessionState:
pb = self.server.fetch_bundle(peer)
assert pb, f"{peer} has no bundle"
IKb = pub_from_bytes(b64d(pb.identity_pub))
SPKb = pub_from_bytes(b64d(pb.signed_prekey_pub))
# 选一个对方的一次性预键
opk_items = list(pb.onetime_prekeys.items())
opk_id, OPKb_b64 = opk_items[0] if opk_items else (None, None)
OPKb = pub_from_bytes(b64d(OPKb_b64)) if OPKb_b64 else None
# 发起者生成临时密钥对
EKa_priv = X25519PrivateKey.generate()
EKa_pub = EKa_priv.public_key()
# X3DH 4次 ECDH无OPK则忽略第4项
s1 = x25519_shared(self.ik_priv, SPKb)
s2 = x25519_shared(EKa_priv, IKb)
s3 = x25519_shared(EKa_priv, SPKb)
parts = [s1, s2, s3]
if OPKb:
s4 = x25519_shared(EKa_priv, OPKb)
parts.append(s4)
master_secret = b"".join(parts)
root_key = hkdf(master_secret, info=b"ROOT", length=32)
chain_key = hkdf(root_key, info=b"CHAIN", length=32)
st = SessionState(
root_key=root_key,
chain_key=chain_key,
counter=0,
peer_identity_pub_b64=pb.identity_pub,
opk_id=opk_id,
established_at=time.time()
)
self.sessions[peer] = st
# 发送“建会话 + 第一条消息”的头信息(包含 EKa_pub、opk_id
self._pending_ephemeral_pub = b64e(pub_bytes(EKa_pub))
self._pending_opk_id = opk_id
return st
# ---------- 建会话(接收方) ----------
def _establish_session_as_responder(self, peer: str, ek_pub_b64: str, opk_id: Optional[str]) -> SessionState:
EKa = pub_from_bytes(b64d(ek_pub_b64)) # 对方临时公钥
# 自己的密钥
IKb_priv = self.ik_priv
SPKb_priv = self.spk_priv
OPKb_priv = self.opk_priv_map.get(opk_id) if opk_id else None
# 4次 ECDH无OPK则忽略第4项
s1 = x25519_shared(SPKb_priv, pub_from_bytes(b64d(self.server.directory[peer].identity_pub))) # = ECDH(Ia, SPKb)
# 注意:发起方 s1 是 ECDH(Ia, SPKb),接收方等价项应是 ECDH(SPKb, Ia)
# 但我们没有 Ia 私钥,这里换一种对称表达:按消息头与本地密钥构造相同串联
# 为确保与发起方一致,我们直接重算:
# 对于接收方ECDH(Iinitiator, Srecipient) == ECDH(Srecipient, Iinitiator)
# 需要 Iinitiator 公钥:来自会话第一帧中?为简化,我们用目录中对方 identity_pub。
Ia_pub = pub_from_bytes(b64d(self.server.directory[peer].identity_pub))
s1 = x25519_shared(self.spk_priv, Ia_pub)
s2 = x25519_shared(self.ik_priv, EKa)
s3 = x25519_shared(self.spk_priv, EKa)
parts = [s1, s2, s3]
if OPKb_priv:
s4 = x25519_shared(OPKb_priv, EKa)
parts.append(s4)
master_secret = b"".join(parts)
root_key = hkdf(master_secret, info=b"ROOT", length=32)
chain_key = hkdf(root_key, info=b"CHAIN", length=32)
st = SessionState(
root_key=root_key,
chain_key=chain_key,
counter=0,
peer_identity_pub_b64=self.server.directory[peer].identity_pub,
opk_id=opk_id,
established_at=time.time()
)
self.sessions[peer] = st
# 一次性预键被使用后,服务端目录也标记删除
if opk_id:
self.server.mark_onetime_used(self.name, opk_id)
return st
# ---------- 每条消息的派生与加密 ----------
def _derive_message_key(self, st: SessionState) -> Tuple[bytes, bytes, bytes]:
"""
chain_key 派生本条消息的 (aes_key, hmac_key, iv)然后更新 chain_keycounter+1
"""
info = b"MSG|" + st.counter.to_bytes(8, "big")
msg_key = hkdf(st.chain_key, info=info, length=80) # 32 AES + 32 HMAC + 16 IV
aes_key = msg_key[:32]
mac_key = msg_key[32:64]
iv = msg_key[64:80]
# 下一条链
st.chain_key = hkdf(st.chain_key, info=b"STEP", length=32)
st.counter += 1
return aes_key, mac_key, iv
# ---------- 发送 ----------
def send(self, to_user: str, plaintext: bytes):
# 若没有会话,先建立
if to_user not in self.sessions:
st = self.establish_session_as_initiator(to_user)
ek_pub_b64 = self._pending_ephemeral_pub
opk_id = self._pending_opk_id
else:
st = self.sessions[to_user]
ek_pub_b64 = None
opk_id = None
aes_key, mac_key, iv = self._derive_message_key(st)
ct = aes_cbc_enc(aes_key, iv, plaintext)
mac = hmac_sha256(mac_key, iv + ct)[:8]
frame = {
"from": self.name,
"to": to_user,
"hdr": {
"init": ek_pub_b64 is not None,
"ek_pub_b64": ek_pub_b64, # 仅首次带上
"opk_id": opk_id,
"counter": st.counter - 1, # 本条的计数
},
"body": {
"iv": b64e(iv),
"ct": b64e(ct),
"mac": b64e(mac),
}
}
self.server.send_frame(to_user, frame)
# ---------- 接收 ----------
def receive_all(self) -> List[Tuple[str, bytes]]:
frames = self.server.pull_frames(self.name)
outputs = []
for f in frames:
sender = f["from"]
hdr = f["hdr"]
body = f["body"]
if sender not in self.sessions:
# 首帧:建立被动会话
assert hdr["init"], "Missing init header"
st = self._establish_session_as_responder(
peer=sender,
ek_pub_b64=hdr["ek_pub_b64"],
opk_id=hdr["opk_id"]
)
else:
st = self.sessions[sender]
# 按对方 counter 对齐(演示:假设顺序到达)
aes_key, mac_key, iv = self._derive_message_key(st)
if st.counter - 1 != hdr["counter"]:
# 简化:严格顺序,真实实现需支持跳号、乱序恢复
raise ValueError("Out-of-order message (demo limitation)")
if b64e(iv) != body["iv"]:
# 教学演示:我们强制使用派生的 iv
iv = b64d(body["iv"]) # 宽松一点也可以直接信任对端 iv
ct = b64d(body["ct"])
mac = b64d(body["mac"])
calc = hmac_sha256(mac_key, iv + ct)[:8]
if not hmac.compare_digest(mac, calc):
raise ValueError("MAC verification failed")
pt = aes_cbc_dec(aes_key, iv, ct)
outputs.append((sender, pt))
return outputs
# ------------------------- Demo --------------------------
def main():
server = RelayServer()
alice = Client("alice", server)
bob = Client("bob", server)
# 发布各自的公钥包
alice.publish()
bob.publish()
# Alice 先发一条(会自动建会话)
alice.send("bob", b"Hello Bob, this is Alice.")
# Bob 拉取并解密
for frm in bob.receive_all():
print("[bob] got:", frm)
# Bob 回复
bob.send("alice", b"Hi Alice, Bob here. Message received.")
for frm in alice.receive_all():
print("[alice] got:", frm)
# 连发多条观察每条都换密钥server 无法解密)
for i in range(1, 4):
alice.send("bob", f"Msg#{i} from Alice".encode())
for frm in bob.receive_all():
print("[bob] got:", frm)
# 验证“服务器看不到明文”
print("\n[server] directory keys (truncated):")
print(json.dumps({
u: {
"identity_pub": v.identity_pub[:24] + "...",
"spk_pub": v.signed_prekey_pub[:24] + "...",
"opk_count": len(v.onetime_prekeys)
} for u, v in server.directory.items()
}, indent=2))
if __name__ == "__main__":
main()