import asyncio
import base64
import json
import logging
import os
import pathlib
import random
import ssl
import socket
import threading
from typing import Any, Dict, Optional
import urllib.parse
import msgpack
import websockets
from websockets import exceptions as ws_exc
from fastapi import FastAPI, HTTPException
from starlette.responses import JSONResponse
from . import _prof as rprof
import radical.edge.logging_config # noqa: F401 # pylint: disable=unused-import
from radical.edge.plugin_base import Plugin
from radical.edge.plugin_host_base import PluginHostBase
from radical.edge.models import (
RequestMessage, PingMessage, ErrorMessage, ShutdownMessage, TopologyMessage,
ResponseMessage, NotificationMessage, RegisterMessage,
parse_bridge_message
)
from radical.edge.ui_schema import ui_config_to_dict
log = logging.getLogger("radical.edge")
# ---------------------------------------------------------------------------
# RequestShim — lightweight stand-in for starlette.requests.Request
# ---------------------------------------------------------------------------
[docs]
class RequestShim:
"""Lightweight adapter for starlette ``Request``.
Provides the three interfaces that every plugin handler uses:
``path_params``, ``query_params``, and ``await .json()`` / ``await .body()``.
Encoding-agnostic: stores raw bytes, decodes lazily based on content_type.
"""
def __init__(self, path_params : dict,
query_params: dict,
body_bytes : bytes,
content_type: str = 'application/json'):
self.path_params = path_params
self.query_params = query_params
self.content_type = content_type
self._body = body_bytes
self._decoded = None
[docs]
async def body(self) -> bytes:
"""Raw body bytes (matches ``Request.body()``)."""
return self._body
[docs]
async def json(self) -> dict:
"""Parse body into a Python dict (matches ``Request.json()``).
Content-type-aware: JSON or msgpack based on Content-Type header.
"""
if self._decoded is not None:
return self._decoded
ct = self.content_type or 'application/json'
if 'msgpack' in ct:
self._decoded = msgpack.unpackb(self._body, raw=False)
else:
self._decoded = json.loads(self._body) if self._body else {}
return self._decoded
# Re-export for backward compatibility (bridge_plugin_host.py, tests, etc.)
from radical.edge.plugin_host_base import _resolve_plugin_names # noqa: F401
[docs]
class EdgeService(PluginHostBase):
"""
Embedded Radical Edge Service.
This class runs the Edge Service logic within an application, supporting both
asyncio-based and synchronous applications. It manages the connection to the
Bridge and hosts the local plugin execution environment.
The service automatically loads the 'sysinfo' plugin to provide system
metrics.
Attributes:
app (FastAPI): The internal FastAPI application hosting the plugins.
"""
def __init__(self, bridge_url: Optional[str] = None, name: Optional[str] = None,
plugins: Optional[list] = None, tunnel: bool = False,
tunnel_via: Optional[str] = None):
"""
Initialize the Edge Service.
Args:
bridge_url: WebSocket URL for the Bridge. Defaults to env var
'RADICAL_BRIDGE_URL' or internal default.
name: Edge service name for identification. Defaults to hostname.
tunnel: When True, open an outbound SSH tunnel to the login host
and route the bridge connection through it. See
:mod:`radical.edge.tunnel`.
tunnel_via: Explicit login host to tunnel through. If unset,
falls back to ``PBS_O_HOST`` (PBSPro) or
``SLURM_SUBMIT_HOST`` (SLURM).
"""
self._bridge_url: str = bridge_url or os.environ.get("RADICAL_BRIDGE_URL", "")
self._app: FastAPI = FastAPI(title="Embedded Edge Service")
self._app.state.bridge_url = self._bridge_url
if not self._bridge_url:
raise ValueError("Bridge URL missing as argument or RADICAL_BRIDGE_URL")
self._plugins: Dict[str, Plugin] = {}
self._name: str = name or socket.gethostname()
self._plugin_filter: list = plugins or ['all']
self._app.state.edge_name = self._name
self._app.state.edge_service = self
self._app.state.is_bridge = False
self._tunnel: bool = tunnel
self._tunnel_via: Optional[str] = tunnel_via
self._tunnel_proc = None # subprocess.Popen of active SSH tunnel
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._send_lock: asyncio.Lock = asyncio.Lock()
self._stop_event: asyncio.Event = asyncio.Event()
self._running_task: Optional[asyncio.Task] = None
self._thread: Optional[threading.Thread] = None
self._direct_routes: list = []
self._prof = rprof.Profiler('edge', ns='radical.edge')
self._load_plugins_from_filter(self._plugin_filter)
# Reference the live list — not a copy — so dynamically registered
# plugin routes are visible immediately.
self._direct_routes = getattr(self._app.state, 'direct_routes', [])
@property
def bridge_url(self):
"""Get the current Bridge URL."""
return self._bridge_url
# -- direct dispatch ------------------------------------------------------
def _match_route(self, method: str, path: str):
"""Match *method* + *path* against the direct-dispatch route table.
Returns ``(handler, path_params)`` or ``(None, None)``.
"""
for rt_method, pattern, param_names, handler in self._direct_routes:
if rt_method == method:
m = pattern.match(path)
if m:
return handler, dict(zip(param_names, m.groups()))
return None, None
@staticmethod
def _error_response(req_id: str, exc: Exception) -> ResponseMessage:
"""Build a ``ResponseMessage`` from an exception."""
if isinstance(exc, HTTPException):
body = json.dumps({"detail": exc.detail})
status = exc.status_code
else:
body = json.dumps({"error": "edge-invoke-failed",
"detail": str(exc)})
status = 502
return ResponseMessage(
req_id=req_id, status=status,
headers={"content-type": "application/json"},
is_binary=False, body=body)
async def _handle_request(self, msg: RequestMessage) -> None:
"""Dispatch a bridge-forwarded request directly to the plugin handler.
Bypasses the ASGI/FastAPI stack entirely — route matching and request
parsing are handled inline via ``_match_route`` and ``RequestShim``.
"""
req_id = msg.req_id
prof = self._prof
try:
prof.prof('edge_recv', uid=req_id,
msg='%s %s' % (msg.method, msg.path))
log.debug("[Edge] [req:%s] Handling %s %s", req_id, msg.method, msg.path)
# Split query string from path
if '?' in msg.path:
path, qs = msg.path.split('?', 1)
query_params = dict(urllib.parse.parse_qsl(qs))
else:
path = msg.path
query_params = {}
# Match route
prof.prof('edge_route', uid=req_id)
handler, path_params = self._match_route(msg.method, path)
if handler is None:
log.error("[Edge] [req:%s] No route for %s %s",
req_id, msg.method, path)
response = ResponseMessage(
req_id=req_id, status=404,
headers={"content-type": "application/json"},
is_binary=False,
body=json.dumps(
{"detail": f"No route: {msg.method} {path}"}))
async with self._send_lock:
if self._ws:
await self._ws.send(response.model_dump_json())
return
# Build RequestShim
prof.prof('edge_shim', uid=req_id)
if isinstance(msg.body, bytes):
body_bytes = msg.body # binary WS frame
elif msg.is_binary and msg.body:
body_bytes = base64.b64decode(msg.body) # base64 fallback
elif msg.body:
body_bytes = msg.body.encode('utf-8')
else:
body_bytes = b''
content_type = (msg.headers or {}).get(
'content-type', 'application/json')
shim = RequestShim(path_params, query_params,
body_bytes, content_type)
# Dispatch to handler
prof.prof('edge_handler', uid=req_id)
try:
result = await handler(shim)
except HTTPException as e:
result = JSONResponse({"detail": e.detail},
status_code=e.status_code)
except Exception as e:
log.exception("[Edge] [req:%s] Handler error", req_id)
result = JSONResponse(
{"error": "edge-invoke-failed", "detail": str(e)},
status_code=500)
prof.prof('edge_handler_done', uid=req_id)
# Build response — handlers return plain dicts/lists (fast
# path) or JSONResponse (error path).
#
# Fast path: serialize body with json.dumps, then build the
# WS frame manually so the body JSON is embedded verbatim
# (avoids Pydantic model_dump_json double-encoding the body
# string as an escaped JSON value).
prof.prof('edge_body_ser', uid=req_id)
if not hasattr(result, 'status_code'):
resp_body = json.dumps(result)
status = 200
headers = {"content-type": "application/json"}
else:
resp_body = result.body.decode('utf-8')
status = result.status_code
headers = dict(result.headers)
prof.prof('edge_body_ser_done', uid=req_id,
msg=str(len(resp_body)))
log.debug("[Edge] [req:%s] Response status=%d",
req_id, status)
# Manual JSON construction — body is already a JSON string,
# embed it directly to avoid re-serialization.
prof.prof('edge_resp_ser', uid=req_id)
hdr_json = json.dumps(headers)
resp_text = (
'{"type":"response"'
',"req_id":' + json.dumps(req_id) +
',"status":' + str(status) +
',"headers":' + hdr_json +
',"body":' + resp_body +
',"is_binary":false}')
prof.prof('edge_resp_ser_done', uid=req_id,
msg=str(len(resp_text)))
prof.prof('edge_ws_send', uid=req_id)
async with self._send_lock:
if self._ws:
await self._ws.send(resp_text)
prof.prof('edge_ws_sent', uid=req_id, state=str(status))
except Exception as e:
log.exception("[Edge] [req:%s] Error handling request", req_id)
response = self._error_response(req_id, e)
async with self._send_lock:
if self._ws:
await self._ws.send(response.model_dump_json())
async def _handle_topology(self, msg: TopologyMessage) -> None:
"""
Handle topology update from bridge (edge connect/disconnect).
Args:
msg: Validated topology message from bridge.
"""
log.debug("[Edge] Topology update: %d edges", len(msg.edges))
# Notify all plugins about the topology change
for pname, plugin in self._plugins.items():
try:
if hasattr(plugin, 'on_topology_change'):
await plugin.on_topology_change(msg.edges)
except Exception as e:
log.warning("[Edge] Plugin %s topology handler failed: %s", pname, e)
# -- topology announcement (PluginHostBase contract) -----------------------
async def _announce_topology(self) -> None:
"""Send a topology message to the bridge over WebSocket.
Called by ``register_dynamic_plugin`` / ``deregister_dynamic_plugin``
after plugin set changes at runtime.
"""
if not self._ws:
log.warning("[Edge] Cannot announce topology, not connected")
return
plugins_data = {}
for pname, plugin in self._plugins.items():
plugins_data[pname] = {
'type' : pname,
'namespace': f'/{self._name}{plugin.namespace}',
'version' : getattr(plugin, 'version', '0.0.1'),
'enabled' : True,
'ui_config': ui_config_to_dict(
getattr(plugin, 'ui_config', None)),
}
msg = json.dumps({
'type' : 'topology',
'edges': {self._name: {'plugins': plugins_data}},
})
async with self._send_lock:
try:
await self._ws.send(msg)
log.info("[Edge] Sent topology (%d plugins)",
len(plugins_data))
except Exception as exc:
log.warning("[Edge] Failed to send topology: %s", exc)
# -- notifications --------------------------------------------------------
[docs]
async def send_notification(self, plugin_name: str, topic: str, data: Dict[str, Any]) -> None:
"""
Send an unsolicited notification to the bridge to broadcast to UI clients.
Args:
plugin_name: Name of the plugin sending the notification.
topic: Notification topic (e.g., "task_status", "job_status").
data: Notification payload data.
"""
if not self._ws:
log.warning("[Edge] Cannot send notification, not connected")
return
notification = NotificationMessage(
edge=self._name,
plugin=plugin_name,
topic=topic,
data=data
)
async with self._send_lock:
try:
await self._ws.send(notification.model_dump_json())
log.debug("[Edge] Sent notification: %s/%s", plugin_name, topic)
except Exception as e:
log.warning("[Edge] Failed to send notification: %s", e)
[docs]
async def run(self) -> None:
"""
Main async entry point.
Connects to Bridge and starts processing loop.
"""
PING_INTERVAL = 20
PING_TIMEOUT = 30
MAX_BACKOFF = 10
JITTER_FACTOR = 0.3 # Add up to 30% jitter to prevent thundering herd
BACKOFF_FACTOR = 1.2
backoff = 0.5
self._stop_event.clear()
self._running_task = asyncio.current_task()
# ── Outbound SSH tunnel (--tunnel flag) ──────────────────────────────
# When --tunnel is passed we are running on a compute node and the
# bridge is only reachable from the login node. We open an outbound
# ssh -L tunnel to the login host ourselves and rewrite the bridge
# URL to localhost:<allocated_port>. Compute→login SSH is permitted
# on virtually all HPC sites; the reverse direction (login→compute)
# is blocked on Aurora and others.
if self._tunnel:
await self._open_tunnel()
# ── End tunnel setup ──────────────────────────────────────────────────
while not self._stop_event.is_set():
try:
# For the ws connect, we change http(s) to ws(s)
if self._bridge_url.startswith("https://"):
ws_url = "wss://" + self._bridge_url[len("https://"):]
elif self._bridge_url.startswith("http://"):
ws_url = "ws://" + self._bridge_url[len("http://"):]
else:
ws_url = self._bridge_url
# remove trailing slashes
ws_url = ws_url.rstrip("/")
if not ws_url.endswith("/register"):
ws_url += "/register"
# Determine if we need SSL
ssl_ctx = None
if ws_url.startswith("wss://"):
ssl_ctx = ssl.create_default_context()
ssl_ctx.check_hostname = False
ssl_ctx.verify_mode = ssl.CERT_NONE
certfile = os.environ.get("RADICAL_BRIDGE_CERT")
if certfile and os.path.exists(certfile):
ssl_ctx.load_verify_locations(certfile)
async with websockets.connect(ws_url,
ssl=ssl_ctx,
ping_interval=PING_INTERVAL,
ping_timeout=PING_TIMEOUT,
close_timeout=2,
max_size=10 * 1024 * 1024,
compression='deflate',
) as ws:
self._ws = ws
log.info("[Edge] Connected to %s", self._bridge_url)
backoff = 0.5 # Reset backoff on success
# Register edge + all plugins in a single message
async with self._send_lock:
plugins_data = {}
for pname, plugin in self._plugins.items():
ui_module_content = None
ui_module_path = getattr(plugin.__class__, 'ui_module', None)
if ui_module_path and os.path.isfile(ui_module_path):
try:
with open(ui_module_path, encoding='utf-8') as f:
ui_module_content = f.read()
except Exception:
log.warning("[Edge] Could not read ui_module for %s: %s",
pname, ui_module_path)
plugins_data[pname] = {
"type": pname,
"namespace": f"/{self._name}{plugin.namespace}",
"version": getattr(plugin, 'version', '0.0.1'),
"enabled": True,
"ui_config": ui_config_to_dict(
getattr(plugin, 'ui_config', None)
),
"ui_module": ui_module_content,
}
reg = RegisterMessage(
edge_name=self._name,
endpoint={"type": "radical.edge"},
plugins=plugins_data,
)
await ws.send(reg.model_dump_json())
# Processing Loop — use asyncio.wait so the loop wakes
# immediately on either a new message or stop signal,
# eliminating the 1-second idle timeout overhead.
_recv_task = asyncio.ensure_future(ws.recv())
_stop_fut = asyncio.ensure_future(self._stop_event.wait())
try:
while not self._stop_event.is_set():
done, _ = await asyncio.wait(
{_recv_task, _stop_fut},
return_when=asyncio.FIRST_COMPLETED)
if _stop_fut in done:
_recv_task.cancel()
break
# _recv_task completed — retrieve result
try:
raw_msg = _recv_task.result()
except websockets.exceptions.ConnectionClosed:
if self._stop_event.is_set():
_stop_fut.cancel()
break
log.info("[Edge] Connection closed")
_stop_fut.cancel()
raise # Reconnect
# Arm next recv immediately
_recv_task = asyncio.ensure_future(ws.recv())
# Binary WS frame → msgpack; text → JSON
self._prof.prof('edge_deser',
msg='%s:%d' % (
'msgpack' if isinstance(raw_msg, bytes)
else 'json',
len(raw_msg)))
if isinstance(raw_msg, bytes):
data = msgpack.unpackb(raw_msg, raw=False)
else:
data = json.loads(raw_msg)
self._prof.prof('edge_deser_done',
uid=data.get('req_id', ''))
self._prof.prof('edge_parse',
uid=data.get('req_id', ''))
try:
msg = parse_bridge_message(data)
except ValueError as ve:
log.warning("[Edge] Invalid message: %s", ve)
continue
self._prof.prof('edge_parse_done',
uid=data.get('req_id', ''))
if isinstance(msg, ErrorMessage):
log.error("[Edge] Registration error: %s", msg.message)
self._stop_event.set()
_recv_task.cancel()
_stop_fut.cancel()
return # Fatal error, stop
if isinstance(msg, PingMessage):
async with self._send_lock:
await ws.send('{"type": "pong"}')
continue
if isinstance(msg, ShutdownMessage):
log.info("[Edge] Shutdown requested: %s", msg.reason)
self._stop_event.set()
_recv_task.cancel()
_stop_fut.cancel()
return
if isinstance(msg, RequestMessage):
asyncio.create_task(self._handle_request(msg))
if isinstance(msg, TopologyMessage):
asyncio.create_task(self._handle_topology(msg))
finally:
_recv_task.cancel()
_stop_fut.cancel()
await asyncio.gather(
_recv_task, _stop_fut,
return_exceptions=True)
except (ws_exc.ConnectionClosed, OSError) as e:
if self._stop_event.is_set():
break # no reconnect
# Add jitter to backoff to prevent thundering herd
jitter = backoff * JITTER_FACTOR * random.random()
sleep_time = backoff + jitter
log.warning("[Edge] Connection lost: %s. Reconnecting in %.1fs...",
e, sleep_time)
await asyncio.sleep(sleep_time)
backoff = min(backoff * BACKOFF_FACTOR, MAX_BACKOFF)
except Exception as e:
# Fatal errors set the stop event, so check that first
if self._stop_event.is_set():
break
log.exception("[Edge] Unexpected error: %s", e)
jitter = 2 * JITTER_FACTOR * random.random()
await asyncio.sleep(2 + jitter)
[docs]
def stop(self):
"""Signal the service to stop."""
self._prof.close()
self._stop_event.set()
if self._running_task:
self._running_task.cancel()
if self._tunnel_proc is not None:
from . import tunnel as _tunnel
_tunnel.cleanup_tunnel(self._tunnel_proc, self._name)
self._tunnel_proc = None
async def _open_tunnel(self) -> None:
"""Open an outbound SSH tunnel to the login host (--tunnel mode).
Derives the login host from ``tunnel_via``, then ``PBS_O_HOST``,
then ``SLURM_SUBMIT_HOST``. Spawns the SSH process, parses the
allocated port, and rewrites ``self._bridge_url`` to route through
``localhost:<port>``.
"""
from urllib.parse import urlparse, urlunparse
from . import tunnel as _tunnel
login_host = (self._tunnel_via
or os.environ.get('PBS_O_HOST')
or os.environ.get('SLURM_SUBMIT_HOST'))
if not login_host:
raise RuntimeError(
"--tunnel: no login host available. Pass --tunnel-via HOST or "
"set PBS_O_HOST / SLURM_SUBMIT_HOST.")
parsed = urlparse(self._bridge_url)
bridge_host = parsed.hostname or 'localhost'
bridge_port = parsed.port or (443 if parsed.scheme == 'https' else 8000)
log.info("[Edge] --tunnel: opening ssh -L tunnel via %s to %s:%d",
login_host, bridge_host, bridge_port)
proc, port = await asyncio.to_thread(
_tunnel.spawn_tunnel,
login_host, bridge_host, bridge_port, self._name)
self._tunnel_proc = proc
self._bridge_url = urlunparse(
parsed._replace(netloc=f'localhost:{port}'))
log.info("[Edge] Tunnel active on localhost:%d; bridge URL now %s",
port, self._bridge_url)
[docs]
def start_background(self):
"""Start the service in a separate daemon thread (for sync apps)."""
if self._thread and self._thread.is_alive():
raise RuntimeError("Service already running in background")
self._thread = threading.Thread(target=self._run_thread, daemon=True)
self._thread.start()
def _run_thread(self):
"""Entry point for background thread."""
try:
asyncio.run(self.run())
except asyncio.CancelledError:
log.info("[Edge] Background service cancelled")
except Exception as e:
log.exception("[Edge] Background thread failed: %s", e)