toolserver/queue_helper.py
2026-02-23 15:59:34 +01:00

138 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -------------------------------------------------------------
# queue_helper.py minimalistisch threadbased wachtrijmanager
# -------------------------------------------------------------
import threading, queue, uuid, time
from typing import Callable, Any, Dict
# ------------------------------------------------------------------
# Configuratie pas eventueel aan
# ------------------------------------------------------------------
USER_MAX_QUEUE = 20 # max. wachtende gebruikers
AGENT_MAX_QUEUE = 50 # max. wachtende agents (stil)
UPDATE_INTERVAL = 10.0 # sec tussen “positionupdate” berichten
WORKER_TIMEOUT = 30.0 # max. tijd die een inference mag duren
# ------------------------------------------------------------------
class _Job:
__slots__ = ("job_id","payload","callback","created_at","event","result","error")
def __init__(self, payload: Dict, callback: Callable[[Dict], None]):
self.job_id = str(uuid.uuid4())
self.payload = payload
self.callback = callback
self.created_at = time.time()
self.event = threading.Event()
self.result = None
self.error = None
def set_success(self, answer: Dict):
self.result = answer
self.event.set()
self.callback(answer)
def set_error(self, exc: Exception):
self.error = str(exc)
self.event.set()
self.callback({"error": self.error})
class QueueManager:
"""Beheert één gedeelde queue + één workerthread."""
def __init__(self, model_infer_fn: Callable[[Dict], Dict]):
self._infer_fn = model_infer_fn
self._user_q = queue.Queue(maxsize=USER_MAX_QUEUE)
self._agent_q = queue.Queue(maxsize=AGENT_MAX_QUEUE)
self._shutdown = threading.Event()
self._worker = threading.Thread(target=self._run_worker,
daemon=True,
name="LLMworker")
self._worker.start()
# ---------- public API ----------
def enqueue_user(
self,
payload: Dict,
progress_cb: Callable[[Dict], None],
*,
notify_position: bool = False,
) -> tuple[str, int]:
job = _Job(payload, progress_cb)
try: self._user_q.put_nowait(job)
except queue.Full: raise RuntimeError(f"Userqueue vol (≥{USER_MAX_QUEUE})")
position = self._user_q.qsize()
if notify_position:
# start een aparte notifier-thread die periodiek de wachtrijpositie meldt
start_position_notifier(job, self._user_q)
return job.job_id, position
def enqueue_agent(self, payload: Dict, progress_cb: Callable[[Dict], None]) -> str:
job = _Job(payload, progress_cb)
try: self._agent_q.put_nowait(job)
except queue.Full: raise RuntimeError(f"Agentqueue vol (≥{AGENT_MAX_QUEUE})")
return job.job_id
# ---------- sync helper voor agents/tools ----------
def request_agent_sync(self, payload: Dict, timeout: float = WORKER_TIMEOUT) -> Dict:
"""
Gebruik dit voor interne calls (agents/tools).
- Job wordt in de agent-queue gezet (lagere prioriteit dan users).
- We wachten blokkerend tot de worker klaar is of tot timeout.
- Er worden GEEN wachtrij-meldingen ("U bent #...") verstuurd.
"""
result_box: Dict[str, Any] = {}
def _cb(msg: Dict):
# alleen het eindresultaat is interessant voor tools/agents
result_box["answer"] = msg
job = _Job(payload, _cb)
try:
self._agent_q.put_nowait(job)
except queue.Full:
raise RuntimeError(f"Agent-queue vol (≥{AGENT_MAX_QUEUE})")
ok = job.event.wait(timeout)
if not ok:
raise TimeoutError(f"LLM-inference duurde langer dan {timeout} seconden.")
if job.error:
raise RuntimeError(job.error)
return result_box.get("answer") or {}
# ---------- worker ----------
def _run_worker(self):
while not self._shutdown.is_set():
job = self._pop_job(self._user_q) or self._pop_job(self._agent_q)
if not job:
time.sleep(0.1)
continue
try:
answer = self._infer_fn(job.payload)
job.set_success(answer)
except Exception as exc:
job.set_error(exc)
def _pop_job(self, q: queue.Queue):
try: return q.get_nowait()
except queue.Empty: return None
def stop(self):
self._shutdown.set()
self._worker.join(timeout=5)
def start_position_notifier(
job: _Job,
queue_ref: queue.Queue,
interval: float = UPDATE_INTERVAL,
):
"""Stuurt elke `interval` seconden een bericht met de huidige positie."""
def _notifier():
# Stop zodra het job-event wordt gezet (success/fout/timeout upstream)
while not job.event.wait(interval):
# Neem een snapshot van de queue-inhoud op een thread-safe manier
with queue_ref.mutex:
snapshot = list(queue_ref.queue)
try:
pos = snapshot.index(job) + 1 # 1-based
except ValueError:
# Job staat niet meer in de wachtrij → geen updates meer nodig
break
job.callback({"info": f"U bent #{pos} in de wachtrij. Even geduld…" })
t = threading.Thread(target=_notifier, daemon=True)
t.start()
return t