import time import threading EPOCH_JAVA_COMMON = 1288834974657 class Snowflake: def __init__(self, datacenter_id: int = 0, worker_id: int = 0, epoch: int = 1480166465631): # 机器和数据中心配置 self.worker_id_bits = 5 self.datacenter_id_bits = 5 self.sequence_bits = 12 self.max_worker_id = -1 ^ (-1 << self.worker_id_bits) self.max_datacenter_id = -1 ^ (-1 << self.datacenter_id_bits) if worker_id > self.max_worker_id or worker_id < 0: raise ValueError(f"worker_id 超出范围 (0 ~ {self.max_worker_id})") if datacenter_id > self.max_datacenter_id or datacenter_id < 0: raise ValueError(f"datacenter_id 超出范围 (0 ~ {self.max_datacenter_id})") self.worker_id = worker_id self.datacenter_id = datacenter_id self.epoch = epoch self.sequence = 0 self.last_timestamp = -1 self.worker_id_shift = self.sequence_bits self.datacenter_id_shift = self.sequence_bits + self.worker_id_bits self.timestamp_left_shift = self.sequence_bits + self.worker_id_bits + self.datacenter_id_bits self.sequence_mask = -1 ^ (-1 << self.sequence_bits) self.lock = threading.Lock() def _timestamp(self): return int(time.time() * 1000) def _til_next_millis(self, last_timestamp): timestamp = self._timestamp() while timestamp <= last_timestamp: timestamp = self._timestamp() return timestamp def next_id(self) -> int: with self.lock: timestamp = self._timestamp() if timestamp < self.last_timestamp: raise Exception("时钟回拨,拒绝生成ID") if timestamp == self.last_timestamp: self.sequence = (self.sequence + 1) & self.sequence_mask if self.sequence == 0: timestamp = self._til_next_millis(self.last_timestamp) else: self.sequence = 0 self.last_timestamp = timestamp return ((timestamp - self.epoch) << self.timestamp_left_shift) | \ (self.datacenter_id << self.datacenter_id_shift) | \ (self.worker_id << self.worker_id_shift) | \ self.sequence if __name__ == "__main__": snowflake = Snowflake(datacenter_id=0, worker_id=0, epoch=EPOCH_JAVA_COMMON) for _ in range(10): print(str(snowflake.next_id()))