feat(wits): 添加WITS数据验证和字段规则配置

- 在model.__init__.py中新增REQUIRED_SIMPLIFIED_FIELD_RULES和REQUIRED_TRANSMISSION_CHANNELS导出
- 移除app/mqtt_sender.py中的MQTT发送功能,禁用相关逻辑
- 在model/wits.py中添加WITS字段验证规则和传输通道映射配置
- 实现validate_required_wits_fields函数进行必填字段验证
- 在WitsData类中添加__post_init__方法执行字段验证
- 为wits_sender.py添加传输值验证和数据包验证功能
- 更新随机WITS数据生成逻辑,使用真实钻井参数范围
- 实现数据包解析和验证功能,确保必传字段完整性
This commit is contained in:
2026-03-12 13:58:19 +08:00
parent 0a123ba210
commit dc8aed8156
4 changed files with 135 additions and 80 deletions

View File

@@ -1,11 +1,6 @@
import argparse import argparse
import json
import logging import logging
import random import random
import time
from urllib.parse import urlparse
import paho.mqtt.client as mqtt
from config import build_sender_dependencies from config import build_sender_dependencies
from model import DrillingRealtimeData from model import DrillingRealtimeData
@@ -14,18 +9,6 @@ from model import DrillingRealtimeData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_broker(broker):
if not broker:
raise ValueError("broker is required")
if "://" not in broker:
broker = "tcp://" + broker
parsed = urlparse(broker)
host = parsed.hostname or "localhost"
port = parsed.port or 1883
scheme = (parsed.scheme or "tcp").lower()
return scheme, host, port
def rand_int(a, b): def rand_int(a, b):
return random.randint(a, b) return random.randint(a, b)
@@ -89,43 +72,14 @@ def build_random_payload(equipment_code):
def run_sender(args, deps): def run_sender(args, deps):
mqtt_config = deps.config.mqtt mqtt_config = deps.config.mqtt
tms_config = deps.config.tms logger.info(
scheme, host, port = parse_broker(mqtt_config.broker) "mqtt_sender is disabled by current requirements; skipping MQTT publish logic for topic=%s client_id=%s interval=%ss count=%s",
mqtt_config.pub_topic,
mqtt_config.sender_client_id,
args.interval,
args.count or "forever",
)
logger.info("MQTT sender config broker=%s://%s:%s client_id=%s pub_topic=%s interval=%ss", scheme, host, port, mqtt_config.sender_client_id, mqtt_config.pub_topic, args.interval)
client = mqtt.Client(client_id=mqtt_config.sender_client_id, clean_session=True)
if mqtt_config.username is not None:
client.username_pw_set(mqtt_config.username, mqtt_config.password)
if scheme in ("ssl", "tls", "mqtts"):
client.tls_set()
def on_disconnect(c, userdata, rc):
logger.info("Disconnected callback rc=%s", rc)
client.on_disconnect = on_disconnect
client.connect(host, port, keepalive=tms_config.keepalive)
client.loop_start()
try:
if not mqtt_config.pub_topic:
logger.warning("pub-topic is empty; nothing to publish")
return
seq = 0
while True:
seq += 1
payload = build_random_payload(tms_config.device_code)
client.publish(mqtt_config.pub_topic, json.dumps(payload, ensure_ascii=True))
logger.info("TX %s #%s", mqtt_config.pub_topic, seq)
if args.count and seq >= args.count:
break
time.sleep(args.interval)
except KeyboardInterrupt:
logger.info("Sender interrupted")
finally:
client.loop_stop()
client.disconnect()
logger.info("Sender stopped")
def add_arguments(parser): def add_arguments(parser):
@@ -134,6 +88,7 @@ def add_arguments(parser):
parser.add_argument("--count", type=int, default=0, help="Publish count (0 = forever)") parser.add_argument("--count", type=int, default=0, help="Publish count (0 = forever)")
def main(argv=None): def main(argv=None):
parser = argparse.ArgumentParser(description="MQTT random data sender") parser = argparse.ArgumentParser(description="MQTT random data sender")
add_arguments(parser) add_arguments(parser)

View File

@@ -3,11 +3,10 @@ import logging
import random import random
import socket import socket
import time import time
from datetime import datetime
from pathlib import Path from pathlib import Path
from config import build_wits_sender_dependencies from config import build_wits_sender_dependencies
from model import WITS_CHANNEL_MAPPING, WitsData from model import REQUIRED_TRANSMISSION_CHANNELS, WITS_CHANNEL_MAPPING, WitsData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,6 +14,16 @@ BEGIN_MARK = "&&\r\n"
END_MARK = "!!\r\n" END_MARK = "!!\r\n"
RECORD_TERMINATOR = "*\r\n" RECORD_TERMINATOR = "*\r\n"
RECONNECT_DELAY = 3 RECONNECT_DELAY = 3
FIELD_RULES = {
"deptbitm": (0.0, 20000.0, float),
"chkp": (0.0, 20000.0, float),
"sppa": (0.0, 20000.0, float),
"rpma": (0, 400, int),
"torqa": (0.0, 100000.0, float),
"hkla": (0.0, 2000.0, float),
"blkpos": (0.0, 1000.0, float),
"woba": (0.0, 2000.0, float),
}
def rand_int(a, b): def rand_int(a, b):
@@ -26,46 +35,50 @@ def rand_float(a, b, digits=6):
def build_random_wits_data(device_code): def build_random_wits_data(device_code):
now = datetime.now()
ts_ms = int(time.time() * 1000) ts_ms = int(time.time() * 1000)
hook_load = rand_float(17.3, 18.8) hook_load = rand_float(17.3, 18.8)
standpipe_pressure = rand_float(990.0, 1012.0) standpipe_pressure = rand_float(990.0, 1012.0)
mud_density = rand_float(1069.8, 1070.1) casing_pressure = rand_float(180.0, 260.0)
rotary_rpm = rand_int(95, 135)
torque = rand_float(8.0, 16.0)
weight_on_bit = rand_float(6.0, 12.0)
bit_depth = rand_float(199.8, 200.3)
block_position = rand_float(5.8, 6.3)
return WitsData( return WitsData(
ts=ts_ms, ts=ts_ms,
wellid="???1", wellid=device_code or "???1",
stknum=0, stknum=0,
recid=1, recid=1,
seqid=rand_int(1600, 9999), seqid=rand_int(1600, 9999),
actual_date=now.strftime("%y%m%d"), actual_date=time.strftime("%y%m%d"),
actual_time=now.strftime("%H%M%S"), actual_time=time.strftime("%H%M%S"),
actual_ts=ts_ms, actual_ts=ts_ms,
actcod=37, actcod=37,
actod_label="AUTO", actod_label="AUTO",
deptbitm=200.0, deptbitm=bit_depth,
deptbitv=198.551422, deptbitv=bit_depth - 1.45,
deptmeas=200.0, deptmeas=bit_depth,
deptvert=198.551422, deptvert=bit_depth - 1.45,
blkpos=6.001850, blkpos=block_position,
ropa=0.0, ropa=rand_float(0.8, 2.5),
hkla=hook_load, hkla=hook_load,
hklx=hook_load, hklx=hook_load,
woba=0.0, woba=weight_on_bit,
wobx=-hook_load, wobx=-weight_on_bit,
torqa=0.0, torqa=torque,
torqx=0.0, torqx=torque,
rpma=0, rpma=rotary_rpm,
sppa=standpipe_pressure, sppa=standpipe_pressure,
chkp=0.0, chkp=casing_pressure,
spm1=0, spm1=rand_int(98, 112),
spm2=0, spm2=0,
spm3=0, spm3=0,
tvolact=0.0, tvolact=rand_float(28.0, 31.0),
tvolcact=0.0, tvolcact=rand_float(28.0, 31.0),
mfop=0, mfop=0,
mfoa=0.0, mfoa=0.0,
mfia=0.0, mfia=0.0,
mdoa=mud_density, mdoa=rand_float(1069.8, 1070.1),
mdia=26.846003, mdia=26.846003,
mtoa=29.113855, mtoa=29.113855,
mtia=346.874634, mtia=346.874634,
@@ -73,7 +86,7 @@ def build_random_wits_data(device_code):
mcia=0.0, mcia=0.0,
stkc=0, stkc=0,
lagstks=0, lagstks=0,
deptretm=200.0, deptretm=bit_depth,
gasa=0.0, gasa=0.0,
space1=0.0, space1=0.0,
space2=0.0, space2=0.0,
@@ -93,9 +106,53 @@ def format_wits_value(value, kind):
return str(value) return str(value)
def validate_transmission_values(values):
for field_name, (minimum, maximum, caster) in FIELD_RULES.items():
raw_value = values.get(field_name)
if raw_value is None or raw_value == "":
raise ValueError(f"WITS field '{field_name}' is required")
try:
value = caster(raw_value)
except (TypeError, ValueError) as exc:
raise ValueError(f"WITS field '{field_name}' must be numeric, got {raw_value!r}") from exc
if value < minimum or value > maximum:
raise ValueError(
f"WITS field '{field_name}' out of range [{minimum}, {maximum}], got {value}"
)
def extract_channel_values(packet):
lines = packet.replace("\r\n", "\n").replace("\r", "\n").split("\n")
values = {}
for raw_line in lines:
line = raw_line.strip()
if not line or line in {"&&", "!!", "*"}:
continue
if len(line) < 5:
raise ValueError(f"Invalid WITS line: {line!r}")
values[line[:4]] = line[4:]
return values
def validate_packet(packet):
channel_values = extract_channel_values(packet)
missing_channels = [channel for channel in REQUIRED_TRANSMISSION_CHANNELS if channel not in channel_values]
if missing_channels:
missing_fields = [REQUIRED_TRANSMISSION_CHANNELS[channel] for channel in missing_channels]
raise ValueError(f"WITS packet missing required fields: {', '.join(missing_fields)}")
field_values = {
field_name: channel_values[channel]
for channel, field_name in REQUIRED_TRANSMISSION_CHANNELS.items()
}
validate_transmission_values(field_values)
def build_wits_packet(data): def build_wits_packet(data):
lines = [f"{channel}{format_wits_value(getattr(data, field_name), kind)}" for channel, field_name, kind in WITS_CHANNEL_MAPPING] lines = [f"{channel}{format_wits_value(getattr(data, field_name), kind)}" for channel, field_name, kind in WITS_CHANNEL_MAPPING]
return BEGIN_MARK + "\r\n".join(lines) + "\r\n" + END_MARK + RECORD_TERMINATOR packet = BEGIN_MARK + "\r\n".join(lines) + "\r\n" + END_MARK + RECORD_TERMINATOR
validate_packet(packet)
return packet
def normalize_packet(text): def normalize_packet(text):
@@ -107,7 +164,9 @@ def normalize_packet(text):
lines = lines[:-1] lines = lines[:-1]
if lines and lines[-1] == "!!": if lines and lines[-1] == "!!":
lines = lines[:-1] lines = lines[:-1]
return BEGIN_MARK + "\r\n".join(lines) + "\r\n" + END_MARK + RECORD_TERMINATOR packet = BEGIN_MARK + "\r\n".join(lines) + "\r\n" + END_MARK + RECORD_TERMINATOR
validate_packet(packet)
return packet
def load_packet_from_file(path): def load_packet_from_file(path):

View File

@@ -1,6 +1,6 @@
from model.config import AppConfig, MqttConfig, TdengineConfig, TmsConfig, WitsConfig from model.config import AppConfig, MqttConfig, TdengineConfig, TmsConfig, WitsConfig
from model.drilling import DrillingRealtimeData from model.drilling import DrillingRealtimeData
from model.wits import WITS_CHANNEL_MAPPING, WitsData from model.wits import REQUIRED_SIMPLIFIED_FIELD_RULES, REQUIRED_TRANSMISSION_CHANNELS, WITS_CHANNEL_MAPPING, WitsData
# Backward-compatible alias for older imports. # Backward-compatible alias for older imports.
WITS_FIELD_MAPPING = WITS_CHANNEL_MAPPING WITS_FIELD_MAPPING = WITS_CHANNEL_MAPPING
@@ -11,6 +11,8 @@ __all__ = [
"MqttConfig", "MqttConfig",
"TdengineConfig", "TdengineConfig",
"TmsConfig", "TmsConfig",
"REQUIRED_SIMPLIFIED_FIELD_RULES",
"REQUIRED_TRANSMISSION_CHANNELS",
"WITS_CHANNEL_MAPPING", "WITS_CHANNEL_MAPPING",
"WITS_FIELD_MAPPING", "WITS_FIELD_MAPPING",
"WitsConfig", "WitsConfig",

View File

@@ -1,6 +1,41 @@
from dataclasses import dataclass from dataclasses import dataclass
REQUIRED_SIMPLIFIED_FIELD_RULES = {
"ts": (1, None),
"deptbitm": (0.0, 20000.0),
"chkp": (0.0, 20000.0),
"sppa": (0.0, 20000.0),
"rpma": (0, 400),
"torqa": (0.0, 100000.0),
"hkla": (0.0, 2000.0),
"blkpos": (0.0, 1000.0),
"woba": (0.0, 2000.0),
}
REQUIRED_TRANSMISSION_CHANNELS = {
"0108": "deptbitm",
"0112": "blkpos",
"0114": "hkla",
"0116": "woba",
"0118": "torqa",
"0120": "rpma",
"0121": "sppa",
"0122": "chkp",
}
def validate_required_wits_fields(data):
for field_name, (minimum, maximum) in REQUIRED_SIMPLIFIED_FIELD_RULES.items():
value = getattr(data, field_name)
if value is None:
raise ValueError(f"WITS field '{field_name}' is required")
if value < minimum:
raise ValueError(f"WITS field '{field_name}' must be >= {minimum}, got {value}")
if maximum is not None and value > maximum:
raise ValueError(f"WITS field '{field_name}' must be <= {maximum}, got {value}")
@dataclass(frozen=True) @dataclass(frozen=True)
class WitsData: class WitsData:
ts: int ts: int
@@ -52,6 +87,9 @@ class WitsData:
space4: float space4: float
space5: float space5: float
def __post_init__(self):
validate_required_wits_fields(self)
WITS_CHANNEL_MAPPING = [ WITS_CHANNEL_MAPPING = [
("0101", "wellid", "string"), ("0101", "wellid", "string"),
@@ -74,6 +112,7 @@ WITS_CHANNEL_MAPPING = [
("0119", "torqx", "float6"), ("0119", "torqx", "float6"),
("0120", "rpma", "int"), ("0120", "rpma", "int"),
("0121", "sppa", "float6"), ("0121", "sppa", "float6"),
("0122", "chkp", "float6"),
("0123", "spm1", "int"), ("0123", "spm1", "int"),
("0124", "spm2", "int"), ("0124", "spm2", "int"),
("0125", "spm3", "int"), ("0125", "spm3", "int"),