Source code for radical.edge.plugin_rhapsody

'''
Rhapsody Plugin for Radical Edge.

Exposes the RHAPSODY Session/Task API so that remote clients can submit
and monitor compute / AI tasks on edge nodes.
'''

import asyncio
import base64
import importlib
import json
import logging
import os
import threading
import time
import uuid

from fastapi import FastAPI, HTTPException, Request

from .plugin_session_base import PluginSession
from .plugin_base import Plugin
from .client import PluginClient

log = logging.getLogger("radical.edge")

TERMINAL_STATES     = {'DONE', 'FAILED', 'CANCELED', 'COMPLETED'}
WATCH_CONCURRENCY   = 64
WS_PAYLOAD_LIMIT    = 8 * 1024 * 1024  # target max per batch (conservative)
NOTIFY_BATCH_SIZE   = 1024             # max tasks per bulk notification
NOTIFY_BATCH_WINDOW = 0.25             # seconds to accumulate before flush

# Guard optional dependencies
try:
    import rhapsody as rh
except ImportError:
    rh = None

import cloudpickle as _cp
import msgpack
from . import _prof as rprof


def _assert_json_serializable(obj, path=""):
    """Recursively verify that *obj* is JSON-serializable.

    Raises ``TypeError`` with the exact key path on the first
    non-serializable value encountered.
    """
    if isinstance(obj, dict):
        for key, val in obj.items():
            _assert_json_serializable(val, f"{path}.{key}" if path else key)
    elif isinstance(obj, (list, tuple)):
        for i, val in enumerate(obj):
            _assert_json_serializable(val, f"{path}[{i}]")
    elif isinstance(obj, (str, int, float, bool, type(None))):
        return
    else:
        raise TypeError(
            f"non-serializable value at '{path}': "
            f"{type(obj).__name__} = {repr(obj)!s:.200}")


# ---------------------------------------------------------------------------
# Edge-side session
# ---------------------------------------------------------------------------

[docs] class RhapsodySession(PluginSession): """ Rhapsody session (service-side). Wraps a ``rhapsody.Session`` instance, forwarding task submission, monitoring, cancellation and statistics queries. """ def __init__(self, sid: str, backend_names: list[str] | None = None, allow_pickled_tasks: bool = True, notify_batch_window: float = NOTIFY_BATCH_WINDOW, notify_batch_size: int = NOTIFY_BATCH_SIZE): """ Initialize a RhapsodySession. Args: sid (str): Unique session identifier. backend_names (list[str] | None): Backends to configure. Defaults to ``['dragon_v3']``. allow_pickled_tasks (bool): Allow cloudpickle-encoded function tasks. Defaults to ``True``. notify_batch_window (float): Seconds to accumulate notifications before flushing. notify_batch_size (int): Max notifications per flush — triggers immediate send. """ super().__init__(sid) if rh is None: raise RuntimeError("rhapsody package is not installed") self.backend_names = backend_names or ['dragon_v3'] self.allow_pickled_tasks = allow_pickled_tasks self._rh_session = None self._tasks: dict[str, dict] = {} # Async init tracking self._init_ready = threading.Event() self._init_error: str | None = None # Limit concurrent task watchers self._watch_sem: asyncio.Semaphore | None = None # Notification batcher: accumulate completions and flush in bulk self._notify_buf: list[dict] = [] self._notify_lock = threading.Lock() self._notify_batch_window = notify_batch_window self._notify_batch_size = notify_batch_size # Cache for deserialized cloudpickle payloads — avoids decoding the # same encoded string N times for template-expanded homogeneous batches. self._pickle_cache: dict[str, object] = {} # Profiler — resolved lazily via the injected _plugin reference self._prof: rprof.Profiler | None = None @property def prof(self) -> rprof.Profiler: if self._prof is None: svc = getattr(getattr(self._plugin, '_app', None), 'state', None) svc = getattr(svc, 'edge_service', None) if svc else None self._prof = getattr(svc, '_prof', None) or \ rprof.Profiler('rhapsody', ns='radical.edge') return self._prof
[docs] async def initialize(self) -> None: """Asynchronously initialize the session and its backends.""" try: backends = [] for name in self.backend_names: b = rh.get_backend(name) if hasattr(b, '__await__'): b = await b backends.append(b) self._rh_session = rh.Session(backends=backends, uid=self._sid) # Register state-change callbacks for intermediate notifications self._notified_states: dict[str, str] = {} self._notified_lock = threading.Lock() for b in backends: if hasattr(b, 'register_callback'): orig = getattr(b, '_callback_func', None) def _on_state(task, state, _orig=orig): self._on_task_state_change(task, state) if _orig: _orig(task, state) b.register_callback(_on_state) # Create semaphore now that we're in an event loop self._watch_sem = asyncio.Semaphore(WATCH_CONCURRENCY) self._init_ready.set() log.info("[%s] Session initialization complete", self._sid) except Exception as e: self._init_error = str(e) self._init_ready.set() # unblock waiters log.error("[%s] Session initialization failed: %s", self._sid, e) raise
def _on_task_state_change(self, task, state): """Fire notification on intermediate state changes (e.g. RUNNING). Called from backend threads — uses lock for _notified_states access. """ uid = self._get_attr(task, 'uid') uid_str = str(uid) if uid else '?' state_str = str(state) with self._notified_lock: # Skip if we already notified this state if self._notified_states.get(uid_str) == state_str: return self._notified_states[uid_str] = state_str # Only fire for non-terminal states; terminal is handled by _watch_task if state_str.upper() in TERMINAL_STATES: return if self._plugin: self._plugin._dispatch_notify("task_status", { "uid": uid_str, "state": state_str, }) def _check_initialized(self) -> None: """Check that the session is active and fully initialized. Raises: HTTPException 409: Session is still initializing. HTTPException 500: Session initialization failed. RuntimeError: Session is closed. """ self._check_active() if not self._init_ready.is_set(): raise HTTPException(status_code=409, detail="session is still initializing") if self._init_error: raise HTTPException( status_code=500, detail=f"session init failed: {self._init_error}") def _deserialize_task(self, td: dict) -> dict: """Deserialize pickled or import-path function fields in a task dict. Handles two formats: - cloudpickle: ``"function": "cloudpickle::<base64>"`` with ``"_pickled_fields": ["function", "args", ...]`` - import path: ``"function": "module.path:func_name"`` Returns the (possibly modified) task dict. """ # --- cloudpickle-encoded fields --- pickled_fields = td.pop('_pickled_fields', None) if pickled_fields: if not self.allow_pickled_tasks: raise HTTPException( status_code=400, detail="pickled function tasks are disabled") for field in pickled_fields: val = td.get(field) if isinstance(val, str) and val.startswith('cloudpickle::'): encoded = val[len('cloudpickle::'):] cached = self._pickle_cache.get(encoded) if cached is not None: td[field] = cached else: obj = _cp.loads(base64.b64decode(encoded)) self._pickle_cache[encoded] = obj td[field] = obj return td # --- import-path string (e.g. "mymodule.sub:func_name") --- fn = td.get('function') if isinstance(fn, str) and ':' in fn and \ not fn.startswith('cloudpickle::'): mod_path, _, attr_name = fn.partition(':') try: mod = importlib.import_module(mod_path) td['function'] = getattr(mod, attr_name) except (ImportError, AttributeError) as e: raise HTTPException( status_code=400, detail=f"cannot resolve function '{fn}': {e}") return td def _prepare_batch(self, task_dicts: list[dict], pre_expanded: bool = False) -> list: """Deserialize and create task objects (runs in worker thread). CPU-bound work (cloudpickle, from_dict) is offloaded here so the event loop stays responsive for WebSocket keepalive. When *pre_expanded* is True the batch came from template expansion: all dicts already share field values by reference and the first dict's cloudpickle fields seed the pickle cache so subsequent dicts hit the cache. """ deserialized = [self._deserialize_task(td) for td in task_dicts] return [rh.BaseTask.from_dict(td) for td in deserialized]
[docs] async def submit_tasks(self, task_dicts: list[dict], pre_expanded: bool = False) -> list[dict]: """ Submit a list of tasks. Each dict is converted to a ``ComputeTask`` or ``AITask`` via ``BaseTask.from_dict()``. Function fields encoded as cloudpickle blobs or import-path strings are deserialized first. Uses a pipeline: deserialization of chunk N+1 runs concurrently with backend submission of chunk N, so the two dominant costs overlap. Returns: list[dict]: Minimal ack dicts ``{uid, state}``. """ self._check_initialized() prof = self.prof batch_n = len(task_dicts) bid = task_dicts[0].get('uid', '?') if task_dicts else '?' prof.prof('rh_submit', uid=bid, msg=str(batch_n)) _CHUNK = 4096 results = [] all_tasks = [] # collect for batch watcher # -- pipelined deser / submit ---------------------------------------- # Kick off deserialization of the first chunk while we have nothing # else to do; then overlap deser(N+1) with submit(N). chunks = [task_dicts[i:i + _CHUNK] for i in range(0, len(task_dicts), _CHUNK)] prev_submit_fut = None # future for the previous submit prev_tasks = None # tasks from the previous chunk for ci, chunk_dicts in enumerate(chunks): # Offload CPU-bound deserialization to a worker thread prof.prof('rh_deser', uid=bid, msg=str(len(chunk_dicts))) deser_fut = asyncio.ensure_future(asyncio.to_thread( self._prepare_batch, chunk_dicts, pre_expanded)) # While deser runs, await the *previous* chunk's submit if prev_submit_fut is not None: await prev_submit_fut prof.prof('rh_backend_submit_done', uid=bid) # Register the previously submitted tasks prof.prof('rh_register', uid=bid) for t in prev_tasks: uid_str = str(t.uid) self._tasks[uid_str] = t results.append({"uid": uid_str, "state": str(t.get("state"))}) all_tasks.extend(prev_tasks) prof.prof('rh_register_done', uid=bid) tasks = await deser_fut prof.prof('rh_deser_done', uid=bid) # Fire off the backend submit (will be awaited next iteration # or after the loop) prof.prof('rh_backend_submit', uid=bid) prev_submit_fut = asyncio.ensure_future( self._rh_session.submit_tasks(tasks)) prev_tasks = tasks # Yield so the event loop can process WS pings await asyncio.sleep(0) # Await the final chunk's submit if prev_submit_fut is not None: await prev_submit_fut prof.prof('rh_backend_submit_done', uid=bid) prof.prof('rh_register', uid=bid) for t in prev_tasks: uid_str = str(t.uid) self._tasks[uid_str] = t results.append({"uid": uid_str, "state": str(t.get("state"))}) all_tasks.extend(prev_tasks) prof.prof('rh_register_done', uid=bid) # Start a single batch watcher instead of per-task watchers if self._plugin and all_tasks: asyncio.ensure_future(self._watch_batch(all_tasks)) prof.prof('rh_submit_done', uid=bid, msg=str(batch_n)) return results
def _queue_notification(self, payload: dict) -> None: """Add a task notification to the batch buffer and ensure a flush is scheduled. Thread-safe — called from watcher coroutines. """ uid = payload.get('uid', '?') self.prof.prof('notify_queue', uid=uid) with self._notify_lock: self._notify_buf.append(payload) buf_len = len(self._notify_buf) if buf_len >= self._notify_batch_size: # Buffer full — flush immediately (sync, from event loop) self._flush_notifications() else: # Schedule a delayed flush so tail items aren't stranded. # Always schedule — it's cheap and a no-op if buffer is # empty when it fires. self._schedule_flush(delay=self._notify_batch_window) def _schedule_flush(self, delay: float = 0) -> None: """Schedule a notification flush on the event loop.""" if not self._plugin: return async def _do_flush(): if delay > 0: await asyncio.sleep(delay) self._flush_notifications() try: loop = asyncio.get_running_loop() loop.create_task(_do_flush()) except RuntimeError: if hasattr(self._plugin, '_main_loop') and \ self._plugin._main_loop: asyncio.run_coroutine_threadsafe( _do_flush(), self._plugin._main_loop) def _flush_notifications(self) -> None: """Flush the notification buffer as a bulk message.""" with self._notify_lock: if not self._notify_buf: return batch = list(self._notify_buf) self._notify_buf.clear() if not self._plugin: return prof = self.prof for item in batch: prof.prof('notify_flush', uid=item.get('uid', '?')) if len(batch) == 1: self._plugin._dispatch_notify("task_status", batch[0]) else: self._plugin._dispatch_notify("task_status_batch", {"tasks": batch}) async def _watch_task(self, task): """Background watcher for a single task: notify as soon as it completes. Concurrency is bounded by ``self._watch_sem`` so that thousands of simultaneous watchers do not overwhelm the event loop. """ uid = self._get_attr(task, 'uid') uid_str = str(uid) if uid else '?' # Acquire semaphore — queues here if too many watchers are active sem = self._watch_sem or asyncio.Semaphore(WATCH_CONCURRENCY) prof = self.prof async with sem: log.debug("[%s] Watcher started for task %s", self._sid, uid_str) prof.prof('rh_task_exec', uid=uid_str) try: if not self._rh_session: log.warning("[%s] Session closed before task %s completed", self._sid, uid_str) self._queue_notification({ "uid": uid_str, "state": "FAILED", "error": "Session closed"}) prof.prof('rh_task_done', uid=uid_str, state='FAILED') return await self._rh_session.wait_tasks([task]) state = self._get_attr(task, 'state') log.debug("[%s] Task %s completed with state: %s", self._sid, uid_str, state) prof.prof('rh_task_done', uid=uid_str, state=str(state)) d = self._notification_payload(task) self._queue_notification(d) except Exception as e: prof.prof('rh_task_done', uid=uid_str, state='FAILED') log.warning("[%s] Rhapsody watch error for task %s: %s", self._sid, uid_str, e) self._queue_notification({ "uid": uid_str, "state": "FAILED", "error": str(e)}) async def _watch_batch(self, tasks): """Watch a batch of tasks, notifying as each completes. Uses ``asyncio.wait(FIRST_COMPLETED)`` to drain completions incrementally. Notifications are queued per-task as soon as each finishes — the existing notification buffer (``_queue_notification``) batches them opportunistically so fast-completing tasks are grouped into single SSE messages while slow tasks don't block others. """ prof = self.prof # Build uid map and per-task futures fut_to_uid: dict[asyncio.Future, str] = {} uid_to_task: dict[str, object] = {} for t in tasks: uid = self._get_attr(t, 'uid') uid_str = str(uid) if uid else '?' uid_to_task[uid_str] = t prof.prof('rh_task_exec', uid=uid_str) if not self._rh_session: for uid_str in uid_to_task: prof.prof('rh_task_done', uid=uid_str, state='FAILED') self._queue_notification({ "uid": uid_str, "state": "FAILED", "error": "Session closed"}) return # Obtain per-task futures from the session's state manager sm = self._rh_session._state_manager for uid_str, t in uid_to_task.items(): fut = sm.get_wait_future(uid_str, t) fut_to_uid[fut] = uid_str # Drain completions incrementally pending = set(fut_to_uid.keys()) while pending: try: done, pending = await asyncio.wait( pending, return_when=asyncio.FIRST_COMPLETED) except Exception as e: log.warning("[%s] Batch watch error: %s", self._sid, e) break # Notify for every task that just completed for fut in done: uid_str = fut_to_uid[fut] t = uid_to_task[uid_str] state = self._get_attr(t, 'state') state_str = str(state) if state else 'UNKNOWN' prof.prof('rh_task_done', uid=uid_str, state=state_str) if state_str.upper() in TERMINAL_STATES: d = self._notification_payload(t) self._queue_notification(d) else: self._queue_notification({ "uid": uid_str, "state": state_str, "error": f"unexpected state: {state_str}"})
[docs] async def wait_tasks(self, uids: list[str], timeout: float | None = None) -> list[dict]: """ Return current task states (non-blocking snapshot). This method no longer blocks until tasks complete. Clients should rely on SSE ``task_status`` notifications for real-time completion events, and call this endpoint only to fetch the current state snapshot. Args: uids (list[str]): Task UIDs to query. timeout (float | None): Ignored (kept for API compat). Returns: list[dict]: Current task state dicts. """ self._check_initialized() tasks = [self._tasks[uid] for uid in uids if uid in self._tasks] if not tasks: raise HTTPException(status_code=404, detail="none of the requested tasks found") return [self._sanitize_task(t) for t in tasks]
def _get_attr(self, obj, attr, default=None): """Helper to get attribute from object or dict.""" val = getattr(obj, attr, None) if val is None and isinstance(obj, dict): val = obj.get(attr) return val if val is not None else default def _sanitize_task(self, t) -> dict: """Sanitize a Rhapsody task dict so it's JSON serializable.""" if hasattr(t, 'to_dict'): d = t.to_dict() else: d = dict(t) # Ensure 'uid' is present and a string uid = self._get_attr(t, 'uid') if uid: d['uid'] = str(uid) # Ensure 'state' is present and a string state = self._get_attr(t, 'state') if state: d['state'] = str(state) d.pop('future', None) if 'exception' in d and d['exception'] is not None: d['exception'] = str(d['exception']) # Stringify callable function fields fn = d.get('function') if callable(fn): d['function'] = f"{fn.__module__}.{fn.__qualname__}" # Decode bytes stdout/stderr; join lists (multi-rank) for key in ('stdout', 'stderr'): val = d.get(key) if isinstance(val, bytes): d[key] = val.decode('utf-8', errors='replace') elif isinstance(val, list): d[key] = '\n'.join(str(v) for v in val) # Ensure return_value is JSON-serializable rv = d.get('return_value') if rv is not None: if isinstance(rv, bytes): d['return_value'] = base64.b64encode(rv).decode('ascii') d['_return_value_encoding'] = 'base64' else: try: json.dumps(rv) except (TypeError, ValueError): d['return_value'] = str(rv) return d _NOTIFICATION_KEYS = {'uid', 'state', 'exit_code', 'return_value', '_return_value_encoding', 'error', 'exception', 'traceback'} def _notification_payload(self, t) -> dict: """Build a minimal notification dict for a completed task. Only essential fields are included to keep WebSocket/SSE payloads small. Clients needing the full task dict (e.g. stdout/stderr) can fetch it via ``GET /task/{sid}/{uid}``. """ full = self._sanitize_task(t) return {k: v for k, v in full.items() if k in self._NOTIFICATION_KEYS}
[docs] async def list_tasks(self) -> dict: """Return all tasks in this session with current state.""" self._check_initialized() tasks = [] for uid, task in self._tasks.items(): tasks.append(self._sanitize_task(task)) return {"tasks": tasks}
[docs] async def get_task(self, uid: str) -> dict: """ Return info for a single cached task. """ self._check_initialized() task = self._tasks.get(uid) if not task: raise HTTPException(status_code=404, detail=f"task {uid} not found") return self._sanitize_task(task)
[docs] async def cancel_task(self, uid: str) -> dict: """ Cancel a running task. """ self._check_initialized() task = self._tasks.get(uid) if not task: raise HTTPException(status_code=404, detail=f"task {uid} not found") backend_name = task.get("backend") if backend_name and backend_name in self._rh_session.backends: backend = self._rh_session.backends[backend_name] await backend.cancel_task(uid) return {"uid": uid, "status": "canceled"}
[docs] async def cancel_all_tasks(self) -> dict: """ Cancel all non-terminal tasks in this session. Best-effort: Dragon V3 marks tasks as CANCELED but cannot truly abort running work. Per-task errors are swallowed. Cancels are issued concurrently via ``asyncio.gather``. """ self._check_initialized() uids = [] for uid, task in list(self._tasks.items()): state = str(self._get_attr(task, 'state', '')).upper() if state not in TERMINAL_STATES: uids.append(uid) if not uids: return {"canceled": 0} async def _try_cancel(uid): try: await self.cancel_task(uid) return True except Exception: return False results = await asyncio.gather(*[_try_cancel(u) for u in uids]) return {"canceled": sum(1 for r in results if r)}
[docs] async def close(self) -> dict: """ Shutdown RHAPSODY session and clean up. """ if self._rh_session: await self._rh_session.close() self._rh_session = None self._tasks = {} return await super().close()
# --------------------------------------------------------------------------- # Application-side client # ---------------------------------------------------------------------------
[docs] class RhapsodyClient(PluginClient): """ Client-side interface for the Rhapsody plugin. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Session-wide accumulator for terminal task notifications. # Populated by a persistent SSE callback registered after # session init, so no notification is ever lost. self._completed: dict[str, dict] = {} self._completed_lock = threading.Lock() # Waiters: list of (set_of_uids, threading.Event) for wait_tasks self._waiters: list[tuple[set, threading.Event]] = [] self._waiters_lock = threading.Lock() def _on_task_done(self, edge, plugin, topic, data): """Persistent SSE callback: accumulate terminal task states. Handles both single ``task_status`` and bulk ``task_status_batch`` notifications. """ if topic == 'task_status_batch': tasks = data.get('tasks', []) else: tasks = [data] newly_done: list[str] = [] with self._completed_lock: for t in tasks: # Decode base64-encoded return values if t.get('_return_value_encoding') == 'base64': t['return_value'] = base64.b64decode(t['return_value']) del t['_return_value_encoding'] uid = t.get('uid') state = str(t.get('state', '')).upper() if uid and state in TERMINAL_STATES: self._completed[uid] = t newly_done.append(uid) # Wake waiters under completed_lock so no waiter can # register between adding to _completed and checking # pending sets. Lock order: completed → waiters. if newly_done: with self._waiters_lock: for pending, event in self._waiters: for uid in newly_done: pending.discard(uid) if not pending: event.set()
[docs] def register_session(self, backends: list[str] | None = None, init_timeout: float = 120, notify_batch_window: float | None = None, notify_batch_size: int | None = None): """ Register a session, optionally specifying backend names. The edge initializes the session asynchronously. This method blocks until a ``session_status`` SSE notification confirms that the session is ready (or until *init_timeout* seconds). Falls back to polling when no ``BridgeClient`` is available. Args: backends: List of backend names (e.g. ``['dragon_v3']``). Defaults to ``['dragon_v3']`` on the server side. init_timeout: Seconds to wait for session init (default 120). notify_batch_window: Seconds to accumulate notifications before flushing (edge-side). notify_batch_size: Max notifications per flush (edge-side). """ has_sse = (self._bc is not None and self._edge_id is not None and self._plugin_name is not None) # Ensure the SSE listener is connected BEFORE we send the POST, # so we never miss a fast init notification. ready = threading.Event() error = [None] if has_sse: self._bc.wait_for_listener(timeout=30) def _on_session_status(edge, plugin, topic, data): st = data.get('status') if st == 'ready': ready.set() elif st == 'failed': error[0] = data.get('error', 'unknown init error') ready.set() self.register_notification_callback(_on_session_status, topic="session_status") payload = {} if backends: payload['backends'] = backends if notify_batch_window is not None: payload['notify_batch_window'] = notify_batch_window if notify_batch_size is not None: payload['notify_batch_size'] = notify_batch_size resp = self._http.post(self._url('register_session'), json=payload) self._raise(resp) data = resp.json() self._sid = data['sid'] status = data.get('status') # Reset the session-wide task completion accumulator with self._completed_lock: self._completed.clear() if status == 'ready': if has_sse: self.unregister_notification_callback( _on_session_status, topic="session_status") self._start_task_listener() return # fast path: init was synchronous if not has_sse: self._poll_session_ready(init_timeout) return # Wait for async init to complete via SSE notification try: ready.wait(timeout=init_timeout) if error[0]: raise RuntimeError( f"Session init failed on edge: {error[0]}") if not ready.is_set(): raise RuntimeError( f"Session init timed out after {init_timeout}s") finally: self.unregister_notification_callback(_on_session_status, topic="session_status") self._start_task_listener()
def _start_task_listener(self): """Register persistent SSE callback that accumulates completions.""" has_sse = (self._bc is not None and self._edge_id is not None and self._plugin_name is not None) if has_sse: self.register_notification_callback(self._on_task_done, topic="task_status") self.register_notification_callback(self._on_task_done, topic="task_status_batch") def _poll_session_ready(self, timeout: float = 120) -> None: """Fallback: poll until the session is ready (no SSE available).""" deadline = time.time() + timeout while time.time() < deadline: try: resp = self._http.get( self._url(f"list_tasks/{self.sid}")) if resp.status_code != 409: return # session is ready (or already errored) except Exception: pass time.sleep(1.0) raise RuntimeError( f"Session init timed out after {timeout}s (poll)") @staticmethod def _serialize_task(td: dict) -> None: """Prepare a task dict for JSON transport (in-place). - Encodes callable ``function``, ``args``, ``kwargs`` via cloudpickle + base64. - Strips non-serializable internal fields (``future``, ``_future``, ``backend``). """ pickled_fields = td.get('_pickled_fields', []) # Serialize callable function fn = td.get('function') if callable(fn): encoded = base64.b64encode(_cp.dumps(fn)).decode('ascii') td['function'] = 'cloudpickle::' + encoded if 'function' not in pickled_fields: pickled_fields.append('function') # Serialize args/kwargs if not JSON-safe for field in ('args', 'kwargs'): val = td.get(field) if val is None: continue if isinstance(val, str) and val.startswith('cloudpickle::'): continue try: json.dumps(val) except (TypeError, ValueError): encoded = base64.b64encode(_cp.dumps(val)).decode('ascii') td[field] = 'cloudpickle::' + encoded if field not in pickled_fields: pickled_fields.append(field) if pickled_fields: td['_pickled_fields'] = pickled_fields # Strip non-serializable internal fields td.pop('future', None) td.pop('_future', None) td.pop('backend', None) # Strip None-valued type-discriminator fields so that # BaseTask.from_dict() routes to the correct task class # (it checks key existence, not truthiness). for key in ('prompt', 'executable', 'function'): if key in td and td[key] is None: del td[key]
[docs] def submit_tasks(self, task_dicts: list[dict]) -> list[dict]: """ Submit tasks to the edge. Large batches are automatically split so each payload stays within the WebSocket frame limit. Batches are submitted concurrently via a thread pool so that network round-trips overlap (pipelining). UIDs are assigned client-side (if absent) so the caller can start waiting for SSE notifications immediately. Args: task_dicts: List of task specification dicts. Returns: list[dict]: Submitted task info (uid, state). """ from concurrent.futures import ThreadPoolExecutor, as_completed self._require_session() # --- serialize callables and clean up internal fields --- for td in task_dicts: self._serialize_task(td) # --- assign UIDs client-side so we know them before submit --- for td in task_dicts: if 'uid' not in td: td['uid'] = f"task.{uuid.uuid4().hex[:8]}" # --- try template compression for homogeneous batches --- # If all tasks share the same fields (except uid), send a # template + list of UIDs instead of N full copies. url = self._url(f"submit/{self.sid}") if len(task_dicts) > 1: ref = task_dicts[0] ref_keys = set(ref) - {'uid'} homogeneous = all( set(td) - {'uid'} == ref_keys and all(td[k] is ref[k] or td[k] == ref[k] for k in ref_keys) for td in task_dicts[1:]) else: homogeneous = False if homogeneous and len(task_dicts) > 1: first = {k: v for k, v in ref.items() if k != 'uid'} return self._submit_template( url, first, [td['uid'] for td in task_dicts]) # --- split into size-aware batches (byte limit only) --- batches: list[list[dict]] = [] batch: list[dict] = [] batch_bytes = 0 for td in task_dicts: td_size = len(str(td)) + 2 if batch and batch_bytes + td_size > WS_PAYLOAD_LIMIT: batches.append(batch) batch = [] batch_bytes = 0 batch.append(td) batch_bytes += td_size if batch: batches.append(batch) # --- submit batches concurrently (pipelining) --- errors: list[str] = [] def _submit_batch(b): resp = self._http.post( url, data=msgpack.packb({"tasks": b}, use_bin_type=True), headers={"Content-Type": "application/msgpack"}) self._raise(resp, f"submit {len(b)} task(s)") return resp.json() if len(batches) == 1: return _submit_batch(batches[0]) results = [] batch_results: dict[int, list] = {} with ThreadPoolExecutor(max_workers=len(batches)) as pool: futures = {pool.submit(_submit_batch, b): i for i, b in enumerate(batches)} for fut in as_completed(futures): idx = futures[fut] try: batch_results[idx] = fut.result() except Exception as e: errors.append(str(e)) if errors: detail = '; '.join(errors[:3]) raise RuntimeError( f"submit failed for {len(errors)} batch(es): {detail}") for i in sorted(batch_results): results.extend(batch_results[i]) return results
def _submit_template(self, url: str, template: dict, uids: list[str]) -> list[dict]: """Submit homogeneous tasks via template compression. Sends one template + list of UIDs instead of N full task dicts. Falls back to regular submit in WS_PAYLOAD_LIMIT-sized chunks. """ from concurrent.futures import ThreadPoolExecutor, as_completed # Cap chunks by both byte size and count so the server # doesn't block the event loop processing too many at once. max_by_bytes = max( 1, (WS_PAYLOAD_LIMIT - len(str(template))) // 20) uids_per_chunk = min(max_by_bytes, 8192) chunks = [uids[i:i + uids_per_chunk] for i in range(0, len(uids), uids_per_chunk)] def _submit_chunk(uid_chunk): payload = {"template": template, "uids": uid_chunk} resp = self._http.post( url, data=msgpack.packb(payload, use_bin_type=True), headers={"Content-Type": "application/msgpack"}) self._raise(resp, f"submit template {len(uid_chunk)} task(s)") return resp.json() if len(chunks) == 1: return _submit_chunk(chunks[0]) results = [] batch_results: dict[int, list] = {} errors: list[str] = [] with ThreadPoolExecutor(max_workers=len(chunks)) as pool: futures = {pool.submit(_submit_chunk, c): i for i, c in enumerate(chunks)} for fut in as_completed(futures): idx = futures[fut] try: batch_results[idx] = fut.result() except Exception as e: errors.append(str(e)) if errors: detail = '; '.join(errors[:3]) raise RuntimeError( f"submit failed for {len(errors)} chunk(s): {detail}") for i in sorted(batch_results): results.extend(batch_results[i]) return results
[docs] def wait_tasks(self, uids: list[str], timeout: float | None = None) -> list[dict]: """ Wait for tasks to reach terminal state via SSE notifications. Purely client-side: the persistent ``_on_task_done`` callback (registered at session init) accumulates completions into ``self._completed``. This method checks the accumulator and blocks only until every requested UID appears there. Falls back to periodic polling when no ``BridgeClient`` is available (e.g. direct construction in tests). Args: uids: Task UIDs to wait for. timeout: Seconds to wait (None = forever). Returns: list[dict]: Completed task dicts. """ self._require_session() # ------------------------------------------------------------------ # Check if SSE notifications are available # ------------------------------------------------------------------ has_sse = (self._bc is not None and self._edge_id is not None and self._plugin_name is not None) if not has_sse: return self._wait_tasks_poll(uids, timeout) # ------------------------------------------------------------------ # SSE-based wait (preferred path) # ------------------------------------------------------------------ # Build pending set and register waiter atomically so no # completions slip through between the check and the # registration. Lock order matches _on_task_done: # completed_lock → waiters_lock. done = threading.Event() with self._completed_lock: if all(uid in self._completed for uid in uids): return [self._completed[uid] for uid in uids] pending = set(uid for uid in uids if uid not in self._completed) with self._waiters_lock: waiter = (pending, done) self._waiters.append(waiter) try: done.wait(timeout=timeout) finally: with self._waiters_lock: try: self._waiters.remove(waiter) except ValueError: pass with self._completed_lock: return [self._completed.get(uid, {"uid": uid, "state": "UNKNOWN"}) for uid in uids]
def _wait_tasks_poll(self, uids: list[str], timeout: float | None = None) -> list[dict]: """Fallback wait via periodic polling (no SSE available).""" url = self._url(f"wait/{self.sid}") payload: dict = {"uids": uids} if timeout is not None: payload["timeout"] = timeout deadline = (time.time() + timeout) if timeout else None while True: resp = self._http.post(url, json=payload) self._raise(resp, f"wait {len(uids)} task(s)") tasks = resp.json() # Check if all are terminal all_done = all( str(t.get('state', '')).upper() in TERMINAL_STATES for t in tasks) if all_done: return tasks if deadline and time.time() >= deadline: return tasks # return whatever we have time.sleep(1.0)
[docs] def list_tasks(self) -> dict: """List all tasks in this session.""" self._require_session() resp = self._http.get(self._url(f"list_tasks/{self.sid}")) self._raise(resp) return resp.json()
[docs] def get_task(self, uid: str) -> dict: """ Retrieve info for a single task. """ self._require_session() url = self._url(f"task/{self.sid}/{uid}") resp = self._http.get(url) self._raise(resp) return resp.json()
[docs] def cancel_task(self, uid: str) -> dict: """ Cancel a task. """ self._require_session() url = self._url(f"cancel/{self.sid}/{uid}") resp = self._http.post(url) self._raise(resp) return resp.json()
[docs] def cancel_all_tasks(self) -> dict: """ Cancel all non-terminal tasks in this session. """ self._require_session() url = self._url(f"cancel_all/{self.sid}") resp = self._http.post(url) self._raise(resp) return resp.json()
# --------------------------------------------------------------------------- # Server-side plugin # ---------------------------------------------------------------------------
[docs] class PluginRhapsody(Plugin): ''' Rhapsody plugin for Radical Edge. Exposes the RHAPSODY Session / Task API via REST endpoints: - POST /rhapsody/register_session – create session - POST /rhapsody/submit/{sid} – submit tasks - POST /rhapsody/wait/{sid} – query task states - GET /rhapsody/list_tasks/{sid} – list all tasks - GET /rhapsody/task/{sid}/{uid} – get single task - POST /rhapsody/cancel/{sid}/{uid} – cancel single task - POST /rhapsody/cancel_all/{sid} – cancel all tasks Notification topics: ``session_status``, ``task_status``, ``task_status_batch``. ''' plugin_name = "rhapsody" session_class = RhapsodySession client_class = RhapsodyClient version = '0.0.1' ui_config = { "icon": "🎼", "title": "Rhapsody Tasks", "description": "Submit compute tasks, wait for results, view stdout/stderr.", "forms": [{ "id": "submit", "title": "📝 Submit Task", "layout": "single", "fields": [ {"name": "exec", "type": "text", "label": "Executable", "default": "/bin/echo", "css_class": "rh-exec"}, {"name": "args", "type": "text", "label": "Arguments (space-separated)", "default": "hello from rhapsody", "css_class": "rh-args"}, {"name": "backends", "type": "select", "label": "Backend", "options": ["dragon_v3", "concurrent"], "css_class": "rh-backends"}, {"name": "timeout", "type": "number", "label": "Timeout (s)", "default": "", "css_class": "rh-timeout"}, {"name": "ranks", "type": "number", "label": "MPI Ranks", "default": "", "css_class": "rh-ranks"}, {"name": "type", "type": "select", "label": "Task Type", "options": ["", "mpi"], "css_class": "rh-type"}, {"name": "cwd", "type": "text", "label": "Working Dir", "default": "", "css_class": "rh-cwd"}, ], "submit": {"label": "▶ Submit Task", "style": "success"} }], "monitors": [{ "id": "tasks", "title": "📊 Task Monitor", "type": "task_list", "css_class": "rh-output", "empty_text": "No tasks submitted yet." }], "notifications": { "topic": "task_status", "id_field": "uid", "state_field": "state" } }
[docs] @classmethod def is_enabled(cls, app: FastAPI) -> bool: """Rhapsody loads on compute nodes only (task execution).""" from .batch_system import detect_batch_system return detect_batch_system().in_allocation()
def __init__(self, app: FastAPI, instance_name: str = "rhapsody"): super().__init__(app, instance_name) self.add_route_post('submit/{sid}', self.submit_tasks) self.add_route_post('wait/{sid}', self.wait_tasks) self.add_route_get('list_tasks/{sid}', self.list_tasks) self.add_route_get('task/{sid}/{uid}', self.get_task) self.add_route_post('cancel/{sid}/{uid}', self.cancel_task) self.add_route_post('cancel_all/{sid}', self.cancel_all_tasks)
[docs] async def register_session(self, request: Request) -> dict: """Register a new Rhapsody session. Accepts an optional JSON body with ``{"backends": ["name", ...]}``. Session initialization happens asynchronously in the background. The SID is returned immediately. The client should wait for a ``session_status`` SSE notification (``status: "ready"``) before submitting tasks, or handle HTTP 409 on early requests. """ try: data = await request.json() except Exception: data = {} backend_names = data.get('backends') notify_batch_window = data.get('notify_batch_window', NOTIFY_BATCH_WINDOW) notify_batch_size = data.get('notify_batch_size', NOTIFY_BATCH_SIZE) # Build session directly to avoid race on shared plugin state self._ensure_cleanup_task() sid = f"session.{uuid.uuid4().hex[:8]}" session = self.session_class( sid, backend_names=backend_names, notify_batch_window=float(notify_batch_window), notify_batch_size=int(notify_batch_size), ) session._plugin = self self._sessions[sid] = session self._session_last_access[sid] = time.time() log.info("[%s] Registered session %s", self.instance_name, sid) # Kick off initialization in the background so the HTTP response # (and therefore the WebSocket slot) is released immediately. asyncio.create_task(self._init_session(sid, session)) return {"sid": sid, "status": "initializing"}
async def _init_session(self, sid: str, session) -> None: """Background task: initialize a session and notify via SSE.""" if session._init_ready.is_set(): return # already initialized (e.g. by test setup) try: if hasattr(session, 'initialize'): await session.initialize() self._dispatch_notify("session_status", { "sid": sid, "status": "ready", }) except Exception as e: log.error("[%s] Session %s init failed: %s", self.instance_name, sid, e) self._dispatch_notify("session_status", { "sid": sid, "status": "failed", "error": str(e), }) # -- route handlers -----------------------------------------------------
[docs] async def submit_tasks(self, request: Request) -> dict: sid = request.path_params['sid'] prof = getattr(getattr(self._app.state, 'edge_service', None), '_prof', None) if prof: prof.prof('rh_parse_body', msg=sid) data = await request.json() if prof: prof.prof('rh_parse_body_done', msg=sid) # Support template compression: {"template": {...}, "uids": [...]} template = data.get('template') if template is not None: uids = data.get('uids', []) if prof: prof.prof('rh_template_expand', msg='%d tasks' % len(uids)) task_dicts = [dict(template, uid=uid) for uid in uids] if prof: prof.prof('rh_template_expand_done') pre_expanded = True else: task_dicts = data.get('tasks', []) pre_expanded = False return await self._forward(sid, RhapsodySession.submit_tasks, task_dicts=task_dicts, pre_expanded=pre_expanded)
[docs] async def wait_tasks(self, request: Request) -> dict: sid = request.path_params['sid'] data = await request.json() uids = data.get('uids', []) timeout = data.get('timeout') return await self._forward(sid, RhapsodySession.wait_tasks, uids=uids, timeout=timeout)
[docs] async def list_tasks(self, request: Request) -> dict: sid = request.path_params['sid'] return await self._forward(sid, RhapsodySession.list_tasks)
[docs] async def get_task(self, request: Request) -> dict: sid = request.path_params['sid'] uid = request.path_params['uid'] return await self._forward(sid, RhapsodySession.get_task, uid=uid)
[docs] async def cancel_task(self, request: Request) -> dict: sid = request.path_params['sid'] uid = request.path_params['uid'] return await self._forward(sid, RhapsodySession.cancel_task, uid=uid)
[docs] async def cancel_all_tasks(self, request: Request) -> dict: sid = request.path_params['sid'] return await self._forward(sid, RhapsodySession.cancel_all_tasks)