352 lines
13 KiB
Python
352 lines
13 KiB
Python
import os, json, hmac, base64, hashlib, time
|
||
from dataclasses import dataclass
|
||
from typing import Dict, List, Optional, Tuple
|
||
from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey, X25519PublicKey
|
||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||
from cryptography.hazmat.primitives import hashes, serialization
|
||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||
|
||
# ------------------------- Utils -------------------------
|
||
|
||
def b64e(b: bytes) -> str:
|
||
return base64.b64encode(b).decode()
|
||
|
||
def b64d(s: str) -> bytes:
|
||
return base64.b64decode(s.encode())
|
||
|
||
def hkdf(ikm: bytes, info: bytes, length: int = 32) -> bytes:
|
||
return HKDF(algorithm=hashes.SHA256(), length=length, salt=None, info=info).derive(ikm)
|
||
|
||
def aes_cbc_enc(key: bytes, iv: bytes, pt: bytes) -> bytes:
|
||
# PKCS#7 padding
|
||
pad = 16 - (len(pt) % 16)
|
||
pt = pt + bytes([pad])*pad
|
||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||
return cipher.encryptor().update(pt) + cipher.encryptor().finalize()
|
||
|
||
def aes_cbc_dec(key: bytes, iv: bytes, ct: bytes) -> bytes:
|
||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||
pt = cipher.decryptor().update(ct) + cipher.decryptor().finalize()
|
||
pad = pt[-1]
|
||
if pad < 1 or pad > 16 or pt[-pad:] != bytes([pad])*pad:
|
||
raise ValueError("Bad padding")
|
||
return pt[:-pad]
|
||
|
||
def hmac_sha256(key: bytes, data: bytes) -> bytes:
|
||
return hmac.new(key, data, hashlib.sha256).digest()
|
||
|
||
def x25519_shared(sk: X25519PrivateKey, pk: X25519PublicKey) -> bytes:
|
||
return sk.exchange(pk)
|
||
|
||
def pub_bytes(pk: X25519PublicKey) -> bytes:
|
||
return pk.public_bytes(encoding=serialization.Encoding.Raw,
|
||
format=serialization.PublicFormat.Raw)
|
||
|
||
def priv_from_bytes(b: bytes) -> X25519PrivateKey:
|
||
return X25519PrivateKey.from_private_bytes(b)
|
||
|
||
def pub_from_bytes(b: bytes) -> X25519PublicKey:
|
||
return X25519PublicKey.from_public_bytes(b)
|
||
|
||
# ------------------ Data structures ----------------------
|
||
|
||
@dataclass
|
||
class PublicBundle:
|
||
identity_pub: str
|
||
signed_prekey_pub: str
|
||
signed_prekey_sig: str # 演示保留字段,未做真实签名校验
|
||
onetime_prekeys: Dict[str, str] # id -> pub
|
||
|
||
@dataclass
|
||
class SessionState:
|
||
# 简化:只做单向发送链(对称 ratchet 计数)
|
||
root_key: bytes
|
||
chain_key: bytes
|
||
counter: int
|
||
peer_identity_pub_b64: str
|
||
opk_id: Optional[str] # 用了哪个一次性预键(初始化时)
|
||
established_at: float
|
||
|
||
# ------------------------- Server ------------------------
|
||
|
||
class RelayServer:
|
||
"""
|
||
极简“服务器”:存放用户公开密钥束,提供查询;转发密文帧。
|
||
不保存/解密消息正文。
|
||
"""
|
||
def __init__(self):
|
||
self.directory: Dict[str, PublicBundle] = {}
|
||
self.mailbox: Dict[str, List[dict]] = {}
|
||
|
||
def register_bundle(self, user: str, bundle: PublicBundle):
|
||
self.directory[user] = bundle
|
||
|
||
def fetch_bundle(self, user: str) -> Optional[PublicBundle]:
|
||
return self.directory.get(user)
|
||
|
||
def mark_onetime_used(self, user: str, opk_id: str):
|
||
b = self.directory.get(user)
|
||
if not b: return
|
||
if opk_id in b.onetime_prekeys:
|
||
del b.onetime_prekeys[opk_id]
|
||
|
||
def send_frame(self, to_user: str, frame: dict):
|
||
self.mailbox.setdefault(to_user, []).append(frame)
|
||
|
||
def pull_frames(self, user: str) -> List[dict]:
|
||
return self.mailbox.pop(user, [])
|
||
|
||
# ------------------------- Client ------------------------
|
||
|
||
class Client:
|
||
def __init__(self, name: str, server: RelayServer):
|
||
self.name = name
|
||
self.server = server
|
||
|
||
# 长期身份密钥
|
||
self.ik_priv = X25519PrivateKey.generate()
|
||
self.ik_pub = self.ik_priv.public_key()
|
||
|
||
# 已签名预共享密钥(演示:不做真实签名)
|
||
self.spk_priv = X25519PrivateKey.generate()
|
||
self.spk_pub = self.spk_priv.public_key()
|
||
self.spk_sig = os.urandom(64) # 占位
|
||
|
||
# 一次性预共享密钥池
|
||
self.opk_priv_map: Dict[str, X25519PrivateKey] = {}
|
||
self.opk_pub_map: Dict[str, X25519PublicKey] = {}
|
||
for i in range(5):
|
||
sk = X25519PrivateKey.generate()
|
||
pk = sk.public_key()
|
||
opk_id = f"opk-{i}"
|
||
self.opk_priv_map[opk_id] = sk
|
||
self.opk_pub_map[opk_id] = pk
|
||
|
||
# 会话(按对端用户)
|
||
self.sessions: Dict[str, SessionState] = {}
|
||
|
||
def publish(self):
|
||
bundle = PublicBundle(
|
||
identity_pub=b64e(pub_bytes(self.ik_pub)),
|
||
signed_prekey_pub=b64e(pub_bytes(self.spk_pub)),
|
||
signed_prekey_sig=b64e(self.spk_sig),
|
||
onetime_prekeys={oid: b64e(pub_bytes(pk)) for oid, pk in self.opk_pub_map.items()}
|
||
)
|
||
self.server.register_bundle(self.name, bundle)
|
||
|
||
# ---------- 建会话(发起方)简化 X3DH ----------
|
||
def establish_session_as_initiator(self, peer: str) -> SessionState:
|
||
pb = self.server.fetch_bundle(peer)
|
||
assert pb, f"{peer} has no bundle"
|
||
|
||
IKb = pub_from_bytes(b64d(pb.identity_pub))
|
||
SPKb = pub_from_bytes(b64d(pb.signed_prekey_pub))
|
||
|
||
# 选一个对方的一次性预键
|
||
opk_items = list(pb.onetime_prekeys.items())
|
||
opk_id, OPKb_b64 = opk_items[0] if opk_items else (None, None)
|
||
OPKb = pub_from_bytes(b64d(OPKb_b64)) if OPKb_b64 else None
|
||
|
||
# 发起者生成临时密钥对
|
||
EKa_priv = X25519PrivateKey.generate()
|
||
EKa_pub = EKa_priv.public_key()
|
||
|
||
# X3DH 4次 ECDH(无OPK则忽略第4项)
|
||
s1 = x25519_shared(self.ik_priv, SPKb)
|
||
s2 = x25519_shared(EKa_priv, IKb)
|
||
s3 = x25519_shared(EKa_priv, SPKb)
|
||
parts = [s1, s2, s3]
|
||
if OPKb:
|
||
s4 = x25519_shared(EKa_priv, OPKb)
|
||
parts.append(s4)
|
||
|
||
master_secret = b"".join(parts)
|
||
root_key = hkdf(master_secret, info=b"ROOT", length=32)
|
||
chain_key = hkdf(root_key, info=b"CHAIN", length=32)
|
||
|
||
st = SessionState(
|
||
root_key=root_key,
|
||
chain_key=chain_key,
|
||
counter=0,
|
||
peer_identity_pub_b64=pb.identity_pub,
|
||
opk_id=opk_id,
|
||
established_at=time.time()
|
||
)
|
||
self.sessions[peer] = st
|
||
|
||
# 发送“建会话 + 第一条消息”的头信息(包含 EKa_pub、opk_id)
|
||
self._pending_ephemeral_pub = b64e(pub_bytes(EKa_pub))
|
||
self._pending_opk_id = opk_id
|
||
return st
|
||
|
||
# ---------- 建会话(接收方) ----------
|
||
def _establish_session_as_responder(self, peer: str, ek_pub_b64: str, opk_id: Optional[str]) -> SessionState:
|
||
EKa = pub_from_bytes(b64d(ek_pub_b64)) # 对方临时公钥
|
||
# 自己的密钥
|
||
IKb_priv = self.ik_priv
|
||
SPKb_priv = self.spk_priv
|
||
OPKb_priv = self.opk_priv_map.get(opk_id) if opk_id else None
|
||
|
||
# 4次 ECDH(无OPK则忽略第4项)
|
||
s1 = x25519_shared(SPKb_priv, pub_from_bytes(b64d(self.server.directory[peer].identity_pub))) # = ECDH(Ia, SPKb)
|
||
# 注意:发起方 s1 是 ECDH(Ia, SPKb),接收方等价项应是 ECDH(SPKb, Ia)
|
||
# 但我们没有 Ia 私钥,这里换一种对称表达:按消息头与本地密钥构造相同串联
|
||
# 为确保与发起方一致,我们直接重算:
|
||
# 对于接收方:ECDH(Iinitiator, Srecipient) == ECDH(Srecipient, Iinitiator)
|
||
# 需要 Iinitiator 公钥:来自会话第一帧中?为简化,我们用目录中对方 identity_pub。
|
||
Ia_pub = pub_from_bytes(b64d(self.server.directory[peer].identity_pub))
|
||
|
||
s1 = x25519_shared(self.spk_priv, Ia_pub)
|
||
s2 = x25519_shared(self.ik_priv, EKa)
|
||
s3 = x25519_shared(self.spk_priv, EKa)
|
||
parts = [s1, s2, s3]
|
||
if OPKb_priv:
|
||
s4 = x25519_shared(OPKb_priv, EKa)
|
||
parts.append(s4)
|
||
|
||
master_secret = b"".join(parts)
|
||
root_key = hkdf(master_secret, info=b"ROOT", length=32)
|
||
chain_key = hkdf(root_key, info=b"CHAIN", length=32)
|
||
|
||
st = SessionState(
|
||
root_key=root_key,
|
||
chain_key=chain_key,
|
||
counter=0,
|
||
peer_identity_pub_b64=self.server.directory[peer].identity_pub,
|
||
opk_id=opk_id,
|
||
established_at=time.time()
|
||
)
|
||
self.sessions[peer] = st
|
||
# 一次性预键被使用后,服务端目录也标记删除
|
||
if opk_id:
|
||
self.server.mark_onetime_used(self.name, opk_id)
|
||
return st
|
||
|
||
# ---------- 每条消息的派生与加密 ----------
|
||
def _derive_message_key(self, st: SessionState) -> Tuple[bytes, bytes, bytes]:
|
||
"""
|
||
从 chain_key 派生本条消息的 (aes_key, hmac_key, iv),然后更新 chain_key,counter+1
|
||
"""
|
||
info = b"MSG|" + st.counter.to_bytes(8, "big")
|
||
msg_key = hkdf(st.chain_key, info=info, length=80) # 32 AES + 32 HMAC + 16 IV
|
||
aes_key = msg_key[:32]
|
||
mac_key = msg_key[32:64]
|
||
iv = msg_key[64:80]
|
||
# 下一条链
|
||
st.chain_key = hkdf(st.chain_key, info=b"STEP", length=32)
|
||
st.counter += 1
|
||
return aes_key, mac_key, iv
|
||
|
||
# ---------- 发送 ----------
|
||
def send(self, to_user: str, plaintext: bytes):
|
||
# 若没有会话,先建立
|
||
if to_user not in self.sessions:
|
||
st = self.establish_session_as_initiator(to_user)
|
||
ek_pub_b64 = self._pending_ephemeral_pub
|
||
opk_id = self._pending_opk_id
|
||
else:
|
||
st = self.sessions[to_user]
|
||
ek_pub_b64 = None
|
||
opk_id = None
|
||
|
||
aes_key, mac_key, iv = self._derive_message_key(st)
|
||
ct = aes_cbc_enc(aes_key, iv, plaintext)
|
||
mac = hmac_sha256(mac_key, iv + ct)[:8]
|
||
|
||
frame = {
|
||
"from": self.name,
|
||
"to": to_user,
|
||
"hdr": {
|
||
"init": ek_pub_b64 is not None,
|
||
"ek_pub_b64": ek_pub_b64, # 仅首次带上
|
||
"opk_id": opk_id,
|
||
"counter": st.counter - 1, # 本条的计数
|
||
},
|
||
"body": {
|
||
"iv": b64e(iv),
|
||
"ct": b64e(ct),
|
||
"mac": b64e(mac),
|
||
}
|
||
}
|
||
self.server.send_frame(to_user, frame)
|
||
|
||
# ---------- 接收 ----------
|
||
def receive_all(self) -> List[Tuple[str, bytes]]:
|
||
frames = self.server.pull_frames(self.name)
|
||
outputs = []
|
||
for f in frames:
|
||
sender = f["from"]
|
||
hdr = f["hdr"]
|
||
body = f["body"]
|
||
|
||
if sender not in self.sessions:
|
||
# 首帧:建立被动会话
|
||
assert hdr["init"], "Missing init header"
|
||
st = self._establish_session_as_responder(
|
||
peer=sender,
|
||
ek_pub_b64=hdr["ek_pub_b64"],
|
||
opk_id=hdr["opk_id"]
|
||
)
|
||
else:
|
||
st = self.sessions[sender]
|
||
|
||
# 按对方 counter 对齐(演示:假设顺序到达)
|
||
aes_key, mac_key, iv = self._derive_message_key(st)
|
||
if st.counter - 1 != hdr["counter"]:
|
||
# 简化:严格顺序,真实实现需支持跳号、乱序恢复
|
||
raise ValueError("Out-of-order message (demo limitation)")
|
||
|
||
if b64e(iv) != body["iv"]:
|
||
# 教学演示:我们强制使用派生的 iv
|
||
iv = b64d(body["iv"]) # 宽松一点也可以直接信任对端 iv
|
||
ct = b64d(body["ct"])
|
||
mac = b64d(body["mac"])
|
||
calc = hmac_sha256(mac_key, iv + ct)[:8]
|
||
if not hmac.compare_digest(mac, calc):
|
||
raise ValueError("MAC verification failed")
|
||
|
||
pt = aes_cbc_dec(aes_key, iv, ct)
|
||
outputs.append((sender, pt))
|
||
return outputs
|
||
|
||
# ------------------------- Demo --------------------------
|
||
|
||
def main():
|
||
server = RelayServer()
|
||
|
||
alice = Client("alice", server)
|
||
bob = Client("bob", server)
|
||
|
||
# 发布各自的公钥包
|
||
alice.publish()
|
||
bob.publish()
|
||
|
||
# Alice 先发一条(会自动建会话)
|
||
alice.send("bob", b"Hello Bob, this is Alice.")
|
||
# Bob 拉取并解密
|
||
for frm in bob.receive_all():
|
||
print("[bob] got:", frm)
|
||
|
||
# Bob 回复
|
||
bob.send("alice", b"Hi Alice, Bob here. Message received.")
|
||
for frm in alice.receive_all():
|
||
print("[alice] got:", frm)
|
||
|
||
# 连发多条,观察每条都换密钥(server 无法解密)
|
||
for i in range(1, 4):
|
||
alice.send("bob", f"Msg#{i} from Alice".encode())
|
||
for frm in bob.receive_all():
|
||
print("[bob] got:", frm)
|
||
|
||
# 验证“服务器看不到明文”
|
||
print("\n[server] directory keys (truncated):")
|
||
print(json.dumps({
|
||
u: {
|
||
"identity_pub": v.identity_pub[:24] + "...",
|
||
"spk_pub": v.signed_prekey_pub[:24] + "...",
|
||
"opk_count": len(v.onetime_prekeys)
|
||
} for u, v in server.directory.items()
|
||
}, indent=2))
|
||
|
||
if __name__ == "__main__":
|
||
main() |