Source code for radical.edge.plugin_psij

'''
PsiJ plugin for RADICAL Edge — HPC job submission.

Three-class pattern
-------------------
PSIJSession   Edge-side session: holds one PsiJ ``Executor`` per submit call,
              manages job state via callbacks and background polling, streams
              stdout/stderr incrementally.

PSIJClient    Application-side thin HTTP wrapper: delegates to the edge service
              over the bridge (``submit_job``, ``get_job_status``, ``list_jobs``,
              ``cancel_job``, ``submit_tunneled``, ``tunnel_status``).

PluginPSIJ    Registers the plugin with the edge, adds URL routes, and wires
              requests to the correct PSIJSession via ``_forward()``.
'''

import asyncio
import json as _json
import logging
import os
import pathlib
import re
import shutil
import subprocess
import threading
import time

from datetime import timedelta
from typing import Any, Dict

from fastapi import FastAPI, HTTPException, Request
from starlette.responses import JSONResponse

import psij

from .plugin_base import Plugin
from .plugin_session_base import PluginSession
from .client import PluginClient

log = logging.getLogger("radical.edge")

# Default poll interval for job status updates (in seconds)
PSIJ_POLL_INTERVAL = 5.0

# Where reverse-tunnel port rendezvous files are written
_RELAY_BASE = pathlib.Path.home() / '.radical' / 'edge' / 'tunnels'

# Persistent directory for job stdout/stderr capture
_OUTPUT_BASE = pathlib.Path.home() / '.radical' / 'edge' / 'psij' / 'output'

# Maximum age (days) for stale output directories cleaned up on session creation
_OUTPUT_MAX_AGE_DAYS = 7


def _relay_dir() -> pathlib.Path:
    """Return (and create) the relay-file directory."""
    _RELAY_BASE.mkdir(parents=True, exist_ok=True)
    return _RELAY_BASE

# Terminal states that don't need further polling
TERMINAL_STATES = {'COMPLETED', 'FAILED', 'CANCELED'}


def _normalize_state(state) -> str:
    """Normalize a PsiJ JobState to a plain string (strip 'JobState.' prefix)."""
    s = str(state)
    return s[9:] if s.startswith('JobState.') else s


def _read_output_file(job, attr: str, offset: int = 0) -> str:
    """Read stdout or stderr from a job's spec path attribute.

    Args:
        job:    PsiJ job object.
        attr:   Attribute name on job.spec ('stdout_path' or 'stderr_path').
        offset: Byte offset to start reading from (0 = full file).

    Returns:
        Content read from the file starting at offset.
    """
    try:
        path = getattr(job.spec, attr, None)
        if path and os.path.exists(str(path)):
            with open(str(path), 'r') as f:
                if offset > 0:
                    f.seek(offset)
                return f.read()
    except Exception as e:
        log.debug("Failed to read %s for job: %s", attr, e)
    return ""


def _output_file_size(job, attr: str) -> int:
    """Return the byte size of a job's stdout/stderr file, or 0."""
    try:
        path = getattr(job.spec, attr, None)
        if path and os.path.exists(str(path)):
            return os.path.getsize(str(path))
    except Exception:
        pass
    return 0


[docs] class PSIJSession(PluginSession): ''' Session-specific PSIJ state. ''' poll_interval = PSIJ_POLL_INTERVAL def __init__(self, sid: str, **kwargs: Any): super().__init__(sid) self._jobs: Dict[str, Any] = {} # job_id -> psij.Job self._job_meta: Dict[str, dict] = {} # job_id -> submission metadata self._job_states: Dict[str, str] = {} # track last known state per job self._poll_interval = kwargs.get('poll_interval', self.poll_interval) self._poll_task = None # Persistent output directory for this session's job stdout/stderr self._output_dir = _OUTPUT_BASE / sid self._cleanup_stale_output() self._output_dir.mkdir(parents=True, exist_ok=True) def _cleanup_stale_output(self) -> None: """Remove output directories older than _OUTPUT_MAX_AGE_DAYS.""" if not _OUTPUT_BASE.exists(): return cutoff = time.time() - _OUTPUT_MAX_AGE_DAYS * 86400 for entry in _OUTPUT_BASE.iterdir(): if not entry.is_dir() or entry == self._output_dir: continue try: if entry.stat().st_mtime < cutoff: shutil.rmtree(entry) log.info("Cleaned up stale output dir: %s", entry) except Exception as e: log.debug("Failed to clean up %s: %s", entry, e)
[docs] async def submit_job(self, job_spec_dict: Dict[str, Any], executor_name: str = 'local') -> Dict[str, Any]: ''' Submit a job via PSIJ. ''' try: spec = psij.JobSpec() executable = job_spec_dict.get('executable') arguments = job_spec_dict.get('arguments') spec.executable = executable if arguments: spec.arguments = arguments if 'directory' in job_spec_dict: spec.directory = job_spec_dict['directory'] if 'environment' in job_spec_dict: spec.environment = job_spec_dict['environment'] if 'attributes' in job_spec_dict: attribs = job_spec_dict['attributes'] spec.attributes = psij.JobAttributes() duration = attribs.get("duration") if duration: spec.attributes.duration = timedelta(seconds=int(duration)) spec.attributes.queue_name = attribs.get("queue_name") spec.attributes.account = attribs.get("account") spec.attributes.reservation_id = attribs.get("reservation_id") node_count = attribs.get("node_count") if node_count: spec.attributes.resource_count = int(node_count) if 'custom_attributes' in job_spec_dict: spec.attributes.custom_attributes = dict( job_spec_dict['custom_attributes']) job = psij.Job(spec) out_path = str(self._output_dir / f"{job.id}.out") err_path = str(self._output_dir / f"{job.id}.err") spec.stdout_path = out_path spec.stderr_path = err_path ex = psij.JobExecutor.get_instance(executor_name) # Set poll interval for status updates if hasattr(ex, 'poll_interval'): ex.poll_interval = self._poll_interval self._jobs[job.id] = job # Store submission metadata for later retrieval attribs = job_spec_dict.get('attributes', {}) self._job_meta[job.id] = { 'executable': executable, 'arguments': arguments or [], 'executor': executor_name, 'directory': job_spec_dict.get('directory'), 'queue_name': attribs.get('queue_name'), 'account': attribs.get('account'), 'node_count': attribs.get('node_count'), 'duration': attribs.get('duration'), } # Register status callback BEFORE submit so no transitions are missed plugin = self._plugin job_id = job.id last_state = None def _on_status(j, status): nonlocal last_state state_str = _normalize_state(status.state) # Skip if state hasn't changed if state_str == last_state: return last_state = state_str is_terminal = state_str in TERMINAL_STATES stdout_content = "" stderr_content = "" if is_terminal: stdout_content = _read_output_file(j, 'stdout_path') stderr_content = _read_output_file(j, 'stderr_path') if plugin: plugin._dispatch_notify("job_status", { "job_id": job_id, "state": state_str, "exit_code": status.exit_code if is_terminal else None, "stdout": stdout_content, "stderr": stderr_content }) job.set_job_status_callback(_on_status) ex.submit(job) # Start background polling for job status updates self._start_polling() log.info("Submitted job %s to %s", job.id, executor_name) return {"job_id": job.id, "native_id": job.native_id} except Exception as e: log.exception("Job submission failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) from e
[docs] async def get_job_status(self, job_id: str, stdout_offset: int = 0, stderr_offset: int = 0) -> Dict[str, Any]: ''' Get job status with metadata and optional stdout/stderr offset. ''' job = self._jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail=f"Job {job_id} not found") status = job.status state_str = _normalize_state(status.state) stdout_content = _read_output_file(job, 'stdout_path', stdout_offset) stderr_content = _read_output_file(job, 'stderr_path', stderr_offset) meta = self._job_meta.get(job_id, {}) return { "job_id": job_id, "native_id": job.native_id, "state": state_str, "message": status.message, "exit_code": status.exit_code, "time": status.time, "executable": meta.get('executable'), "arguments": meta.get('arguments', []), "executor": meta.get('executor'), "directory": meta.get('directory'), "queue_name": meta.get('queue_name'), "account": meta.get('account'), "node_count": meta.get('node_count'), "duration": meta.get('duration'), "stdout": stdout_content, "stderr": stderr_content, "stdout_offset": _output_file_size(job, 'stdout_path'), "stderr_offset": _output_file_size(job, 'stderr_path'), }
[docs] async def list_jobs(self) -> Dict[str, Any]: ''' List all jobs in this session with current state and metadata. ''' jobs = [] for job_id, job in self._jobs.items(): state_str = _normalize_state(job.status.state) meta = self._job_meta.get(job_id, {}) jobs.append({ "job_id": job_id, "native_id": job.native_id, "state": state_str, "exit_code": job.status.exit_code, "executable": meta.get('executable'), "arguments": meta.get('arguments', []), "executor": meta.get('executor'), "queue_name": meta.get('queue_name'), "account": meta.get('account'), "node_count": meta.get('node_count'), }) return {"jobs": jobs}
[docs] async def cancel_job(self, job_id: str) -> Dict[str, Any]: ''' Cancel a job. ''' job = self._jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail=f"Job {job_id} not found") try: job.cancel() return {"job_id": job_id, "status": "canceled"} except Exception as e: log.exception("Job cancellation failed: %s", e) raise HTTPException(status_code=500, detail=str(e)) from e
[docs] async def close(self) -> dict: ''' Close the session and stop polling. ''' if self._poll_task: self._poll_task.cancel() try: await self._poll_task except asyncio.CancelledError: pass self._poll_task = None # Clean up this session's output directory if self._output_dir.exists(): try: shutil.rmtree(self._output_dir) except Exception as e: log.debug("Failed to remove output dir %s: %s", self._output_dir, e) return await super().close()
def _start_polling(self): ''' Start the background polling task if not already running. ''' if self._poll_task is None or self._poll_task.done(): self._poll_task = asyncio.create_task(self._poll_jobs()) async def _poll_jobs(self): ''' Background task that polls job status and sends notifications. ''' first = True while True: try: if first: # Short delay on first poll to catch fast state transitions await asyncio.sleep(0.5) first = False else: await asyncio.sleep(self._poll_interval) # Check all non-terminal jobs for job_id, job in list(self._jobs.items()): try: status = job.status state_str = _normalize_state(status.state) # Skip if state hasn't changed last_state = self._job_states.get(job_id) if state_str == last_state: continue self._job_states[job_id] = state_str is_terminal = state_str in TERMINAL_STATES stdout_content = "" stderr_content = "" if is_terminal: stdout_content = _read_output_file(job, 'stdout_path') stderr_content = _read_output_file(job, 'stderr_path') if self._plugin: self._plugin._dispatch_notify("job_status", { "job_id": job_id, "state": state_str, "exit_code": status.exit_code if is_terminal else None, "stdout": stdout_content, "stderr": stderr_content }) except Exception as e: log.debug("Error polling job %s: %s", job_id, e) # Check if all jobs are terminal - if so, stop polling if all(self._job_states.get(jid) in TERMINAL_STATES for jid in self._jobs): break except asyncio.CancelledError: break except Exception as e: log.debug("Polling error: %s", e)
[docs] class PSIJClient(PluginClient): """ Client-side interface for the PSIJ plugin. """
[docs] def submit_job(self, job_spec: Dict[str, Any], executor: str = 'local') -> Dict[str, Any]: """ Submit a job. Args: job_spec (dict): The job specification. executor (str): The executor to use. Returns: dict: Job submission result (job_id, native_id). """ self._require_session() url = self._url(f"submit/{self.sid}") payload = {"job_spec": job_spec, "executor": executor} resp = self._http.post(url, json=payload) self._raise(resp, f"psij submit {job_spec.get('executable','?')!r} on {executor!r}") return resp.json()
[docs] def get_job_status(self, job_id: str, stdout_offset: int = 0, stderr_offset: int = 0) -> Dict[str, Any]: """ Get the status of a job. Args: job_id: The job ID to query. stdout_offset: Byte offset for stdout (0 = full). stderr_offset: Byte offset for stderr (0 = full). Returns: Job status info including metadata and stdout/stderr. """ self._require_session() url = self._url(f"status/{self.sid}/{job_id}") params = {} if stdout_offset: params['stdout_offset'] = str(stdout_offset) if stderr_offset: params['stderr_offset'] = str(stderr_offset) resp = self._http.get(url, params=params) self._raise(resp, f"job status {job_id!r}") return resp.json()
[docs] def list_jobs(self) -> Dict[str, Any]: """ List all jobs in this session. Returns: dict with 'jobs' list. """ self._require_session() resp = self._http.get(self._url(f"list_jobs/{self.sid}")) self._raise(resp) return resp.json()
[docs] def cancel_job(self, job_id: str) -> Dict[str, Any]: """ Cancel a job. Args: job_id: The job ID to cancel. Returns: Cancellation result. """ self._require_session() url = self._url(f"cancel/{self.sid}/{job_id}") resp = self._http.post(url) self._raise(resp, f"cancel job {job_id!r}") return resp.json()
[docs] def submit_tunneled(self, job_spec: Dict[str, Any], executor: str = 'local', tunnel: bool = False) -> Dict[str, Any]: """Submit a job that launches a child Edge service on a compute node. The ``job_spec.arguments`` list *must* contain ``-n <edge_name>`` or ``--name <edge_name>`` so the child edge can register under the correct name. When *tunnel* is ``True`` the server automatically appends ``--tunnel`` to the job arguments. The plugin-side watcher then opens a reverse SSH tunnel (login node → compute node) once the job is running and writes the port to ``~/.radical/edge/tunnels/{edge_name}.port``. The child edge service reads that file at startup and rewrites its bridge URL to connect through the tunnel. Args: job_spec: PsiJ job specification dict. ``arguments`` must include ``-n <edge_name>``. executor: PsiJ executor name (default: ``"local"``). tunnel: Whether to set up a reverse SSH tunnel (default: False). Returns: dict with ``job_id``, ``native_id``, and ``edge_name``. Raises: RuntimeError: If the server returns an error response. """ self._require_session() url = self._url(f"submit_tunneled/{self.sid}") payload = {"job_spec": job_spec, "executor": executor, "tunnel": tunnel} resp = self._http.post(url, json=payload) self._raise(resp, f"psij submit_tunneled on {executor!r}") return resp.json()
[docs] def tunnel_status(self, edge_name: str) -> Dict[str, Any]: """Return the current tunnel status for a named edge. This endpoint is session-less (no session required). Args: edge_name: The logical name of the child edge service. Returns: dict with fields: - ``edge_name`` — echoed back. - ``status`` — one of ``"pending"``, ``"active"``, ``"failed"``, ``"done"``, or ``"no_tunnel"``. - ``port`` — assigned tunnel port (int) once active, else null. - ``pid`` — SSH process PID, once spawned, else null. """ resp = self._http.get(self._url(f"tunnel_status/{edge_name}")) self._raise(resp, f"tunnel_status {edge_name!r}") return resp.json()
[docs] class PluginPSIJ(Plugin): ''' PSIJ plugin for Radical Edge. This plugin provides an interface to submit and manage jobs via the `psij-python` library. ''' plugin_name = "psij" session_class = PSIJSession client_class = PSIJClient version = '0.0.1' ui_config = { "icon": "🚀", "title": "PsiJ Jobs", "description": "Submit and monitor HPC batch jobs via PsiJ.", "forms": [{ "id": "submit", "title": "📝 Submit Job", "layout": "grid2", "fields": [ {"name": "exec", "type": "text", "label": "Executable", "default": "radical-edge-wrapper.sh", "css_class": "p-exec", "column": 0}, {"name": "args", "type": "text", "label": "Arguments (space-separated)", "placeholder": "auto-filled with --url and --name", "css_class": "p-args", "column": 0}, {"name": "executor", "type": "select", "label": "Executor", "options": ["local", "slurm", "pbs", "lsf"], "css_class": "p-executor", "column": 0}, {"name": "queue", "type": "text", "label": "Queue / Partition", "placeholder": "optional", "required": False, "css_class": "p-queue", "column": 1}, {"name": "account", "type": "text", "label": "Account / Project", "placeholder": "optional", "required": False, "css_class": "p-account", "column": 1}, {"name": "duration", "type": "text", "label": "Duration (seconds)", "placeholder": "e.g. 600", "required": False, "css_class": "p-duration", "column": 1}, {"name": "node_count", "type": "number", "label": "Number of Nodes", "placeholder": "e.g. 1", "required": False, "css_class": "p-node-count", "column": 1}, {"name": "custom", "type": "custom_attributes", "label": "🔧 Custom Attributes", "required": False, "css_class": "p-custom-attr", "column": 1}, ], "submit": {"label": "🚀 Submit Job", "style": "success"} }], "monitors": [{ "id": "jobs", "title": "📊 Job Monitor", "type": "task_list", "css_class": "psij-output", "empty_text": "No jobs submitted yet." }], "notifications": { "topic": "job_status", "id_field": "job_id", "state_field": "state" } } def __init__(self, app: FastAPI, instance_name: str = "psij"): super().__init__(app, instance_name) # watcher tasks keyed by edge_name (plugin-level, survive session cleanup) self._watchers: dict = {} # SSH tunnel processes keyed by edge_name self._tunnel_procs: dict = {} # Ensure relay directory exists at startup _relay_dir() self._app.router.on_shutdown.append(self._cleanup_tunnels) self.add_route_post('submit/{sid}', self.submit_job) self.add_route_post('submit_tunneled/{sid}', self.submit_tunneled) self.add_route_get('tunnel_status/{edge_name}', self.tunnel_status) self.add_route_get('status/{sid}/{job_id}', self.get_job_status) self.add_route_get('list_jobs/{sid}', self.list_jobs) self.add_route_post('cancel/{sid}/{job_id}', self.cancel_job)
[docs] async def submit_job(self, request: Request) -> JSONResponse: sid = request.path_params['sid'] data = await request.json() job_spec = data.get('job_spec', {}) executor = data.get('executor', 'local') return await self._forward(sid, PSIJSession.submit_job, job_spec_dict=job_spec, executor_name=executor)
[docs] async def get_job_status(self, request: Request) -> JSONResponse: sid = request.path_params['sid'] job_id = request.path_params['job_id'] so = int(request.query_params.get('stdout_offset', '0')) se = int(request.query_params.get('stderr_offset', '0')) return await self._forward(sid, PSIJSession.get_job_status, job_id=job_id, stdout_offset=so, stderr_offset=se)
[docs] async def list_jobs(self, request: Request) -> JSONResponse: sid = request.path_params['sid'] return await self._forward(sid, PSIJSession.list_jobs)
[docs] async def cancel_job(self, request: Request) -> JSONResponse: sid = request.path_params['sid'] job_id = request.path_params['job_id'] return await self._forward(sid, PSIJSession.cancel_job, job_id=job_id)
# ───────────────────────────────────────────────────────────────────────── # Edge-job submission with optional reverse SSH tunnel # ─────────────────────────────────────────────────────────────────────────
[docs] async def submit_tunneled(self, request: Request) -> JSONResponse: """Submit a job that starts a new Edge service on a compute node. The job *must* pass ``-n``/``--name <edge_name>`` in its arguments so the child edge service can register under the correct name. When ``tunnel=true`` the plugin automatically appends ``--tunnel`` to the job's argument list. The child edge service reads this flag at startup and waits for a relay port file at the hardcoded path ``~/.radical/edge/tunnels/{edge_name}.port`` before connecting to the bridge. The watcher on the parent edge writes that file once the reverse SSH tunnel is established. When ``tunnel=true`` the plugin: 1. Cleans up any stale relay file from a previous run. 2. Injects ``--tunnel`` into the job arguments. 3. Spawns an async watcher that waits for the SLURM job to reach RUNNING, then opens a reverse SSH tunnel (login → compute) and writes the allocated port to the relay file. Request body JSON fields: - ``job_spec`` (dict) — PsiJ job specification. - ``executor`` (str) — PsiJ executor name (default: ``"local"``). - ``tunnel`` (bool) — Whether to set up a reverse SSH tunnel (default: ``false``). Returns: JSON with ``job_id``, ``native_id``, and ``edge_name``. Raises: 422 if ``-n``/``--name`` is missing from ``job_spec.arguments``. 409 if a tunnel watcher for the same edge name is already active. """ sid = request.path_params['sid'] data = await request.json() job_spec = data.get('job_spec', {}) executor = data.get('executor', 'local') tunnel = bool(data.get('tunnel', False)) # --- resolve edge name from arguments --- args = list(job_spec.get('arguments') or []) edge_name = None for i, a in enumerate(args[:-1]): if a in ('-n', '--name'): edge_name = args[i + 1] break if not edge_name: raise HTTPException( status_code=422, detail="submit_tunneled requires -n/--name <edge_name> in job_spec.arguments") # --- guard against duplicate watchers --- existing = self._watchers.get(edge_name) if existing and not existing.done(): raise HTTPException( status_code=409, detail=f"Tunnel watcher already active for edge '{edge_name}'") # --- prepare relay file and inject --tunnel / logging flags --- relay_file: pathlib.Path | None = None if tunnel: relay_file = _relay_dir() / f'{edge_name}.port' relay_file.unlink(missing_ok=True) # remove stale file from previous run # Inject flags so the child edge (a) waits for the relay port file # and (b) writes DEBUG logs to a shared-filesystem path visible from # the login node. if '--tunnel' not in args: args.append('--tunnel') job_spec = dict(job_spec) job_spec['arguments'] = args resp = await self._forward(sid, PSIJSession.submit_job, job_spec_dict=job_spec, executor_name=executor) if tunnel and relay_file is not None: result = _json.loads(bytes(resp.body)) native_id = result.get('native_id') log.info("[psij] submit_tunneled: edge=%s job_id=%s native_id=%s — starting tunnel watcher", edge_name, result.get('job_id'), native_id) if native_id is None: log.warning("[psij] native_id is None for edge '%s'; watcher will poll " "without a SLURM job ID (PsiJ may not have assigned one yet)", edge_name) task = asyncio.create_task( self._tunnel_watcher(edge_name, native_id, relay_file)) self._watchers[edge_name] = task # Augment response with edge_name for caller convenience body = _json.loads(bytes(resp.body)) body['edge_name'] = edge_name return JSONResponse(body, status_code=resp.status_code)
[docs] async def tunnel_status(self, request: Request) -> JSONResponse: """Return the current tunnel status for a named edge. Path param: ``edge_name`` Returns a JSON object with fields: - ``edge_name`` — echoed back. - ``status`` — one of ``"pending"``, ``"active"``, ``"failed"``, ``"done"``, or ``"no_tunnel"``. - ``port`` — allocated tunnel port (int) once active, else null. - ``pid`` — SSH process PID, once spawned, else null. """ edge_name = request.path_params['edge_name'] relay_file = _relay_dir() / f'{edge_name}.port' pid_file = _relay_dir() / f'{edge_name}.pid' port = None pid = None if relay_file.exists(): try: port = int(relay_file.read_text().strip()) except (ValueError, OSError): pass if pid_file.exists(): try: pid = int(pid_file.read_text().strip()) except (ValueError, OSError): pass task = self._watchers.get(edge_name) proc = self._tunnel_procs.get(edge_name) alive = proc is not None and proc.poll() is None if task is None: # No watcher was ever started status = 'no_tunnel' elif port is not None: # Relay file exists → tunnel successfully reported a port if alive: status = 'active' elif proc is not None: rc = proc.poll() status = 'done' if rc == 0 else 'failed' else: # Port written but proc not tracked (e.g. migrated session) status = 'active' elif task.done(): # Watcher finished but no port was ever written → failed exc = task.exception() if not task.cancelled() else None status = 'failed' if exc else 'failed' else: # Watcher still running, waiting for job to reach RUNNING state status = 'pending' return JSONResponse({'edge_name': edge_name, 'status': status, 'port': port, 'pid': pid})
# ───────────────────────────────────────────────────────────────────────── # Internal tunnel helpers # ───────────────────────────────────────────────────────────────────────── async def _tunnel_watcher(self, edge_name: str, native_id, relay_file: pathlib.Path) -> None: """Watch a SLURM job and spawn a reverse SSH tunnel once it starts. Polls ``squeue`` until the job is RUNNING, then calls ``_spawn_tunnel`` to open the reverse SSH tunnel and write the assigned port to *relay_file*. Args: edge_name: Logical name of the child edge service. native_id: SLURM job ID string/int. relay_file: Path where the tunnel port will be written. """ from .queue_info import QueueInfoSlurm log.info("[psij] Watcher started for edge '%s' (job %s)", edge_name, native_id) # --- wait for job to reach RUNNING --- last_state = None for attempt in range(120): # up to ~4 min (2s × 120) await asyncio.sleep(2) state = await _get_slurm_state(native_id) log.debug("[psij] watcher edge=%s job=%s state=%r attempt=%d", edge_name, native_id, state, attempt) # Log every 60 s at INFO so state is visible without DEBUG logging if state != last_state or attempt % 30 == 0: log.info("[psij] watcher edge=%s job=%s state=%r (attempt %d/120)", edge_name, native_id, state or '(unknown)', attempt) last_state = state if state in ('FAILED', 'CANCELLED', 'TIMEOUT', 'NODE_FAIL', 'PREEMPTED'): log.warning("[psij] Job %s ended with state %s — aborting tunnel", native_id, state) return if state != 'RUNNING': continue # --- job is RUNNING: find its nodes --- nodes = await asyncio.to_thread( QueueInfoSlurm.get_job_nodes, str(native_id)) if not nodes: log.warning("[psij] Job %s RUNNING but no nodes found yet, retrying", native_id) continue compute_node = nodes[0] log.info("[psij] Job %s running on %s — spawning tunnel", native_id, compute_node) # Retry spawn: pam_slurm_adopt rejects SSH until the job's # processes are live, which can lag a few seconds behind squeue # reporting RUNNING. Retry rapidly for up to 15 s. deadline = asyncio.get_event_loop().time() + 15 spawn_attempt = 0 while True: try: await self._spawn_tunnel(compute_node, relay_file, edge_name) return # success except Exception as e: spawn_attempt += 1 remaining = deadline - asyncio.get_event_loop().time() if remaining <= 0: log.error("[psij] Tunnel spawn failed for edge '%s' " "after %d attempts: %s", edge_name, spawn_attempt, e) return log.warning("[psij] Tunnel spawn attempt %d failed, %.0f s left: %s", spawn_attempt, remaining, e) await asyncio.sleep(1) log.warning("[psij] Watcher for edge '%s' timed out waiting for job %s to start", edge_name, native_id) async def _cleanup_tunnels(self) -> None: """Terminate all SSH tunnel processes on shutdown.""" for name, proc in list(self._tunnel_procs.items()): try: proc.terminate() proc.wait(timeout=5) except Exception: try: proc.kill() except Exception: pass log.info("[psij] Terminated tunnel process for edge '%s'", name) self._tunnel_procs.clear() for name, task in list(self._watchers.items()): task.cancel() self._watchers.clear() async def _spawn_tunnel(self, node: str, relay_file: pathlib.Path, edge_name: str) -> None: """Open a reverse SSH tunnel from *this* login node to *node*. Uses ``ssh -R 0:<bridge_host>:<bridge_port> <node> -N`` so the OS assigns a free port. The assigned port is extracted from SSH stderr and written to *relay_file* so the child edge can read it. The SSH process runs detached (new session) so it outlives this coroutine. A PID file is written alongside the relay file. Args: node: Compute node hostname. relay_file: Path where the assigned port number will be written. edge_name: Used only for log messages and the PID file name. """ from urllib.parse import urlparse as _urlparse # Derive bridge host/port from the edge service (authoritative source), # falling back to the RADICAL_BRIDGE_URL env var if not accessible. edge_svc = getattr(self._app.state, 'edge_service', None) svc_url = getattr(edge_svc, '_bridge_url', '') if edge_svc else '' bridge_url = svc_url or os.environ.get('RADICAL_BRIDGE_URL', '') parsed = _urlparse(bridge_url) if bridge_url else None bridge_host = (parsed.hostname or 'localhost') if parsed else 'localhost' bridge_port = (parsed.port or 8000) if parsed else 8000 log.info("[psij] Tunnel bridge target: %s:%s (from url=%r)", bridge_host, bridge_port, bridge_url) ssh_cmd = [ 'ssh', '-N', '-o', 'StrictHostKeyChecking=no', '-o', 'UserKnownHostsFile=/dev/null', '-o', 'BatchMode=yes', '-o', 'ServerAliveInterval=10', '-o', 'ServerAliveCountMax=3', '-o', 'ExitOnForwardFailure=yes', '-R', f'0:{bridge_host}:{bridge_port}', node, ] log.info("[psij] Spawning reverse tunnel: %s", ' '.join(ssh_cmd)) proc = subprocess.Popen( ssh_cmd, stderr=subprocess.PIPE, start_new_session=True, # detach so it survives edge restart ) # Extract the allocated port from SSH stderr. # With -v, SSH prints: "Allocated port N for remote forward to ..." # NOTE: some OpenSSH versions also print an earlier line like # "remote forward success. listening on port 0" # (where 0 is a placeholder before the real port is announced). # We must skip any match of port 0 and keep reading. # This is blocking I/O; run it in a thread so the event loop stays free. def _read_port() -> tuple[int | None, list[str]]: lines: list[str] = [] assert proc.stderr is not None for raw in proc.stderr: line = raw.decode('utf-8', errors='replace').rstrip() lines.append(line) m = re.search(r'[Aa]llocated port (\d+)', line) if m: port = int(m.group(1)) if port > 0: return port, lines if proc.poll() is not None: break return None, lines port, ssh_lines = await asyncio.get_event_loop().run_in_executor(None, _read_port) # Drain SSH stderr in background so the pipe never fills and blocks SSH. if proc.stderr: threading.Thread(target=proc.stderr.read, daemon=True).start() if port is None: rc = proc.poll() tail = '\n'.join(ssh_lines[-20:]) raise RuntimeError( f"SSH tunnel for edge '{edge_name}' did not report a port " f"(exit={rc})\nSSH output (last 20 lines):\n{tail}") log.info("[psij] SSH allocated port %d for edge '%s' tunnel", port, edge_name) # Write rendezvous files and register proc for live status checks relay_file.write_text(str(port)) pid_file = relay_file.with_suffix('.pid') pid_file.write_text(str(proc.pid)) self._tunnel_procs[edge_name] = proc log.info("[psij] Reverse tunnel for edge '%s' active on port %d (pid=%d)", edge_name, port, proc.pid)
async def _get_slurm_state(native_id) -> str: """Return the SLURM state string for *native_id*, or empty string.""" try: proc = await asyncio.create_subprocess_exec( 'squeue', '--job', str(native_id), '--noheader', '--format=%T', stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) stdout, _ = await asyncio.wait_for(proc.communicate(), timeout=10) lines = [l.strip() for l in stdout.decode().splitlines() if l.strip()] return lines[0] if lines else '' except Exception: return ''