import asyncio import hashlib import json import os import time import uuid from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple import httpx from fastapi import Body, FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse def env_bool(key: str, default: bool = False) -> bool: v = os.getenv(key) if v is None: return default return v.strip().lower() in {"1", "true", "yes", "on"} def env_int(key: str, default: int) -> int: v = os.getenv(key) if v is None: return default try: return int(v) except ValueError: return default def now_ts() -> float: return time.time() def job_id() -> str: return uuid.uuid4().hex def sha1(text: str) -> str: return hashlib.sha1(text.encode("utf-8", errors="ignore")).hexdigest() @dataclass class ProxyConfig: upstreams: List[str] sticky_header: str = "X-Chat-Id" affinity_ttl_sec: int = 60 queue_notify_user: str = "auto" # auto|always|never queue_notify_min_ms: int = 1200 read_timeout_sec: int = 3600 toolserver_url: Optional[str] = None # for later text_toolcall_detect: bool = False def load_config() -> ProxyConfig: ups = (os.getenv("LLM_UPSTREAMS") or "").strip() if not ups: raise RuntimeError("LLM_UPSTREAMS is required (comma-separated URLs)") upstreams = [u.strip() for u in ups.split(",") if u.strip()] return ProxyConfig( upstreams=upstreams, sticky_header=(os.getenv("STICKY_HEADER") or "X-Chat-Id").strip() or "X-Chat-Id", affinity_ttl_sec=env_int("AFFINITY_TTL_SEC", 60), queue_notify_user=(os.getenv("QUEUE_NOTIFY_USER") or "auto").strip().lower(), queue_notify_min_ms=env_int("QUEUE_NOTIFY_MIN_MS", 1200), read_timeout_sec=env_int("LLM_READ_TIMEOUT", 3600), toolserver_url=(os.getenv("TOOLSERVER_URL") or "").strip() or None, text_toolcall_detect=env_bool("TEXT_TOOLCALL_DETECT", False), ) @dataclass class Job: job_id: str created_ts: float kind: str # user_chat | agent_call stream: bool body: Dict[str, Any] headers: Dict[str, str] thread_key: str assigned_worker: int status: str = "queued" # queued|running|done|error error: Optional[str] = None result: Optional[Dict[str, Any]] = None # For waiting results done_fut: asyncio.Future = field(default_factory=asyncio.Future) stream_q: asyncio.Queue = field(default_factory=asyncio.Queue) saw_any_output: bool = False class Worker: def __init__(self, idx: int, upstream_url: str, client: httpx.AsyncClient): self.idx = idx self.upstream_url = upstream_url self.client = client self.user_q: asyncio.Queue[Job] = asyncio.Queue() self.agent_q: asyncio.Queue[Job] = asyncio.Queue() self.current_job_id: Optional[str] = None self.task: Optional[asyncio.Task] = None def pending_count(self) -> int: return self.user_q.qsize() + self.agent_q.qsize() + (1 if self.current_job_id else 0) async def start(self, state: "ProxyState") -> None: self.task = asyncio.create_task(self._run(state), name=f"worker-{self.idx}") async def _run(self, state: "ProxyState") -> None: while True: if not self.user_q.empty(): job = await self.user_q.get() else: job = await self.agent_q.get() self.current_job_id = job.job_id job.status = "running" try: if job.stream: await state.handle_stream_job(job, self) else: await state.handle_non_stream_job(job, self) except Exception as e: job.status = "error" job.error = str(e) if not job.done_fut.done(): job.done_fut.set_exception(e) await job.stream_q.put(None) finally: self.current_job_id = None try: if job.kind == "user_chat": self.user_q.task_done() else: self.agent_q.task_done() except Exception: pass def get_header_any(headers: Dict[str, str], names: List[str]) -> Optional[str]: for n in names: v = headers.get(n) if v: return v return None def infer_kind(headers: Dict[str, str]) -> str: jk = (headers.get("X-Job-Kind") or "").strip().lower() if jk in {"agent", "agent_call", "repo_agent"}: return "agent_call" return "user_chat" def infer_thread_key(cfg: ProxyConfig, headers: Dict[str, str], body: Dict[str, Any]) -> str: v = get_header_any( headers, [cfg.sticky_header, "X-OpenWebUI-Chat-Id", "X-Chat-Id", "X-Conversation-Id"], ) if v: return v seed = { "model": (body.get("model") or "").strip(), "messages": (body.get("messages") or [])[:2], "user": body.get("user"), } try: txt = json.dumps(seed, sort_keys=True, ensure_ascii=False) except Exception: txt = str(seed) return "h:" + sha1(txt) class ProxyState: def __init__(self, cfg: ProxyConfig): self.cfg = cfg self.http = httpx.AsyncClient(timeout=httpx.Timeout(cfg.read_timeout_sec)) self.workers: List[Worker] = [Worker(i, u, self.http) for i, u in enumerate(cfg.upstreams)] self.affinity: Dict[str, Tuple[int, float]] = {} self.jobs: Dict[str, Job] = {} async def start(self) -> None: for w in self.workers: await w.start(self) async def close(self) -> None: await self.http.aclose() def pick_worker(self, thread_key: str) -> int: now = now_ts() sticky = self.affinity.get(thread_key) if sticky: widx, last = sticky if now - last <= self.cfg.affinity_ttl_sec and 0 <= widx < len(self.workers): self.affinity[thread_key] = (widx, now) return widx best = 0 best_load = None for w in self.workers: load = w.pending_count() if best_load is None or load < best_load: best_load = load best = w.idx self.affinity[thread_key] = (best, now) return best def enqueue(self, job: Job) -> None: w = self.workers[job.assigned_worker] if job.kind == "user_chat": w.user_q.put_nowait(job) else: w.agent_q.put_nowait(job) self.jobs[job.job_id] = job def sse_pack(obj: Dict[str, Any]) -> bytes: return ("data: " + json.dumps(obj, ensure_ascii=False) + "\n\n").encode("utf-8") def sse_done() -> bytes: return b"data: [DONE]\n\n" def make_chunk(job_id: str, model: str, delta: Dict[str, Any], finish_reason: Optional[str] = None) -> Dict[str, Any]: return { "id": job_id, "object": "chat.completion.chunk", "created": int(now_ts()), "model": model, "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], } def looks_like_toolcalls_text(txt: str) -> bool: s = (txt or "").lstrip() return s.startswith("[TOOL_CALLS]") or s.startswith("{\"tool_calls\"") def strip_toolcalls_text(txt: str) -> str: return "(tool-call output suppressed; tools not enabled in proxy-only MVP)" def capture_tool_calls_from_delta(delta: Dict[str, Any], acc: List[Dict[str, Any]]) -> None: tcs = delta.get("tool_calls") if isinstance(tcs, list): acc.extend(tcs) class ProxyState(ProxyState): # type: ignore def queue_position(self, job: Job) -> int: w = self.workers[job.assigned_worker] if job.kind == "user_chat": return w.user_q.qsize() return w.agent_q.qsize() async def handle_non_stream_job(self, job: Job, worker: Worker) -> None: body = dict(job.body) body["stream"] = False r = await self.http.post(worker.upstream_url, json=body) r.raise_for_status() data = r.json() if self.cfg.text_toolcall_detect: try: msg = data.get("choices", [{}])[0].get("message", {}) content = msg.get("content") if isinstance(content, str) and looks_like_toolcalls_text(content): msg["content"] = strip_toolcalls_text(content) except Exception: pass job.result = data job.status = "done" if not job.done_fut.done(): job.done_fut.set_result(data) async def handle_stream_job(self, job: Job, worker: Worker) -> None: body = dict(job.body) body["stream"] = True model = (body.get("model") or "").strip() or "unknown" tool_calls: List[Dict[str, Any]] = [] async with self.http.stream("POST", worker.upstream_url, json=body) as r: r.raise_for_status() async for line in r.aiter_lines(): if not line: continue if not line.startswith("data:"): continue payload = line[len("data:"):].strip() if payload == "[DONE]": break try: obj = json.loads(payload) except Exception: # non-json chunk; pass through await job.stream_q.put((line + "\n\n").encode("utf-8")) job.saw_any_output = True continue choice0 = (obj.get("choices") or [{}])[0] delta = choice0.get("delta") or {} if "tool_calls" in delta: capture_tool_calls_from_delta(delta, tool_calls) continue if self.cfg.text_toolcall_detect: c = delta.get("content") if isinstance(c, str) and looks_like_toolcalls_text(c): continue await job.stream_q.put((line + "\n\n").encode("utf-8")) job.saw_any_output = True if tool_calls and not job.saw_any_output: msg = "(tool-call requested but tools are not enabled yet in the proxy-only MVP)" await job.stream_q.put(sse_pack(make_chunk(job.job_id, model, {"role": "assistant", "content": msg}))) await job.stream_q.put(sse_pack(make_chunk(job.job_id, model, {}, finish_reason="stop"))) await job.stream_q.put(sse_done()) job.status = "done" if not job.done_fut.done(): job.done_fut.set_result({"status": "streamed"}) app = FastAPI(title="QueueGate Proxy", version="0.1.0") CFG: Optional[ProxyConfig] = None STATE: Optional[ProxyState] = None @app.on_event("startup") async def _startup() -> None: global CFG, STATE CFG = load_config() STATE = ProxyState(CFG) await STATE.start() @app.on_event("shutdown") async def _shutdown() -> None: global STATE if STATE: await STATE.close() @app.get("/healthz") async def healthz() -> Dict[str, Any]: if not STATE: raise HTTPException(status_code=503, detail="not ready") return { "ok": True, "upstreams": len(STATE.workers), "workers": [ { "id": w.idx, "pending": w.pending_count(), "busy": bool(w.current_job_id), "upstream": w.upstream_url, } for w in STATE.workers ], } @app.get("/v1/jobs/{job_id}") async def job_status(job_id: str) -> Dict[str, Any]: if not STATE: raise HTTPException(status_code=503, detail="not ready") job = STATE.jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="unknown job") return { "job_id": job.job_id, "status": job.status, "kind": job.kind, "created_ts": job.created_ts, "worker": job.assigned_worker, "queue_position_est": STATE.queue_position(job) if job.status == "queued" else 0, "error": job.error, } def _require_state() -> ProxyState: if not STATE: raise HTTPException(status_code=503, detail="not ready") return STATE def _copy_headers(req: Request) -> Dict[str, str]: # keep it simple: copy headers for sticky routing + internal kind keep = { "X-Job-Kind", "X-Chat-Id", "X-Conversation-Id", "X-OpenWebUI-Chat-Id", } out: Dict[str, str] = {} for k, v in req.headers.items(): if k in keep or k.lower() == (CFG.sticky_header.lower() if CFG else "").lower(): out[k] = v return out @app.post("/v1/chat/completions") async def chat_completions(body: Dict[str, Any] = Body(...), request: Request = None): state = _require_state() cfg = state.cfg stream = bool(body.get("stream", False)) headers = _copy_headers(request) kind = infer_kind(headers) thread_key = infer_thread_key(cfg, headers, body) worker_idx = state.pick_worker(thread_key) jid = job_id() job = Job( job_id=jid, created_ts=now_ts(), kind=kind, stream=stream, body=body, headers=headers, thread_key=thread_key, assigned_worker=worker_idx, ) state.enqueue(job) if not stream: # wait for result try: data = await job.done_fut except Exception as e: raise HTTPException(status_code=502, detail=f"upstream error: {e}") return JSONResponse(content=data) # streaming async def gen() -> Any: # Optional: queue notice (user jobs only) if job.kind == "user_chat" and cfg.queue_notify_user != "never": t0 = now_ts() # wait until running or until threshold while job.status == "queued" and (now_ts() - t0) * 1000 < cfg.queue_notify_min_ms: await asyncio.sleep(0.05) if job.status == "queued" and cfg.queue_notify_user in {"auto", "always"}: pos = state.queue_position(job) model = (body.get("model") or "unknown").strip() or "unknown" msg = f"⏳ In wachtrij (positie ~{pos})…" yield sse_pack(make_chunk(job.job_id, model, {"role": "assistant", "content": msg})) while True: item = await job.stream_q.get() if item is None: break yield item return StreamingResponse(gen(), media_type="text/event-stream")