183 lines
5.5 KiB
Python
183 lines
5.5 KiB
Python
import threading
|
|
from typing import Any, AnyStr, BinaryIO, Callable, Mapping, Optional
|
|
|
|
|
|
class StreamController:
|
|
def __init__(
|
|
self,
|
|
post: Callable[[Any, Optional[AnyStr]], None],
|
|
on_incoming_stream_request: Optional[Callable[[Any, Any], BinaryIO]] = None,
|
|
on_incoming_stream_closed=None,
|
|
on_stats_updated=None,
|
|
) -> None:
|
|
self.streams_opened = 0
|
|
self.bytes_received = 0
|
|
self.bytes_sent = 0
|
|
|
|
self._handlers = {".create": self._on_create, ".finish": self._on_finish, ".write": self._on_write}
|
|
|
|
self._post = post
|
|
self._on_incoming_stream_request = on_incoming_stream_request
|
|
self._on_incoming_stream_closed = on_incoming_stream_closed
|
|
self._on_stats_updated = on_stats_updated
|
|
|
|
self._sources = {}
|
|
self._next_endpoint_id = 1
|
|
|
|
self._requests = {}
|
|
self._next_request_id = 1
|
|
|
|
def dispose(self) -> None:
|
|
error = DisposedException("disposed")
|
|
for request in self._requests.values():
|
|
request[2] = error
|
|
for event in [request[0] for request in self._requests.values()]:
|
|
event.set()
|
|
|
|
def open(self, label, details={}) -> "Sink":
|
|
eid = self._next_endpoint_id
|
|
self._next_endpoint_id += 1
|
|
|
|
endpoint = {"id": eid, "label": label, "details": details}
|
|
|
|
sink = Sink(self, endpoint)
|
|
|
|
self.streams_opened += 1
|
|
self._notify_stats_updated()
|
|
|
|
return sink
|
|
|
|
def receive(self, stanza: Mapping[str, Any], data: Any) -> None:
|
|
sid = stanza["id"]
|
|
name = stanza["name"]
|
|
payload = stanza.get("payload", None)
|
|
|
|
stype = name[0]
|
|
if stype == ".":
|
|
self._on_request(sid, name, payload, data)
|
|
elif stype == "+":
|
|
self._on_notification(sid, name, payload)
|
|
else:
|
|
raise ValueError("unknown stanza: " + name)
|
|
|
|
def _on_create(self, payload: Mapping[str, Any], data: Any) -> None:
|
|
endpoint = payload["endpoint"]
|
|
eid = endpoint["id"]
|
|
label = endpoint["label"]
|
|
details = endpoint["details"]
|
|
|
|
if self._on_incoming_stream_request is None:
|
|
raise ValueError("incoming streams not allowed")
|
|
source = self._on_incoming_stream_request(label, details)
|
|
|
|
self._sources[eid] = (source, label, details)
|
|
|
|
self.streams_opened += 1
|
|
self._notify_stats_updated()
|
|
|
|
def _on_finish(self, payload: Mapping[str, Any], data: Any) -> None:
|
|
eid = payload["endpoint"]["id"]
|
|
|
|
entry = self._sources.pop(eid, None)
|
|
if entry is None:
|
|
raise ValueError("invalid endpoint ID")
|
|
source, label, details = entry
|
|
|
|
source.close()
|
|
|
|
if self._on_incoming_stream_closed is not None:
|
|
self._on_incoming_stream_closed(label, details)
|
|
|
|
def _on_write(self, payload: Mapping[str, Any], data: Any) -> None:
|
|
entry = self._sources.get(payload["endpoint"]["id"], None)
|
|
if entry is None:
|
|
raise ValueError("invalid endpoint ID")
|
|
source, *_ = entry
|
|
|
|
source.write(data)
|
|
|
|
self.bytes_received += len(data)
|
|
self._notify_stats_updated()
|
|
|
|
def _request(self, name: str, payload: Mapping[Any, Any], data: Optional[AnyStr] = None):
|
|
rid = self._next_request_id
|
|
self._next_request_id += 1
|
|
|
|
completed = threading.Event()
|
|
request = [completed, None, None]
|
|
self._requests[rid] = request
|
|
|
|
self._post({"id": rid, "name": name, "payload": payload}, data)
|
|
|
|
completed.wait()
|
|
|
|
error = request[2]
|
|
if error is not None:
|
|
raise error
|
|
|
|
return request[1]
|
|
|
|
def _on_request(self, sid, name: str, payload: Mapping[str, Any], data: Any) -> None:
|
|
handler = self._handlers.get(name, None)
|
|
if handler is None:
|
|
raise ValueError("invalid request: " + name)
|
|
|
|
try:
|
|
result = handler(payload, data)
|
|
except Exception as e:
|
|
self._reject(sid, e)
|
|
return
|
|
|
|
self._resolve(sid, result)
|
|
|
|
def _resolve(self, sid, value) -> None:
|
|
self._post({"id": sid, "name": "+result", "payload": value})
|
|
|
|
def _reject(self, sid, error) -> None:
|
|
self._post({"id": sid, "name": "+error", "payload": {"message": str(error)}})
|
|
|
|
def _on_notification(self, sid, name: str, payload) -> None:
|
|
request = self._requests.pop(sid, None)
|
|
if request is None:
|
|
raise ValueError("invalid request ID")
|
|
|
|
if name == "+result":
|
|
request[1] = payload
|
|
elif name == "+error":
|
|
request[2] = StreamException(payload["message"])
|
|
else:
|
|
raise ValueError("unknown notification: " + name)
|
|
completed, *_ = request
|
|
completed.set()
|
|
|
|
def _notify_stats_updated(self) -> None:
|
|
if self._on_stats_updated is not None:
|
|
self._on_stats_updated()
|
|
|
|
|
|
class Sink:
|
|
def __init__(self, controller: StreamController, endpoint) -> None:
|
|
self._controller = controller
|
|
self._endpoint = endpoint
|
|
|
|
controller._request(".create", {"endpoint": endpoint})
|
|
|
|
def close(self) -> None:
|
|
self._controller._request(".finish", {"endpoint": self._endpoint})
|
|
|
|
def write(self, chunk) -> None:
|
|
ctrl = self._controller
|
|
|
|
ctrl._request(".write", {"endpoint": self._endpoint}, chunk)
|
|
|
|
ctrl.bytes_sent += len(chunk)
|
|
ctrl._notify_stats_updated()
|
|
|
|
|
|
class DisposedException(Exception):
|
|
pass
|
|
|
|
|
|
class StreamException(Exception):
|
|
pass
|