mistral-api/queue_helper.py
2025-11-06 14:49:56 +01:00

97 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -------------------------------------------------------------
# queue_helper.py minimalistisch threadbased wachtrijmanager
# -------------------------------------------------------------
import threading, queue, uuid, time
from typing import Callable, Any, Dict
# ------------------------------------------------------------------
# Configuratie pas eventueel aan
# ------------------------------------------------------------------
USER_MAX_QUEUE = 20 # max. wachtende gebruikers
AGENT_MAX_QUEUE = 50 # max. wachtende agents (stil)
UPDATE_INTERVAL = 10.0 # sec tussen “positionupdate” berichten
WORKER_TIMEOUT = 30.0 # max. tijd die een inference mag duren
# ------------------------------------------------------------------
class _Job:
__slots__ = ("job_id","payload","callback","created_at","event","result","error")
def __init__(self, payload: Dict, callback: Callable[[Dict], None]):
self.job_id = str(uuid.uuid4())
self.payload = payload
self.callback = callback
self.created_at = time.time()
self.event = threading.Event()
self.result = None
self.error = None
def set_success(self, answer: Dict):
self.result = answer
self.event.set()
self.callback(answer)
def set_error(self, exc: Exception):
self.error = str(exc)
self.event.set()
self.callback({"error": self.error})
class QueueManager:
"""Beheert één gedeelde queue + één workerthread."""
def __init__(self, model_infer_fn: Callable[[Dict], Dict]):
self._infer_fn = model_infer_fn
self._user_q = queue.Queue(maxsize=USER_MAX_QUEUE)
self._agent_q = queue.Queue(maxsize=AGENT_MAX_QUEUE)
self._shutdown = threading.Event()
self._worker = threading.Thread(target=self._run_worker,
daemon=True,
name="LLMworker")
self._worker.start()
# ---------- public API ----------
def enqueue_user(self, payload: Dict, progress_cb: Callable[[Dict], None]) -> tuple[str, int]:
job = _Job(payload, progress_cb)
try: self._user_q.put_nowait(job)
except queue.Full: raise RuntimeError(f"Userqueue vol (≥{USER_MAX_QUEUE})")
position = self._user_q.qsize()
return job.job_id, position
def enqueue_agent(self, payload: Dict, progress_cb: Callable[[Dict], None]) -> str:
job = _Job(payload, progress_cb)
try: self._agent_q.put_nowait(job)
except queue.Full: raise RuntimeError(f"Agentqueue vol (≥{AGENT_MAX_QUEUE})")
return job.job_id
# ---------- worker ----------
def _run_worker(self):
while not self._shutdown.is_set():
job = self._pop_job(self._user_q) or self._pop_job(self._agent_q)
if not job:
time.sleep(0.1)
continue
try:
answer = self._infer_fn(job.payload)
job.set_success(answer)
except Exception as exc:
job.set_error(exc)
def _pop_job(self, q: queue.Queue):
try: return q.get_nowait()
except queue.Empty: return None
def stop(self):
self._shutdown.set()
self._worker.join(timeout=5)
def start_position_notifier(job: _Job,
queue_ref: queue.Queue,
interval: float = UPDATE_INTERVAL):
"""Stuurt elke `interval` seconden een bericht met de huidige positie."""
def _notifier():
while not job.event.is_set():
try:
pos = list(queue_ref.queue).index(job) + 1 # 1based
except ValueError:
break
job.callback({"info": f"U bent #{pos} in de wachtrij. Even geduld…" })
time.sleep(interval)
t = threading.Thread(target=_notifier, daemon=True)
t.start()
return t