From bdc56693a9c710454b2a91596dc2d85812c0109e Mon Sep 17 00:00:00 2001 From: admin Date: Thu, 6 Nov 2025 14:04:06 +0100 Subject: [PATCH] fix stream cap of 1024 tokens --- app.py | 3469 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 3469 insertions(+) create mode 100644 app.py diff --git a/app.py b/app.py new file mode 100644 index 0000000..d44c1ef --- /dev/null +++ b/app.py @@ -0,0 +1,3469 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +FastAPI bridge met: +- RAG (Chroma) index & query (client-side embeddings, http óf local) +- Repo cloning/updating (git) +- LLM-bridge (OpenAI-compatible /v1/chat/completions) +- Repo agent endpoints (injectie van helpers in agent_repo.py) +""" + +from __future__ import annotations +import contextlib +from contextlib import contextmanager +import os, re, json, time, uuid, hashlib, logging, asyncio, fnmatch, threading +from dataclasses import dataclass +from typing import List, Dict, Optional, Union, Any +from pathlib import Path +from io import BytesIO + +import requests +import httpx +import chromadb +import git +import base64 + +from fastapi import FastAPI, APIRouter, UploadFile, File, Form, Request, HTTPException, Body +from fastapi.responses import JSONResponse, StreamingResponse, PlainTextResponse +from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.utils import get_openapi +from fastapi.routing import APIRoute +from starlette.concurrency import run_in_threadpool +from pydantic import BaseModel + + + +# Optionele libs voor tekst-extractie +try: + import PyPDF2 +except Exception: + PyPDF2 = None +try: + import docx # python-docx +except Exception: + docx = None +try: + import pandas as pd +except Exception: + pd = None +try: + from pptx import Presentation +except Exception: + Presentation = None + +# (optioneel) BM25 voor hybride retrieval +try: + from rank_bm25 import BM25Okapi +except Exception: + BM25Okapi = None + +# --- BM25 fallback registry (per repo) --- +_BM25_BY_REPO: dict[str, tuple[object, list[dict]]] = {} # repo_key -> (bm25, docs) +def _bm25_tok(s: str) -> list[str]: + return re.findall(r"[A-Za-z0-9_]+", s.lower()) + + +# --- Extra optional libs voor audio/vision/images --- +try: + import cairosvg +except Exception: + cairosvg = None + +import tempfile, subprocess # voor audio + +# STT (Whisper-compatible via faster-whisper) — optioneel +_STT_MODEL = None +STT_MODEL_NAME = os.getenv("STT_MODEL", "small") +STT_DEVICE = os.getenv("STT_DEVICE", "auto") # "auto" | "cuda" | "cpu" + +# TTS (piper) — optioneel +PIPER_BIN = os.getenv("PIPER_BIN", "/usr/bin/piper") +PIPER_VOICE = os.getenv("PIPER_VOICE", "") + +RAG_LLM_RERANK = os.getenv("RAG_LLM_RERANK", "0").lower() in ("1","true","yes") + +# Lazy LLM import & autodetect +_openai = None +_mistral = None + + + +# Token utility / conversatie window (bestaat in je project) +try: + from windowing_utils import ( + derive_thread_id, SUMMARY_STORE, ConversationWindow, + approx_token_count, count_message_tokens + ) +except Exception: + # Fallbacks zodat dit bestand standalone blijft werken + SUMMARY_STORE = {} + def approx_token_count(s: str) -> int: + return max(1, len(s) // 4) + def count_message_tokens(messages: List[dict]) -> int: + return sum(approx_token_count(m.get("content","")) for m in messages) + def derive_thread_id(messages: List[dict]) -> str: + payload = (messages[0].get("content","") if messages else "") + "|".join(m.get("role","") for m in messages) + return hashlib.sha256(payload.encode("utf-8", errors="ignore")).hexdigest()[:16] + class ConversationWindow: + def __init__(self, *a, **k): pass + +# Queue helper (optioneel aanwezig in je project) +try: + from queue_helper import QueueManager, start_position_notifier, _Job + from queue_helper import USER_MAX_QUEUE, AGENT_MAX_QUEUE, UPDATE_INTERVAL, WORKER_TIMEOUT +except Exception: + QueueManager = None + +# Smart_rag import +from smart_rag import enrich_intent, expand_queries, hybrid_retrieve, assemble_context + +# Repo-agent (wij injecteren functies hieronder) +from agent_repo import initialize_agent, handle_repo_agent, repo_qa_answer, rag_query_internal_fn + +try: + from agent_repo import smart_chunk_text # al aanwezig in jouw agent_repo +except Exception: + def smart_chunk_text(text: str, path_hint: str, target_chars: int = 1800, + hard_max: int = 2600, min_chunk: int = 800): + # simpele fallback + chunks = [] + i, n = 0, len(text) + step = max(1, target_chars - 200) + while i < n: + chunks.append(text[i:i+target_chars]) + i += step + return chunks + + +def _build_tools_system_prompt(tools: list) -> str: + lines = [ + "You can call functions. When a function is needed, answer ONLY with a JSON object:", + '{"tool_calls":[{"name":"","arguments":{...}}]}', + "No prose. Arguments MUST be valid JSON for the function schema." + ] + for t in tools: + fn = t.get("function", {}) + lines.append(f"- {fn.get('name')}: {fn.get('description','')}") + return "\n".join(lines) + +def _extract_tool_calls_from_text(txt: str): + # tolerant: pak eerste JSON object (ook als het in ```json staat) + m = re.search(r"\{[\s\S]*\}", txt or "") + if not m: + return [] + try: + obj = json.loads(m.group(0)) + except Exception: + return [] + tc = obj.get("tool_calls") or [] + out = [] + for c in tc: + name = (c or {}).get("name") + args = (c or {}).get("arguments") or {} + if isinstance(args, str): + try: + args = json.loads(args) + except Exception: + args = {} + if name: + out.append({ + "id": f"call_{uuid.uuid4().hex[:8]}", + "type": "function", + "function": { + "name": name, + "arguments": json.dumps(args, ensure_ascii=False) + } + }) + return out + +# ----------------------------------------------------------------------------- +# App & logging +# ----------------------------------------------------------------------------- +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("app") + +def _unique_id(route: APIRoute): + # unieke operation_id op basis van naam, pad en method + method = list(route.methods)[0].lower() if route.methods else "get" + return f"{route.name}_{route.path.replace('/', '_')}_{method}" + +app = FastAPI(title="Mistral Bridge API",generate_unique_id_function=_unique_id) +app.add_middleware( + CORSMiddleware, + allow_credentials=True, + allow_origins=["*"], # dev only + allow_methods=["*"], + allow_headers=["*"], +) + +@app.on_event("startup") +async def _startup(): + # Zorg dat lokale hosts nooit via een proxy gaan + os.environ.setdefault( + "NO_PROXY", + "localhost,127.0.0.1,::1,host.docker.internal" + ) + app.state.HTTPX = httpx.AsyncClient( + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + timeout=httpx.Timeout(LLM_READ_TIMEOUT, connect=LLM_CONNECT_TIMEOUT), + trust_env=False, # belangrijk: negeer env-proxy’s voor LLM + headers={"Connection": "keep-alive"} # houd verbindingen warm + ) + app.state.HTTPX_PROXY = httpx.AsyncClient( + limits=httpx.Limits(max_connections=100, max_keepalive_connections=20), + timeout=httpx.Timeout(LLM_READ_TIMEOUT, connect=LLM_CONNECT_TIMEOUT), + trust_env=True, # belangrijk: negeer env-proxy’s voor LLM + headers={"Connection": "keep-alive"} # houd verbindingen warm + ) + +@app.on_event("shutdown") +async def _shutdown(): + try: + await app.state.HTTPX.aclose() + except Exception: + pass + try: + await app.state.HTTPX_PROXY.aclose() + except Exception: + pass + + +# --- Globale LLM-concurrency & wachtrij (serieel by default) --- +LLM_MAX_CONCURRENCY = int(os.getenv("LLM_MAX_CONCURRENCY", os.getenv("LLM_CONCURRENCY", "1"))) + +if not hasattr(app.state, "LLM_SEM"): + import asyncio + app.state.LLM_SEM = asyncio.Semaphore(max(1, LLM_MAX_CONCURRENCY)) +if not hasattr(app.state, "LLM_QUEUE"): + from collections import deque + app.state.LLM_QUEUE = deque() + + +@app.middleware("http") +async def log_requests(request: Request, call_next): + logger.info("➡️ %s %s", request.method, request.url.path) + response = await call_next(request) + logger.info("⬅️ %s", response.status_code) + return response + +# ----------------------------------------------------------------------------- +# Config +# ----------------------------------------------------------------------------- +MISTRAL_MODE = os.getenv("MISTRAL_MODE", "v1").lower() +LLM_URL = os.getenv("LLM_URL", "http://localhost:8000/v1/chat/completions").strip() +RAW_URL = os.getenv("MISTRAL_URL_RAW", "http://host.docker.internal:8000/completion").strip() +LLM_CONNECT_TIMEOUT = float(os.getenv("LLM_CONNECT_TIMEOUT", "10")) +LLM_READ_TIMEOUT = float(os.getenv("LLM_READ_TIMEOUT", "1200")) + +_UPSTREAM_URLS = [u.strip() for u in os.getenv("LLM_UPSTREAMS","").split(",") if u.strip()] + +# ==== Meilisearch (optioneel) ==== +MEILI_URL = os.getenv("MEILI_URL", "http://localhost:7700").rstrip("/") +MEILI_API_KEY = os.getenv("MEILI_API_KEY", "0xipOmfgi_zMgdFplSdv7L8mlx0RPMQCNxVTNJc54lQ") +MEILI_INDEX = os.getenv("MEILI_INDEX", "code_chunks") +MEILI_ENABLED = bool(MEILI_URL) + +# Repo summaries (cache on demand) +_SUMMARY_DIR = os.path.join("/rag_db", "repo_summaries") +os.makedirs(_SUMMARY_DIR, exist_ok=True) + +@dataclass +class _Upstream: + url: str + active: int = 0 + ok: bool = True + +_UPS = [_Upstream(u) for u in _UPSTREAM_URLS] if _UPSTREAM_URLS else [] + +def _pick_upstream_sticky(key: str) -> _Upstream | None: + if not _UPS: + return None + try: + h = int(hashlib.sha1(key.encode("utf-8")).hexdigest(), 16) + except Exception: + h = 0 + idx = h % len(_UPS) + cand = _UPS[idx] + if cand.ok: + cand.active += 1 + return cand + ok_list = [u for u in _UPS if u.ok] + best = min(ok_list or _UPS, key=lambda x: x.active) + best.active += 1 + return best + + +def _pick_upstream() -> _Upstream | None: + if not _UPS: + return None + # kies de minst-belaste die ok is, anders toch de minst-belaste + ok_list = [u for u in _UPS if u.ok] + cand = min(ok_list or _UPS, key=lambda x: x.active) + cand.active += 1 + return cand + +def _release_upstream(u: _Upstream | None, bad: bool = False): + if not u: + return + u.active = max(0, u.active - 1) + if bad: + u.ok = False + # simpele cool-off: na 2 sec weer ok + asyncio.create_task(_mark_ok_later(u, 10.0)) + +async def _mark_ok_later(u: _Upstream, delay: float): + await asyncio.sleep(delay) + u.ok = True + +SYSTEM_PROMPT = ( + "Je bent een expert programmeringsassistent. Je geeft accurate, specifieke antwoorden zonder hallucinaties.\n" + "Voor code base analyse:\n" + "1. Geef eerst een samenvatting van de functionaliteit\n" + "2. Identificeer mogelijke problemen of verbeterpunten\n" + "3. Geef concrete aanbevelingen voor correcties\n" + "4. Voor verbeterde versies: geef eerst de toelichting, dan alleen het codeblok" +) + +# ==== Chroma configuratie (local/http) ==== +CHROMA_MODE = os.getenv("CHROMA_MODE", "local").lower() # "local" | "http" +CHROMA_PATH = os.getenv("CHROMA_PATH", "/rag_db") +CHROMA_HOST = os.getenv("CHROMA_HOST", "chroma") +CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8005")) + +# ==== Celery (optioneel, voor async indexing) ==== +CELERY_ENABLED = os.getenv("CELERY_ENABLED", "0").lower() in ("1", "true", "yes") +celery_app = None +if CELERY_ENABLED: + try: + from celery import Celery + celery_app = Celery( + "agent_tasks", + broker=os.getenv("AMQP_URL", "amqp://guest:guest@rabbitmq:5672//"), + backend=os.getenv("CELERY_BACKEND", "redis://redis:6379/0"), + ) + celery_app.conf.update( + task_acks_late=True, + worker_prefetch_multiplier=1, + task_time_limit=60*60, + ) + except Exception as e: + logger.warning("Celery init failed (fallback naar sync): %s", e) + celery_app = None + +# Git / repos +GITEA_URL = os.environ.get("GITEA_URL", "http://localhost:3080").rstrip("/") +REPO_PATH = os.environ.get("REPO_PATH", "/tmp/repos") + +# ----------------------------------------------------------------------------- +# Models (Pydantic) +# ----------------------------------------------------------------------------- +class ChatMessage(BaseModel): + role: str + content: str + +class ChatRequest(BaseModel): + messages: List[ChatMessage] + +class RepoQARequest(BaseModel): + repo_hint: str + question: str + branch: str = "main" + n_ctx: int = 8 + +# ----------------------------------------------------------------------------- +# Embeddings (SentenceTransformers / ONNX / Default) +# ----------------------------------------------------------------------------- +from chromadb.api.types import EmbeddingFunction, Documents, Embeddings +try: + from sentence_transformers import SentenceTransformer +except Exception: + SentenceTransformer = None + +@dataclass +class _Embedder: + slug: str + family: str + model: object + device: str = "cpu" + + def _encode(self, texts: list[str]) -> list[list[float]]: + if hasattr(self.model, "encode"): + bs = int(os.getenv("RAG_EMBED_BATCH_SIZE", "64")) + # geen progressbar; grotere batches voor throughput + return self.model.encode( + texts, + normalize_embeddings=True, + batch_size=bs, + show_progress_bar=False + ).tolist() + return self.model(texts) + + def embed_documents(self, docs: list[str]) -> list[list[float]]: + if self.family == "e5": + docs = [f"passage: {t}" for t in docs] + return self._encode(docs) + + def embed_query(self, q: str) -> list[float]: + if self.family == "e5": + q = f"query: {q}" + return self._encode([q])[0] + +def _build_embedder() -> _Embedder: + import inspect + # voorkom tokenizer thread-oversubscription + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + choice = os.getenv("RAG_EMBEDDINGS", "gte-multilingual").lower().strip() + if SentenceTransformer: + mapping = { + "gte-multilingual": ("Alibaba-NLP/gte-multilingual-base", "gte", "gte-multilingual"), + "bge-small": ("BAAI/bge-small-en-v1.5", "bge", "bge-small"), + "e5-small": ("intfloat/e5-small-v2", "e5", "e5-small"), + "gte-base-en": ("thenlper/gte-base", "gte", "gte-english"), + } + if choice not in mapping: + model_name, family, slug = mapping["bge-small"] + else: + model_name, family, slug = mapping[choice] + st_kwargs = {"device": "cpu"} + try: + if "trust_remote_code" in inspect.signature(SentenceTransformer).parameters: + st_kwargs["trust_remote_code"] = True + model = SentenceTransformer(model_name, **st_kwargs) + # optioneel: CPU thread-telling forceren + try: + thr = int(os.getenv("RAG_TORCH_THREADS", "0")) + if thr > 0: + import torch + torch.set_num_threads(thr) + except Exception: + pass + return _Embedder(slug=slug, family=family, model=model, device="cpu") + except Exception: + pass + + # Fallback via Chroma embedding functions + from chromadb.utils import embedding_functions as ef + try: + onnx = ef.ONNXMiniLM_L6_V2() + slug, family = "onnx-minilm", "minilm" + except Exception: + onnx = ef.DefaultEmbeddingFunction() + slug, family = "default", "minilm" + + class _OnnxWrapper: + def __init__(self, fun): self.fun = fun + def __call__(self, texts): return self.fun(texts) + return _Embedder(slug=slug, family=family, model=_OnnxWrapper(onnx)) + +_EMBEDDER = _build_embedder() + +class _ChromaEF(EmbeddingFunction): + """Alleen gebruikt bij local PersistentClient (niet bij http); naam wordt geborgd.""" + def __init__(self, embedder: _Embedder): + self._embedder = embedder + def __call__(self, input: Documents) -> Embeddings: + return self._embedder.embed_documents(list(input)) + def name(self) -> str: + return f"rag-ef:{self._embedder.family}:{self._embedder.slug}" + +_CHROMA_EF = _ChromaEF(_EMBEDDER) + +# === Chroma client (local of http) === +if CHROMA_MODE == "http": + _CHROMA = chromadb.HttpClient(host=CHROMA_HOST, port=CHROMA_PORT) +else: + _CHROMA = chromadb.PersistentClient(path=CHROMA_PATH) + +def _collection_versioned(base: str) -> str: + ver = os.getenv("RAG_INDEX_VERSION", "3") + return f"{base}__{_EMBEDDER.slug}__v{ver}" + +_COLLECTIONS: dict[str, any] = {} + +def _get_collection(base: str): + """Haalt collection op; bij local voegt embedding_function toe, bij http niet (embedden doen we client-side).""" + name = _collection_versioned(base) + if name not in _COLLECTIONS: + if CHROMA_MODE == "http": + _COLLECTIONS[name] = _CHROMA.get_or_create_collection(name=name) + else: + _COLLECTIONS[name] = _CHROMA.get_or_create_collection(name=name, embedding_function=_CHROMA_EF) + return _COLLECTIONS[name] + +def _collection_add(collection, documents: list[str], metadatas: list[dict], ids: list[str]): + """Altijd embeddings client-side meezenden — werkt voor local én http. + Voeg een lichte header toe (pad/taal/symbolen) om retrieval te verbeteren. + """ + def _symbol_hints(txt: str) -> list[str]: + hints = [] + # heel simpele, taalonafhankelijke patterns + for pat in [r"def\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", + r"class\s+([A-Za-z_][A-Za-z0-9_]*)\b", + r"function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", + r"public\s+function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\("]: + try: + hints += re.findall(pat, txt[:4000]) + except Exception: + pass + # uniek en klein houden + out = [] + for h in hints: + if h not in out: + out.append(h) + if len(out) >= 6: + break + return out + + augmented_docs = [] + metadatas_mod = [] + for doc, meta in zip(documents, metadatas): + path = (meta or {}).get("path", "") + ext = (Path(path).suffix.lower().lstrip(".") if path else "") or "txt" + syms = _symbol_hints(doc) + header = f"FILE:{path} | LANG:{ext} | SYMBOLS:{','.join(syms)}\n" + augmented_docs.append(header + (doc or "")) + m = dict(meta or {}) + if syms: + m["symbols"] = ",".join(syms[:8]) + metadatas_mod.append(m) + + embs = _EMBEDDER.embed_documents(augmented_docs) + collection.add(documents=augmented_docs, embeddings=embs, metadatas=metadatas_mod, ids=ids) + +# ----------------------------------------------------------------------------- +# Repository profile + file selectie +# ----------------------------------------------------------------------------- +PROFILE_EXCLUDE_DIRS = { + ".git",".npm","node_modules","vendor","storage","dist","build",".next", + "__pycache__",".venv","venv",".mypy_cache",".pytest_cache", + "target","bin","obj","logs","cache","temp",".cache",".idea",".vscode" +} +PROFILE_INCLUDES = { + "generic": ["*.md","*.txt","*.json","*.yml","*.yaml","*.ini","*.cfg","*.toml", + "*.py","*.php","*.js","*.ts","*.jsx","*.tsx","*.css","*.scss", + "*.html","*.htm","*.vue","*.rb","*.go","*.java","*.cs","*.blade.php"], + "laravel": ["*.php","*.blade.php","*.md","*.env","*.json"], + "node": ["*.js","*.ts","*.jsx","*.tsx","*.vue","*.md","*.json","*.css","*.scss","*.html","*.htm"], +} + +def _detect_repo_profile(root: Path) -> str: + if (root / "artisan").exists() or (root / "composer.json").exists(): + return "laravel" + if (root / "package.json").exists(): + return "node" + return "generic" + +# ----------------------------------------------------------------------------- +# Tekst-extractie helpers +# ----------------------------------------------------------------------------- +TEXT_EXTS = { + ".php",".blade.php",".vue",".js",".ts",".jsx",".tsx",".css",".scss", + ".html",".htm",".json",".md",".ini",".cfg",".yml",".yaml",".toml", + ".py",".go",".rb",".java",".cs",".txt",".env",".sh",".bat",".dockerfile", +} +BINARY_SKIP = {".png",".jpg",".jpeg",".webp",".bmp",".gif",".ico",".pdf",".zip",".gz",".tar",".7z",".rar",".woff",".woff2",".ttf",".eot",".otf"} + +def _read_text_file(p: Path) -> str: + ext = p.suffix.lower() + # Skip obvious binaries quickly + if ext in BINARY_SKIP: + # PDF → probeer tekst + if ext == ".pdf" and PyPDF2: + try: + with open(p, "rb") as f: + reader = PyPDF2.PdfReader(f) + out = [] + for page in reader.pages: + try: + out.append(page.extract_text() or "") + except Exception: + pass + return "\n".join(out).strip() + except Exception: + return "" + return "" + # DOCX + if ext == ".docx" and docx: + try: + d = docx.Document(str(p)) + return "\n".join([para.text for para in d.paragraphs]) + except Exception: + return "" + # CSV/XLSX (alleen header + paar regels) + if ext in {".csv",".tsv"} and pd: + try: + df = pd.read_csv(p, nrows=200) + return df.to_csv(index=False)[:20000] + except Exception: + pass + if ext in {".xlsx",".xls"} and pd: + try: + df = pd.read_excel(p, nrows=200) + return df.to_csv(index=False)[:20000] + except Exception: + pass + # PPTX + if ext == ".pptx" and Presentation: + try: + prs = Presentation(str(p)) + texts = [] + for slide in prs.slides: + for shape in slide.shapes: + if hasattr(shape, "text"): + texts.append(shape.text) + return "\n".join(texts) + except Exception: + return "" + # Default: lees als tekst + try: + return p.read_text(encoding="utf-8", errors="ignore") + except Exception: + try: + return p.read_text(encoding="latin-1", errors="ignore") + except Exception: + return "" + +def _chunk_text(text: str, chunk_chars: int = 3000, overlap: int = 400) -> List[str]: + if not text: return [] + n = len(text); i = 0; step = max(1, chunk_chars - overlap) + chunks = [] + while i < n: + chunks.append(text[i:i+chunk_chars]) + i += step + return chunks + +# ----------------------------------------------------------------------------- +# Git helper +# ----------------------------------------------------------------------------- +def get_git_repo(repo_url: str, branch: str = "main") -> str: + """ + Clone of update repo in REPO_PATH, checkout branch. + Retourneert pad als string. + """ + os.makedirs(REPO_PATH, exist_ok=True) + # Unieke directory obv owner/repo of hash + name = None + try: + from urllib.parse import urlparse + u = urlparse(repo_url) + parts = [p for p in u.path.split("/") if p] + if parts: + name = parts[-1] + if name.endswith(".git"): name = name[:-4] + except Exception: + pass + if not name: + name = hashlib.sha1(repo_url.encode("utf-8", errors="ignore")).hexdigest()[:12] + local = os.path.join(REPO_PATH, name) + lock_path = f"{local}.lock" + with _file_lock(lock_path, timeout=float(os.getenv("GIT_LOCK_TIMEOUT","60"))): + if not os.path.exists(local): + logger.info("Cloning %s → %s", repo_url, local) + repo = git.Repo.clone_from(repo_url, local, depth=1) + else: + repo = git.Repo(local) + try: + repo.remote().fetch(depth=1, prune=True) + except Exception as e: + logger.warning("Fetch failed: %s", e) + # Checkout + logger.info("Checking out branch %s", branch) + try: + repo.git.checkout(branch) + except Exception: + # probeer origin/branch aan te maken + try: + repo.git.checkout("-B", branch, f"origin/{branch}") + except Exception: + # laatste fallback: default HEAD + logger.warning("Checkout %s failed; fallback to default HEAD", branch) + logger.info("Done with: Checking out branch %s", branch) + return local + +# ----------------------------------------------------------------------------- +# LLM call (OpenAI-compatible) +# ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# LLM call (OpenAI-compatible) +# ----------------------------------------------------------------------------- +# ----------------------------------------------------------------------------- +# LLM call (OpenAI-compatible) — met seriële wachtrij (LLM_SEM + LLM_QUEUE) +# ----------------------------------------------------------------------------- +async def llm_call_openai_compat( + messages: List[dict], + *, + model: Optional[str] = None, + stream: bool = False, + temperature: float = 0.2, + top_p: float = 0.9, + max_tokens: int = 1024, + extra: Optional[dict] = None, + stop: Optional[Union[str, list[str]]] = None, + **kwargs +) -> dict | StreamingResponse: + payload: dict = { + "model": model or os.getenv("LLM_MODEL", "mistral-medium"), + "messages": messages, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "stream": bool(stream) + } + # OpenAI-compat: optionele stop-sequenties doorgeven indien aanwezig + if stop is not None: + payload["stop"] = stop + # Eventuele andere onbekende kwargs negeren (compat met callsites die extra parameters sturen) + # (bewust geen payload.update(kwargs) om upstream niet te breken) + + if extra: + payload.update(extra) + + # kies URL + thread_key = None + try: + # Sticky per conversatie én model (voorkomt verkeerde stickiness en TypeErrors) + _model = model or os.getenv("LLM_MODEL","mistral-medium") + thread_key = f"{_model}:{derive_thread_id(messages)}" + except Exception: + thread_key = (model or os.getenv("LLM_MODEL","mistral-medium")) + ":" + json.dumps( + [m.get("role","")+":"+(m.get("content","")[:64]) for m in messages] + )[:256] + + upstream = _pick_upstream_sticky(thread_key) or _pick_upstream() + #url = (upstream.url if upstream else LLM_URL) + #upstream = _pick_upstream() + url = (upstream.url if upstream else LLM_URL) + + # --- NON-STREAM: wacht keurig op beurt en houd exclusieve lock vast + if not stream: + token = object() + app.state.LLM_QUEUE.append(token) + try: + # Fair: wacht tot je vooraan staat + # Fair: laat de eerste N (LLM_MAX_CONCURRENCY) door zodra er een permit vrij is + while True: + try: + pos = app.state.LLM_QUEUE.index(token) + 1 + except ValueError: + return # token verdwenen (client annuleerde) + free = getattr(app.state.LLM_SEM, "_value", 0) + if pos <= LLM_MAX_CONCURRENCY and free > 0: + await app.state.LLM_SEM.acquire() + break + await asyncio.sleep(0.1) + + + # ⬇️ BELANGRIJK: gebruik non-proxy client (zelfde host; vermijd env-proxy timeouts) + client = app.state.HTTPX + r = await client.post(url, json=payload) + try: + r.raise_for_status() + return r.json() + except httpx.HTTPStatusError as e: + _release_upstream(upstream, bad=True) + # geef een OpenAI-achtige fout terug + raise HTTPException(status_code=r.status_code, detail=f"LLM upstream error: {e}") # behoudt statuscode + except ValueError: + _release_upstream(upstream, bad=True) + raise HTTPException(status_code=502, detail="LLM upstream gaf geen geldige JSON.") + finally: + _release_upstream(upstream) + app.state.LLM_SEM.release() + finally: + try: + if app.state.LLM_QUEUE and app.state.LLM_QUEUE[0] is token: + app.state.LLM_QUEUE.popleft() + else: + app.state.LLM_QUEUE.remove(token) + except ValueError: + pass + + # --- STREAM: stuur wachtrij-updates als SSE totdat je aan de beurt bent + async def _aiter(): + token = object() + app.state.LLM_QUEUE.append(token) + try: + # 1) wachtrij-positie uitsturen zolang je niet mag + first = True + while True: + try: + pos = app.state.LLM_QUEUE.index(token) + 1 + except ValueError: + return + free = getattr(app.state.LLM_SEM, "_value", 0) + if pos <= LLM_MAX_CONCURRENCY and free > 0: + await app.state.LLM_SEM.acquire() + break + if first or pos <= LLM_MAX_CONCURRENCY: + data = { + "id": f"queue-info-{int(time.time())}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": payload["model"], + "choices": [{ + "index": 0, + "delta": {"role": "assistant", "content": f"⏳ Plek #{pos} – vrije slots: {free}\n"}, + "finish_reason": None + }] + } + yield ("data: " + json.dumps(data, ensure_ascii=False) + "\n\n").encode("utf-8") + first = False + await asyncio.sleep(1.0) + + + # 2) nu exclusieve toegang → echte upstream streamen + try: + client = app.state.HTTPX + # timeout=None voor onbegrensde streams + async with client.stream("POST", url, json=payload, timeout=None) as r: + r.raise_for_status() + HEARTBEAT = float(os.getenv("SSE_HEARTBEAT_SEC","10")) + q: asyncio.Queue[bytes] = asyncio.Queue(maxsize=100) + + async def _reader(): + try: + async for chunk in r.aiter_bytes(): + if chunk: + await q.put(chunk) + except Exception: + pass + finally: + await q.put(b"__EOF__") + + reader_task = asyncio.create_task(_reader()) + try: + while True: + try: + chunk = await asyncio.wait_for(q.get(), timeout=HEARTBEAT) + except asyncio.TimeoutError: + # SSE comment; door UIs genegeerd → houdt verbinding warm + yield b": ping\n\n" + continue + if chunk == b"__EOF__": + break + yield chunk + finally: + reader_task.cancel() + with contextlib.suppress(Exception): + await reader_task + except (asyncio.CancelledError, httpx.RemoteProtocolError) as e: + # Client brak verbinding; upstream niet ‘bad’ markeren + _release_upstream(upstream, bad=False) + logger.info("🔌 stream cancelled/closed by client: %s", e) + return + except Exception as e: + _release_upstream(upstream, bad=True) + logger.info("🔌 stream aborted (compat): %s", e) + return + finally: + _release_upstream(upstream) + app.state.LLM_SEM.release() + + finally: + try: + if app.state.LLM_QUEUE and app.state.LLM_QUEUE[0] is token: + app.state.LLM_QUEUE.popleft() + else: + app.state.LLM_QUEUE.remove(token) + except ValueError: + pass + + return StreamingResponse( + _aiter(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + }, + ) + + +def extract_code_block(s: str) -> str: + """ + Pak de eerste ``` ``` blok-inhoud (zonder fences). Val terug naar volledige tekst als geen blok aangetroffen. + """ + if not s: return "" + m = re.search(r"```[a-zA-Z0-9_+-]*\n([\s\S]*?)```", s) + if m: + return m.group(1).strip() + return s.strip() + +_SIZE_RE = re.compile(r"^\s*(\d+)\s*x\s*(\d+)\s*$") + +def _parse_size(size_str: str) -> tuple[int,int]: + m = _SIZE_RE.match(str(size_str or "512x512")) + if not m: return (512,512) + w = max(64, min(2048, int(m.group(1)))) + h = max(64, min(2048, int(m.group(2)))) + return (w, h) + +def _sanitize_svg(svg: str) -> str: + # strip code fences + if "```" in svg: + svg = extract_code_block(svg) + # remove scripts + on* handlers + svg = re.sub(r"<\s*script\b[^>]*>.*?<\s*/\s*script\s*>", "", svg, flags=re.I|re.S) + svg = re.sub(r"\son\w+\s*=\s*['\"].*?['\"]", "", svg, flags=re.I|re.S) + return svg + +def _svg_wrap_if_needed(svg: str, w: int, h: int, bg: str="white") -> str: + s = (svg or "").strip() + if " + + SVG niet gevonden in modeloutput +''' + if 'width=' not in s or 'height=' not in s: + s = re.sub(r" str: + sys = ("Je bent een SVG-tekenaar. Geef ALLEEN raw SVG 1.1 markup terug; geen uitleg of code fences. " + "Geen externe refs of scripts.") + user = (f"Maak een eenvoudige vectorillustratie.\n- Canvas {w}x{h}, achtergrond {background}\n" + f"- Thema: {prompt}\n- Gebruik eenvoudige vormen/paths/tekst.") + resp = await llm_call_openai_compat( + [{"role":"system","content":sys},{"role":"user","content":user}], + stream=False, temperature=0.35, top_p=0.9, max_tokens=1200 + ) + svg = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","") + return _svg_wrap_if_needed(_sanitize_svg(svg), w, h, background) + +# --------- kleine helpers: atomic json write & simpele file lock ---------- +def _atomic_json_write(path: str, data: dict): + tmp = f"{path}.tmp.{uuid.uuid4().hex}" + Path(path).parent.mkdir(parents=True, exist_ok=True) + with open(tmp, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False) + os.replace(tmp, path) + +@contextmanager +def _file_lock(lock_path: str, timeout: float = 60.0, poll: float = 0.2): + start = time.time() + fd = None + while True: + try: + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644) + break + except FileExistsError: + if time.time() - start > timeout: raise TimeoutError(f"Lock timeout for {lock_path}") + time.sleep(poll) + try: + os.write(fd, str(os.getpid()).encode("utf-8", errors="ignore")) + yield + finally: + with contextlib.suppress(Exception): + os.close(fd) if fd is not None else None + os.unlink(lock_path) + + +# -------- UploadFile → text (generiek), gebruikt je optional libs -------- +async def read_file_content(file: UploadFile) -> str: + name = (file.filename or "").lower() + raw = await run_in_threadpool(file.file.read) + + # Plain-text achtig + if name.endswith((".txt",".py",".php",".js",".ts",".jsx",".tsx",".html",".css",".json", + ".md",".yml",".yaml",".ini",".cfg",".log",".toml",".env",".sh",".bat",".dockerfile")): + return raw.decode("utf-8", errors="ignore") + + if name.endswith(".docx") and docx: + d = docx.Document(BytesIO(raw)) + return "\n".join(p.text for p in d.paragraphs) + + if name.endswith(".pdf") and PyPDF2: + reader = PyPDF2.PdfReader(BytesIO(raw)) + return "\n".join([(p.extract_text() or "") for p in reader.pages]) + + if name.endswith(".csv") and pd is not None: + df = pd.read_csv(BytesIO(raw)) + return df.head(200).to_csv(index=False) + + if name.endswith((".xlsx",".xls")) and pd is not None: + try: + sheets = pd.read_excel(BytesIO(raw), sheet_name=None, header=0) + except Exception as e: + return f"(Kon XLSX niet parsen: {e})" + out = [] + for sname, df in sheets.items(): + out.append(f"# Sheet: {sname}\n{df.head(200).to_csv(index=False)}") + return "\n\n".join(out) + + if name.endswith(".pptx") and Presentation: + pres = Presentation(BytesIO(raw)) + texts = [] + for i, slide in enumerate(pres.slides, 1): + buf = [f"# Slide {i}"] + for shape in slide.shapes: + if hasattr(shape, "has_text_frame") and shape.has_text_frame: + buf.append(shape.text.strip()) + texts.append("\n".join([t for t in buf if t])) + return "\n".join(texts) + + # Afbeeldingen of onbekend → probeer raw tekst + try: + return raw.decode("utf-8", errors="ignore") + except Exception: + return "(onbekend/beeld-bestand)" + + +def _client_ip(request: Request) -> str: + for hdr in ("cf-connecting-ip","x-real-ip","x-forwarded-for"): + v = request.headers.get(hdr) + if v: return v.split(",")[0].strip() + ip = request.client.host if request.client else "0.0.0.0" + return ip + +# --- Vision helpers: OpenAI content parts -> plain text + images (b64) --- +_JSON_FENCE_RE = re.compile(r"^```(?:json)?\s*([\s\S]*?)\s*```$", re.I) + +def _normalize_openai_vision_messages(messages: list[dict]) -> tuple[list[dict], list[str]]: + """ + Converteer OpenAI-achtige message parts (text + image_url) + naar (1) platte string met markers en (2) images=[base64,...]. + Alleen data: URI's of al-b64 doorlaten; remote http(s) URL's negeren (veilig default). + """ + imgs: list[str] = [] + out: list[dict] = [] + for m in messages: + c = m.get("content") + role = m.get("role","user") + if isinstance(c, list): + parts = [] + for item in c: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(item.get("text","")) + elif isinstance(item, dict) and item.get("type") == "image_url": + u = item.get("image_url") + if isinstance(u, dict): + u = u.get("url") + if isinstance(u, str): + if u.startswith("data:image") and "," in u: + b64 = u.split(",",1)[1] + imgs.append(b64) + parts.append("") + elif re.match(r"^[A-Za-z0-9+/=]+$", u.strip()): # ruwe base64 + imgs.append(u.strip()) + parts.append("") + else: + # http(s) URL's niet meesturen (veilig), maar ook geen ruis in de prompt + pass + out.append({"role": role, "content": "\n".join([p for p in parts if p]).strip()}) + else: + out.append({"role": role, "content": c or ""}) + return out, imgs + +def _parse_repo_qa_from_messages(messages: list[dict]) -> tuple[Optional[str], str, str, int]: + """Haal repo_hint, question, branch, n_ctx grofweg uit user-tekst.""" + txts = [m.get("content","") for m in messages if m.get("role") == "user"] + full = "\n".join(txts).strip() + repo_hint = None + m = re.search(r"(https?://\S+?)(?:\s|$)", full) + if m: repo_hint = m.group(1).strip() + if not repo_hint: + m = re.search(r"\brepo\s*:\s*([A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+)", full, flags=re.I) + if m: repo_hint = m.group(1).strip() + if not repo_hint: + m = re.search(r"\b([A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+)\b", full) + if m: repo_hint = m.group(1).strip() + branch = "main" + m = re.search(r"\bbranch\s*:\s*([A-Za-z0-9_.\/-]+)", full, flags=re.I) + if m: branch = m.group(1).strip() + n_ctx = 8 + m = re.search(r"\bn_ctx\s*:\s*(\d+)", full, flags=re.I) + if m: + try: n_ctx = max(1, min(16, int(m.group(1)))) + except: pass + question = full + return repo_hint, question, branch, n_ctx + + +# ------------------------------ +# OpenAI-compatible endpoints +# ------------------------------ +@app.get("/v1/models") +def list_models(): + base_model = os.getenv("LLM_MODEL", "mistral-medium") + # Toon ook je twee "virtuele" modellen voor de UI + return { + "object": "list", + "data": [ + {"id": base_model, "object": "model", "created": 0, "owned_by": "you"}, + {"id": "repo-agent", "object": "model", "created": 0, "owned_by": "you"}, + {"id": "repo-qa", "object": "model", "created": 0, "owned_by": "you"}, + ], + } + + +def _openai_chat_response(model: str, text: str, messages: list[dict]): + created = int(time.time()) + # heel simpele usage schatting; vermijdt None + try: + prompt_tokens = count_message_tokens(messages) if 'count_message_tokens' in globals() else approx_token_count(json.dumps(messages)) + except Exception: + prompt_tokens = approx_token_count(json.dumps(messages)) + completion_tokens = approx_token_count(text) + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": created, + "model": model, + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": text}, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens + } + } + +def _get_stt_model(): + global _STT_MODEL + if _STT_MODEL is not None: + return _STT_MODEL + try: + from faster_whisper import WhisperModel + except Exception: + raise HTTPException(status_code=500, detail="STT niet beschikbaar: faster-whisper ontbreekt.") + # device-selectie + if STT_DEVICE == "auto": + try: + _STT_MODEL = WhisperModel(STT_MODEL_NAME, device="cuda", compute_type="float16") + return _STT_MODEL + except Exception: + _STT_MODEL = WhisperModel(STT_MODEL_NAME, device="cpu", compute_type="int8") + return _STT_MODEL + dev, comp = ("cuda","float16") if STT_DEVICE=="cuda" else ("cpu","int8") + _STT_MODEL = WhisperModel(STT_MODEL_NAME, device=dev, compute_type=comp) + return _STT_MODEL + +def _stt_transcribe_path(path: str, lang: str | None): + model = _get_stt_model() + segments, info = model.transcribe(path, language=lang or None, vad_filter=True) + text = "".join(seg.text for seg in segments).strip() + return text, getattr(info, "language", None) + +@app.post("/v1/audio/transcriptions") +async def audio_transcriptions( + file: UploadFile = File(...), + model: str = Form("whisper-1"), + language: str | None = Form(None), + prompt: str | None = Form(None) +): + data = await file.read() + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: + tmp.write(data) + tmp_path = tmp.name + try: + text, lang = await run_in_threadpool(_stt_transcribe_path, tmp_path, language) + return {"text": text, "language": lang or "unknown"} + finally: + try: os.unlink(tmp_path) + except Exception: pass + +@app.post("/v1/audio/speech") +async def audio_speech(body: dict = Body(...)): + """ + OpenAI-compat: { "model": "...", "voice": "optional", "input": "tekst" } + Return: audio/wav + """ + if not PIPER_VOICE: + raise HTTPException(status_code=500, detail="PIPER_VOICE env var niet gezet (piper TTS).") + text = (body.get("input") or "").strip() + if not text: + raise HTTPException(status_code=400, detail="Lege input") + + def _synth_to_wav_bytes() -> bytes: + with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as out: + out_path = out.name + try: + cp = subprocess.run( + [PIPER_BIN, "-m", PIPER_VOICE, "-f", out_path, "-q"], + input=text.encode("utf-8"), + stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True + ) + with open(out_path, "rb") as f: + return f.read() + finally: + try: os.unlink(out_path) + except Exception: pass + + audio_bytes = await run_in_threadpool(_synth_to_wav_bytes) + return StreamingResponse(iter([audio_bytes]), media_type="audio/wav") + +@app.post("/v1/images/generations") +async def images_generations(payload: dict = Body(...)): + prompt = (payload.get("prompt") or "").strip() + if not prompt: + raise HTTPException(status_code=400, detail="prompt required") + n = max(1, min(8, int(payload.get("n", 1)))) + size = payload.get("size","512x512") + w,h = _parse_size(size) + background = payload.get("background","white") + fmt = (payload.get("format") or "").lower().strip() # non-standard + if fmt not in ("png","svg",""): + fmt = "" + if not fmt: + fmt = "png" if cairosvg is not None else "svg" + + out_items = [] + for _ in range(n): + svg = await _svg_from_prompt(prompt, w, h, background) + if fmt == "png" and cairosvg is not None: + png_bytes = cairosvg.svg2png(bytestring=svg.encode("utf-8"), output_width=w, output_height=h) + b64 = base64.b64encode(png_bytes).decode("ascii") + else: + b64 = base64.b64encode(svg.encode("utf-8")).decode("ascii") + out_items.append({"b64_json": b64}) + return {"created": int(time.time()), "data": out_items} + +@app.get("/v1/images/health") +def images_health(): + return {"svg_to_png": bool(cairosvg is not None)} + +@app.post("/present/make") +async def present_make( + prompt: str = Form(...), + file: UploadFile | None = File(None), + max_slides: int = Form(8) +): + if Presentation is None: + raise HTTPException(500, "python-pptx ontbreekt in de container.") + src_text = "" + if file: + try: + src_text = (await read_file_content(file))[:30000] + except Exception: + src_text = "" + sys = ("Je bent een presentatieschrijver. Geef ALLEEN geldige JSON terug: " + '{"title": str, "slides":[{"title": str, "bullets":[str,...]}...]}.') + user = (f"Doelpresentatie: {prompt}\nBron (optioneel):\n{src_text[:12000]}\n" + f"Max. {max_slides} dia's, 3–6 bullets per dia.") + plan = await llm_call_openai_compat( + [{"role":"system","content":sys},{"role":"user","content":user}], + stream=False, temperature=0.3, top_p=0.9, max_tokens=1200 + ) + raw = (plan.get("choices",[{}])[0].get("message",{}) or {}).get("content","{}") + try: + spec = json.loads(raw) + except Exception: + raise HTTPException(500, "Model gaf geen valide JSON voor slides.") + prs = Presentation() + title = spec.get("title") or "Presentatie" + # Titel + slide_layout = prs.slide_layouts[0] + s = prs.slides.add_slide(slide_layout) + s.shapes.title.text = title + # Content + for slide in spec.get("slides", []): + layout = prs.slide_layouts[1] + sl = prs.slides.add_slide(layout) + sl.shapes.title.text = slide.get("title","") + tx = sl.placeholders[1].text_frame + tx.clear() + for i, bullet in enumerate(slide.get("bullets", [])[:10]): + p = tx.add_paragraph() if i>0 else tx.paragraphs[0] + p.text = bullet + p.level = 0 + bio = BytesIO() + prs.save(bio); bio.seek(0) + headers = {"Content-Disposition": f'attachment; filename="deck.pptx"'} + return StreamingResponse(bio, media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation", headers=headers) + +@app.get("/whoami") +def whoami(): + import socket + return {"service": "mistral-api", "host": socket.gethostname(), "LLM_URL": LLM_URL} + +IMG_EXTS = (".png",".jpg",".jpeg",".webp",".bmp",".gif") + +def _is_image_filename(name: str) -> bool: + return (name or "").lower().endswith(IMG_EXTS) + +@app.post("/vision/ask") +async def vision_ask( + file: UploadFile = File(...), + prompt: str = Form("Beschrijf kort wat je ziet."), + stream: bool = Form(False), + temperature: float = Form(0.2), + top_p: float = Form(0.9), + max_tokens: int = Form(512), +): + raw = await run_in_threadpool(file.file.read) + img_b64 = base64.b64encode(raw).decode("utf-8") + messages = [{"role":"user","content": f" {prompt}"}] + if stream: + return await llm_call_openai_compat( + messages, + model=os.getenv("LLM_MODEL","mistral-medium"), + stream=True, + temperature=float(temperature), + top_p=float(top_p), + max_tokens=int(max_tokens), + extra={"images": [img_b64]} + ) + + + return await llm_call_openai_compat( + messages, stream=False, temperature=temperature, top_p=top_p, max_tokens=max_tokens, + extra={"images": [img_b64]} + ) + +@app.post("/file/vision-and-text") +async def vision_and_text( + files: List[UploadFile] = File(...), + prompt: str = Form("Combineer visuele analyse met de tekstcontext. Geef 5 bullets met bevindingen en 3 actiepunten."), + stream: bool = Form(False), + max_images: int = Form(6), + max_chars: int = Form(25000), + temperature: float = Form(0.2), + top_p: float = Form(0.9), + max_tokens: int = Form(1024), +): + images_b64: list[str] = [] + text_chunks: list[str] = [] + for f in files: + name = f.filename or "" + raw = await run_in_threadpool(f.file.read) + if _is_image_filename(name) and len(images_b64) < int(max_images): + images_b64.append(base64.b64encode(raw).decode("utf-8")) + else: + try: + tmp = UploadFile(filename=name, file=BytesIO(raw), headers=f.headers) + text = await read_file_content(tmp) + except Exception: + text = raw.decode("utf-8", errors="ignore") + if text: + text_chunks.append(f"### {name}\n{text.strip()}") + + text_context = ("\n\n".join(text_chunks))[:int(max_chars)].strip() + image_markers = ("\n".join([""] * len(images_b64))).strip() + user_content = (image_markers + ("\n\n" if image_markers else "") + prompt.strip()) + if text_context: + user_content += f"\n\n--- TEKST CONTEXT (ingekort) ---\n{text_context}" + messages = [{"role":"user","content": user_content}] + + if stream: + return await llm_call_openai_compat( + messages, + model=os.getenv("LLM_MODEL","mistral-medium"), + stream=True, + temperature=float(temperature), + top_p=float(top_p), + max_tokens=int(max_tokens), + extra={"images": images_b64} + ) + + + return await llm_call_openai_compat( + messages, stream=False, temperature=temperature, top_p=top_p, max_tokens=max_tokens, + extra={"images": images_b64} + ) + +@app.get("/vision/health") +async def vision_health(): + tiny_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAOaO5nYAAAAASUVORK5CYII=" + try: + messages = [{"role":"user","content":" Beschrijf dit in één woord."}] + resp = await llm_call_openai_compat(messages, extra={"images":[tiny_png]}, max_tokens=16) + txt = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","").strip() + return {"vision": bool(txt), "sample": txt[:60]} + except Exception as e: + return {"vision": False, "error": str(e)} + +# -------- Tool registry (OpenAI-style) -------- +LLM_FUNCTION_CALLING_MODE = os.getenv("LLM_FUNCTION_CALLING_MODE", "auto").lower() # "native" | "shim" | "auto" + +OWUI_BASE_URL='http://localhost:3000' +OWUI_API_TOKEN='sk-f1b7991b054442b5ae388de905019726' +# Aliassen zodat oudere codepaths blijven werken +OWUI_BASE = OWUI_BASE_URL +OWUI_TOKEN = OWUI_API_TOKEN + + +@app.get("/tools", operation_id="get_tools_list_compat") +async def tools_compat(): + return await list_tools() + + + +async def _fetch_openwebui_file(file_id: str, dst_dir: Path) -> Path: + """ + Download raw file content from OpenWebUI and save locally. + """ + if not (OWUI_BASE_URL and OWUI_API_TOKEN): + raise HTTPException(status_code=501, detail="OpenWebUI niet geconfigureerd (OWUI_BASE_URL/OWUI_API_TOKEN).") + url = f"{OWUI_BASE_URL}/api/v1/files/{file_id}/content" + headers = {"Authorization": f"Bearer {OWUI_API_TOKEN}"} + + dst_dir.mkdir(parents=True, exist_ok=True) + out_path = dst_dir / f"{file_id}" + # Gebruik proxy-aware client uit app.state als beschikbaar + client = getattr(app.state, "HTTPX_PROXY", None) + close_after = False + if client is None: + client = httpx.AsyncClient(timeout=None, trust_env=True) + close_after = True + try: + async with client.stream("GET", url, headers=headers, timeout=None) as resp: + if resp.status_code != 200: + # lees body (beperkt) voor foutmelding + body = await resp.aread() + raise HTTPException(status_code=resp.status_code, + detail=f"OpenWebUI file fetch failed: {body[:2048]!r}") + max_mb = int(os.getenv("OWUI_MAX_DOWNLOAD_MB", "64")) + max_bytes = max_mb * 1024 * 1024 + total = 0 + with out_path.open("wb") as f: + async for chunk in resp.aiter_bytes(): + if not chunk: + continue + total += len(chunk) + if total > max_bytes: + try: f.close() + except Exception: pass + try: out_path.unlink(missing_ok=True) + except Exception: pass + raise HTTPException(status_code=413, detail=f"Bestand groter dan {max_mb}MB; download afgebroken.") + f.write(chunk) + return out_path + except HTTPException: + raise + except Exception as e: + logging.exception("download failed") + raise HTTPException(status_code=500, detail=f"download error: {e}") + finally: + if close_after: + await client.aclose() + + +def _normalize_files_arg(args: dict): + # OpenWebUI injecteert files als __files__ = [{ id, name, mime, size, ...}] + files = args.get("__files__") or args.get("files") or [] + if isinstance(files, dict): # enkel bestand + files = [files] + return files + + +@app.get("/openapi.json", include_in_schema=False) +async def openapi_endpoint(): + return get_openapi( + title="Tool Server", + version="0.1.0", + routes=app.routes, + ) + +def _openai_tools_from_registry(reg: dict): + out = [] + for name, spec in reg.items(): + out.append({ + "type": "function", + "function": { + "name": name, + "description": spec.get("description",""), + "parameters": spec.get("parameters", {"type":"object","properties":{}}) + } + }) + return out + +def _visible_registry(reg: dict) -> dict: + return {k:v for k,v in reg.items() if not v.get("hidden")} + +VALIDATE_PROMPT = ( + "Je bent een code-reviewer die deze code moet valideren:\n" + "1. Controleer op syntactische fouten\n" + "2. Controleer op logische fouten en onjuiste functionaliteit\n" + "3. Check of alle vereiste functionaliteit aanwezig is\n" + "4. Zoek naar mogelijke bugs of veiligheidsrisico's\n" + "5. Geef specifieke feedback met regelnummers\n\n" + "Geef een lijst in dit formaat:\n" + "- Regels [10-12]: Beschrijving\n" + "- Regels 25: Beschrijving" +) + +def _parse_validation_results(text: str) -> list[str]: + issues = [] + for line in (text or "").splitlines(): + line = line.strip() + if line.startswith('-') and 'Regels' in line and ':' in line: + issues.append(line) + return issues + +async def _execute_tool(name: str, args: dict) -> dict: + if name == "repo_grep": + repo_url = args.get("repo_url","") + branch = args.get("branch","main") + query = (args.get("query") or "").strip() + max_hits = int(args.get("max_hits",200)) + if not repo_url or not query: + raise HTTPException(status_code=400, detail="repo_url en query verplicht.") + repo_path = await _get_git_repo_async(repo_url, branch) + root = Path(repo_path) + hits = [] + qlow = query.lower() + for p in root.rglob("*"): + if p.is_dir(): + continue + if set(p.parts) & PROFILE_EXCLUDE_DIRS: + continue + if p.suffix.lower() in BINARY_SKIP: + continue + try: + txt = _read_text_file(p) + except Exception: + continue + if not txt: + continue + # snelle lijngewijze scan + for ln, line in enumerate(txt.splitlines(), 1): + if qlow in line.lower(): + hits.append({ + "path": str(p.relative_to(root)), + "line": ln, + "excerpt": line.strip()[:400] + }) + if len(hits) >= max_hits: + break + if len(hits) >= max_hits: + break + return {"count": len(hits), "hits": hits, "repo": os.path.basename(repo_url), "branch": branch} + # RAG + if name == "rag_index_repo": + out = await run_in_threadpool(_rag_index_repo_sync, **{ + "repo_url": args.get("repo_url",""), + "branch": args.get("branch","main"), + "profile": args.get("profile","auto"), + "include": args.get("include",""), + "exclude_dirs": args.get("exclude_dirs",""), + "chunk_chars": int(args.get("chunk_chars",3000)), + "overlap": int(args.get("overlap",400)), + "collection_name": args.get("collection_name","code_docs"), + "force": bool(args.get("force", False)), + }) + return out + if name == "rag_query": + out = await rag_query_api( + query=args.get("query",""), + n_results=int(args.get("n_results",5)), + collection_name=args.get("collection_name","code_docs"), + repo=args.get("repo"), + path_contains=args.get("path_contains"), + profile=args.get("profile") + ) + return out + + # Tekst tools + if name == "summarize_text": + text = (args.get("text") or "")[:int(args.get("max_chars",16000))] + instruction = args.get("instruction") or "Vat samen in bullets (max 10), met korte inleiding en actiepunten." + resp = await llm_call_openai_compat( + [{"role":"system","content":"Je bent behulpzaam en exact."}, + {"role":"user","content": f"{instruction}\n\n--- BEGIN ---\n{text}\n--- EINDE ---"}], + stream=False, max_tokens=768 + ) + return resp + + if name == "analyze_text": + text = (args.get("text") or "")[:int(args.get("max_chars",20000))] + goal = args.get("goal") or "Licht toe wat dit document doet. Benoem sterke/zwakke punten, risico’s en concrete verbeterpunten." + resp = await llm_call_openai_compat( + [{"role":"system","content":"Wees feitelijk en concreet."}, + {"role":"user","content": f"{goal}\n\n--- BEGIN ---\n{text}\n--- EINDE ---"}], + stream=False, max_tokens=768 + ) + return resp + + if name == "improve_text": + text = (args.get("text") or "")[:int(args.get("max_chars",20000))] + objective = args.get("objective") or "Verbeter dit document. Lever een beknopte toelichting en vervolgens de volledige verbeterde versie." + style = args.get("style") or "Hanteer best practices, behoud inhoudelijke betekenis." + resp = await llm_call_openai_compat( + [{"role":"system","content": SYSTEM_PROMPT}, + {"role":"user","content": f"{objective}\nStijl/criteria: {style}\n" + "Antwoord met eerst een korte toelichting, daarna alleen de verbeterde inhoud tussen een codeblok.\n\n" + f"--- BEGIN ---\n{text}\n--- EINDE ---"}], + stream=False, max_tokens=1024 + ) + return resp + + # Code tools + if name == "validate_code_text": + code = args.get("code","") + resp = await llm_call_openai_compat( + [{"role":"system","content":"Wees strikt en concreet."}, + {"role":"user","content": f"{VALIDATE_PROMPT}\n\nCode om te valideren:\n```\n{code}\n```"}], + stream=False, max_tokens=512 + ) + txt = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","") + return {"status":"issues_found" if _parse_validation_results(txt) else "valid", + "issues": _parse_validation_results(txt), "raw": txt} + + if name == "improve_code_text": + code = args.get("code","") + language = args.get("language","auto") + focus = args.get("improvement_focus","best practices") + resp = await llm_call_openai_compat( + [{"role":"system","content": SYSTEM_PROMPT}, + {"role":"user","content": f"Verbeter deze {language}-code met focus op {focus}:\n\n" + f"{code}\n\nGeef eerst een korte toelichting, dan alleen het verbeterde codeblok. Behoud functionaliteit."}], + stream=False, max_tokens=1536 + ) + return resp + + if name == "ingest_openwebui_files": + files = _normalize_files_arg(args) + if not files: + raise HTTPException(status_code=400, detail="Geen bestanden ontvangen (__files__ is leeg).") + saved = [] + tmpdir = Path("/tmp/owui_files") + for fmeta in files: + fid = fmeta.get("id") + if not fid: + continue + path = await _fetch_openwebui_file(fid, tmpdir) + saved.append({"id": fid, "path": str(path), "name": fmeta.get("name"), "mime": fmeta.get("mime")}) + + # >>> HIER je eigen pipeline aanroepen, bijv. direct indexeren: + # for s in saved: index_file_into_chroma(s["path"], collection=args.get("target_collection","code_docs"), ...) + + return {"ok": True, "downloaded": saved, "collection": args.get("target_collection","code_docs")} + + if name == "vision_analyze": + image_url = args.get("image_url","") + prompt = args.get("prompt","Beschrijf beknopt wat je ziet en noem de belangrijkste details.") + max_tokens = int(args.get("max_tokens",512)) + b64 = None + # Alleen data: of raw base64 accepteren; http(s) niet, want die worden niet + # ingeladen in de vision-call en zouden stil falen. + if image_url.startswith(("http://","https://")): + raise HTTPException(status_code=400, detail="vision_analyze: image_url moet data: URI of raw base64 zijn.") + if image_url.startswith("data:image") and "," in image_url: + b64 = image_url.split(",",1)[1] + elif re.match(r"^[A-Za-z0-9+/=]+$", image_url.strip()): + b64 = image_url.strip() + messages = [{"role":"user","content": f" {prompt}"}] + return await llm_call_openai_compat(messages, stream=False, max_tokens=max_tokens, + extra={"images": [b64] if b64 else []}) + + raise HTTPException(status_code=400, detail=f"Unknown tool: {name}") + +TOOLS_REGISTRY = { + "repo_grep": { + "description": "Zoek exact(e) tekst in een git repo (fast grep-achtig).", + "parameters": { + "type":"object", + "properties":{ + "repo_url":{"type":"string"}, + "branch":{"type":"string","default":"main"}, + "query":{"type":"string"}, + "max_hits":{"type":"integer","default":200} + }, + "required":["repo_url","query"] + } + }, + "ingest_openwebui_files": { + "description": "Download aangehechte OpenWebUI-bestanden en voer ingestie/embedding uit.", + "parameters": { + "type":"object", + "properties":{ + "target_collection":{"type":"string","default":"code_docs"}, + "profile":{"type":"string","default":"auto"}, + "chunk_chars":{"type":"integer","default":3000}, + "overlap":{"type":"integer","default":400} + } + } + }, + "rag_index_repo": { + "description": "Indexeer een git-repo in ChromaDB (chunken & metadata).", + "parameters": { + "type":"object", + "properties":{ + "repo_url":{"type":"string"}, + "branch":{"type":"string","default":"main"}, + "profile":{"type":"string","default":"auto"}, + "include":{"type":"string","default":""}, + "exclude_dirs":{"type":"string","default":""}, + "chunk_chars":{"type":"integer","default":3000}, + "overlap":{"type":"integer","default":400}, + "collection_name":{"type":"string","default":"code_docs"}, + "force":{"type":"boolean","default":False} + }, + "required":["repo_url"] + } + }, + "rag_query": { + "description": "Zoek in de RAG-collectie en geef top-N passages (hybride rerank).", + "parameters": { + "type":"object", + "properties":{ + "query":{"type":"string"}, + "n_results":{"type":"integer","default":8}, + "collection_name":{"type":"string","default":"code_docs"}, + "repo":{"type":["string","null"]}, + "path_contains":{"type":["string","null"]}, + "profile":{"type":["string","null"]} + }, + "required":["query"] + } + }, + "summarize_text": { + "description": "Vat tekst samen in bullets met inleiding en actiepunten.", + "parameters": {"type":"object","properties":{ + "text":{"type":"string"}, + "instruction":{"type":"string","default":"Vat samen in bullets (max 10), met korte inleiding en actiepunten."}, + "max_chars":{"type":"integer","default":16000} + },"required":["text"]} + }, + "analyze_text": { + "description": "Analyseer tekst: sterke/zwakke punten, risico’s en verbeterpunten.", + "parameters": {"type":"object","properties":{ + "text":{"type":"string"}, + "goal":{"type":"string","default":"Licht toe wat dit document doet. Benoem sterke/zwakke punten, risico’s en concrete verbeterpunten."}, + "max_chars":{"type":"integer","default":20000} + },"required":["text"]} + }, + "improve_text": { + "description": "Verbeter tekst: korte toelichting + volledige verbeterde versie.", + "parameters": {"type":"object","properties":{ + "text":{"type":"string"}, + "objective":{"type":"string","default":"Verbeter dit document. Lever een beknopte toelichting en vervolgens de volledige verbeterde versie."}, + "style":{"type":"string","default":"Hanteer best practices, behoud inhoudelijke betekenis."}, + "max_chars":{"type":"integer","default":20000} + },"required":["text"]} + }, + "validate_code_text": { + "description": "Valideer code; geef issues (met regelnummers).", + "parameters": {"type":"object","properties":{ + "code":{"type":"string"} + },"required":["code"]} + }, + "improve_code_text": { + "description": "Verbeter code met focus (best practices/perf/security).", + "parameters": {"type":"object","properties":{ + "code":{"type":"string"}, + "language":{"type":"string","default":"auto"}, + "improvement_focus":{"type":"string","default":"best practices"} + },"required":["code"]} + }, + "vision_analyze": { + "description": "Eén afbeelding analyseren via data: URI of base64.", + "parameters":{"type":"object","properties":{ + "image_url":{"type":"string"}, + "prompt":{"type":"string","default":"Beschrijf beknopt wat je ziet en de belangrijkste details."}, + "max_tokens":{"type":"integer","default":512} + },"required":["image_url"]} + } +} + +# Verberg OWUI-afhankelijke tool wanneer niet geconfigureerd +try: + if not (OWUI_BASE_URL and OWUI_API_TOKEN): + # bestaat altijd als key; markeer als hidden zodat /v1/tools hem niet toont + TOOLS_REGISTRY["ingest_openwebui_files"]["hidden"] = True +except Exception: + pass + + +def _tools_list_from_registry(reg: dict): + lst = [] + for name, spec in reg.items(): + lst.append({ + "name": name, + "description": spec.get("description", ""), + "parameters": spec.get("parameters", {"type":"object","properties":{}}) + }) + return {"mode": LLM_FUNCTION_CALLING_MODE if 'LLM_FUNCTION_CALLING_MODE' in globals() else "shim", + "tools": lst} + + +@app.get("/v1/tools", operation_id="get_tools_list") +async def list_tools(format: str = "proxy"): + # negeer 'format' en geef altijd OpenAI tool list terug + return { + "object": "tool.list", + "mode": "function", # <- belangrijk voor OWUI 0.6.21 + "data": _openai_tools_from_registry(_visible_registry(TOOLS_REGISTRY)) + } + +# === OpenAPI-compat: 1 endpoint per tool (operationId = toolnaam) === +# Open WebUI kijkt naar /openapi.json, leest operationId’s en maakt daar “tools” van. +def _register_openapi_tool(name: str): + # Opzet: POST /openapi/{name} met body == arguments-dict + # operation_id = name => Open WebUI toolnaam = name + # Belangrijk: injecteer requestBody schema in OpenAPI, zodat OWUI de parameters kent. + schema = (TOOLS_REGISTRY.get(name, {}) or {}).get("parameters", {"type":"object","properties":{}}) + @app.post( + f"/openapi/{name}", + operation_id=name, + summary=f"Tool: {name}", + openapi_extra={ + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": schema + } + } + } + } + ) + async def _tool_entry(payload: dict = Body(...)): + if name not in TOOLS_REGISTRY: + raise HTTPException(status_code=404, detail=f"Unknown tool: {name}") + # payload = {"arg1":..., "arg2":...} + out = await _execute_tool(name, payload or {}) + # Als _execute_tool een OpenAI-chat response teruggeeft, haal de tekst eruit + if isinstance(out, dict) and "choices" in out: + try: + txt = ((out.get("choices") or [{}])[0].get("message") or {}).get("content", "") + return {"result": txt} + except Exception: + pass + return out + +for _t in TOOLS_REGISTRY.keys(): + _register_openapi_tool(_t) + + +@app.post("/v1/tools/call", operation_id="post_tools_call") +async def tools_entry(request: Request): + #if request.method == "GET": + # return {"object":"tool.list","mode":LLM_FUNCTION_CALLING_MODE,"data": _openai_tools_from_registry(TOOLS_REGISTRY)} + data = await request.json() + name = (data.get("name") + or (data.get("tool") or {}).get("name") + or (data.get("function") or {}).get("name")) + raw_args = (data.get("arguments") + or (data.get("tool") or {}).get("arguments") + or (data.get("function") or {}).get("arguments") + or {}) + if isinstance(raw_args, str): + try: args = json.loads(raw_args) + except Exception: args = {} + else: + args = raw_args + if not name or name not in TOOLS_REGISTRY or TOOLS_REGISTRY[name].get("hidden"): + raise HTTPException(status_code=404, detail=f"Unknown tool: {name}") + try: + output = await _execute_tool(name, args) + except HTTPException: + raise + except Exception as e: + logging.exception("tool call failed") + raise HTTPException(status_code=500, detail=str(e)) + return {"id": f"toolcall-{uuid.uuid4().hex}","object":"tool.run","created": int(time.time()),"name": name,"arguments": args,"output": output} + +async def _complete_with_autocontinue( + messages: list[dict], + *, + model: str, + temperature: float, + top_p: float, + max_tokens: int, + extra_payload: dict | None, + max_autocont: int +) -> tuple[str, dict]: + """ + Eén of meer non-stream calls met jouw bestaande auto-continue logica. + Retourneert (full_text, last_response_json). + """ + resp = await llm_call_openai_compat( + messages, model=model, stream=False, + temperature=temperature, top_p=top_p, max_tokens=max_tokens, + extra=extra_payload if extra_payload else None + ) + content = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","") + finish_reason = (resp.get("choices",[{}])[0] or {}).get("finish_reason") + + continues = 0 + NEAR_MAX = int(os.getenv("LLM_AUTOCONT_NEAR_MAX", "48")) + def _near_cap(txt: str) -> bool: + try: + return approx_token_count(txt) >= max_tokens - NEAR_MAX + except Exception: + return False + while continues < max_autocont and ( + finish_reason in ("length","content_filter") + or content.endswith("…") + or _near_cap(content) + ): + follow = messages + [ + {"role":"assistant","content": content}, + {"role":"user","content": "Ga verder zonder herhaling; alleen het vervolg."} + ] + nxt = await llm_call_openai_compat( + follow, model=model, stream=False, + temperature=temperature, top_p=top_p, max_tokens=max_tokens, + extra=extra_payload if extra_payload else None + ) + more = (nxt.get("choices",[{}])[0].get("message",{}) or {}).get("content","") + content = (content + ("\n" if content and more else "") + more).strip() + finish_reason = (nxt.get("choices",[{}])[0] or {}).get("finish_reason") + resp = nxt # laatst gezien resp teruggeven (usage etc.) + continues += 1 + return content, resp + +@app.post("/v1/chat/completions") +async def openai_chat_completions(body: dict = Body(...), request: Request = None): + model = (body.get("model") or os.getenv("LLM_MODEL", "mistral-medium")).strip() + #logging.info(str(body)) + #logging.info(str(request)) + stream = bool(body.get("stream", False)) + raw_messages = body.get("messages") or [] + # normaliseer tool-berichten naar plain tekst voor het LLM + norm_messages = [] + for m in raw_messages: + if m.get("role") == "tool": + nm = m.get("name") or "tool" + norm_messages.append({ + "role": "user", + "content": f"[{nm} RESULT]\n{m.get('content') or ''}" + }) + else: + norm_messages.append(m) + + + # --- minimal tool-calling glue (laat rest van je functie intact) --- + tools = body.get("tools") or [] + # 'tool_choice' hier alleen lezen; later in de native branch wordt opnieuw naar body gekeken + tool_choice_req = body.get("tool_choice") # 'auto' | 'none' | 'required' | {...} + try: + logger.info("🧰 tools_count=%s, tool_choice=%s", len(tools), tool_choice_req) + except Exception: + pass + + # OWUI stuurt vaak "required" als: "er MOET een tool worden gebruikt". + # Als er precies 1 tool is meegegeven, normaliseren we dat naar "force deze tool". + if tool_choice_req == "required" and tools: + names = [ (t.get("function") or {}).get("name") for t in tools if t.get("function") ] + names = [ n for n in names if n ] + # 1) exact 1 → force die + if len(set(names)) == 1: + tool_choice_req = {"type":"function","function":{"name": names[0]}} + else: + # 2) meerdere → kies op basis van user prompt (noem de toolnaam) + last_user = next((m for m in reversed(norm_messages) if m.get("role")=="user"), {}) + utext = (last_user.get("content") or "").lower() + mentioned = [n for n in names if n and n.lower() in utext] + if mentioned: + tool_choice_req = {"type":"function","function":{"name": mentioned[0]}} + logger.info("🔧 required->picked tool by mention: %s", mentioned[0]) + + + + # (1) Force: OWUI dwingt een specifieke tool af + if isinstance(tool_choice_req, dict) and (tool_choice_req.get("type") == "function"): + fname = (tool_choice_req.get("function") or {}).get("name") + if fname and fname not in TOOLS_REGISTRY: + # Onbekende tool → laat de LLM zelf native tool_calls teruggeven. + passthrough = dict(body) + passthrough["messages"] = norm_messages + passthrough["stream"] = False + client = app.state.HTTPX + r = await client.post(LLM_URL, json=passthrough) + try: + return JSONResponse(r.json(), status_code=r.status_code) + except Exception: + return PlainTextResponse(r.text, status_code=r.status_code) + + if fname: + # Probeer lichte heuristiek voor bekende tools + last_user = next((m for m in reversed(norm_messages) if m.get("role")=="user"), {}) + utext = (last_user.get("content") or "") + args: dict = {} + + if fname == "rag_index_repo": + m = re.search(r'(https?://\S+)', utext) + if m: args["repo_url"] = m.group(1) + mb = re.search(r'\bbranch\s+([A-Za-z0-9._/-]+)', utext, re.I) + if mb: args["branch"] = mb.group(1) + elif fname == "rag_query": + args["query"] = utext.strip() + elif fname == "summarize_text": + m = re.search(r':\s*(.+)$', utext, re.S) + args["text"] = (m.group(1).strip() if m else utext.strip())[:16000] + elif fname == "analyze_text": + m = re.search(r':\s*(.+)$', utext, re.S) + args["text"] = (m.group(1).strip() if m else utext.strip())[:20000] + elif fname == "improve_text": + m = re.search(r':\s*(.+)$', utext, re.S) + args["text"] = (m.group(1).strip() if m else utext.strip())[:20000] + elif fname == "validate_code_text": + code = re.search(r"```.*?\n(.*?)```", utext, re.S) + args["code"] = (code.group(1).strip() if code else utext.strip()) + elif fname == "improve_code_text": + code = re.search(r"```.*?\n(.*?)```", utext, re.S) + args["code"] = (code.group(1).strip() if code else utext.strip()) + elif fname == "vision_analyze": + m = re.search(r'(data:image\/[a-zA-Z]+;base64,[A-Za-z0-9+/=]+|https?://\S+)', utext) + if m: args["image_url"] = m.group(1) + + # Check verplichte velden; zo niet → native passthrough met alleen deze tool + required = (TOOLS_REGISTRY.get(fname, {}).get("parameters", {}) or {}).get("required", []) + if not all(k in args and args[k] for k in required): + passthrough = dict(body) + passthrough["messages"] = norm_messages + # Alleen deze tool meegeven + dwing deze tool af + only = [t for t in (body.get("tools") or []) if (t.get("function") or {}).get("name")==fname] + if only: passthrough["tools"] = only + passthrough["tool_choice"] = {"type":"function","function":{"name": fname}} + passthrough["stream"] = False + client = app.state.HTTPX + r = await client.post(LLM_URL, json=passthrough) + try: + return JSONResponse(r.json(), status_code=r.status_code) + except Exception: + return PlainTextResponse(r.text, status_code=r.status_code) + + + # Heuristiek geslaagd → stuur tool_calls terug (compat met OWUI) + return { + "id": f"chatcmpl-{uuid.uuid4().hex}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "finish_reason": "tool_calls", + "message": { + "role": "assistant", + "tool_calls": [{ + "id": f"call_{uuid.uuid4().hex[:8]}", + "type": "function", + "function": {"name": fname, "arguments": json.dumps(args, ensure_ascii=False)} + }] + } + }] + } + + # Snelle escape: bij streaming en geen expliciete 'required' tool -> forceer directe streaming + if stream and tools and tool_choice_req in (None, "auto", "none") and \ + os.getenv("STREAM_PREFER_DIRECT", "1").lower() not in ("0","false","no"): + tools = [] # bypass tool glue zodat we rechtstreeks naar de echte streaming gaan + + # (2) Auto: vraag de LLM om 1+ function calls te produceren + if (tool_choice_req in (None, "auto")) and tools: + sys = _build_tools_system_prompt(tools) + ask = [{"role": "system", "content": sys}] + norm_messages + # jouw bestaande helper; hou 'm zoals je al gebruikt + resp = await llm_call_openai_compat(ask, stream=False, max_tokens=512) + txt = ((resp.get("choices") or [{}])[0].get("message") or {}).get("content", "") or "" + calls = _extract_tool_calls_from_text(txt) + if calls: + return { + "id": f"chatcmpl-{uuid.uuid4().hex}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{ + "index": 0, + "finish_reason": "tool_calls", + "message": {"role": "assistant", "tool_calls": calls} + }] + } + # --- einde minimal tool-calling glue --- + + # Vision normalisatie (na tool->tekst normalisatie) + msgs, images_b64 = _normalize_openai_vision_messages(norm_messages) + messages = msgs if msgs else norm_messages + extra_payload = {"images": images_b64} if images_b64 else {} + + # Speciale modellen + if model == "repo-agent": + if stream: + async def event_gen(): + import asyncio, time, json, contextlib + # Stuur meteen een role-delta om bytes te pushen + head = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk", + "created": int(time.time()),"model": model, + "choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason": None}]} + yield f"data: {json.dumps(head)}\n\n" + HEARTBEAT = float(os.getenv("SSE_HEARTBEAT_SEC","10")) + done = asyncio.Event() + result = {"text": ""} + # geef wat vroege status in de buitenste generator (hier mag je wel yielden) + for s in ("🔎 Indexeren/zoeken…", "🧠 Analyseren…", "🛠️ Voorstel genereren…"): + data = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk", + "created": int(time.time()),"model": model, + "choices":[{"index":0,"delta":{"content": s + "\n"},"finish_reason": None}]} + yield f"data: {json.dumps(data)}\n\n" + await asyncio.sleep(0) + + async def _work(): + try: + result["text"] = await handle_repo_agent(messages, request) + finally: + done.set() + worker = asyncio.create_task(_work()) + try: + while not done.is_set(): + try: + await asyncio.wait_for(done.wait(), timeout=HEARTBEAT) + except asyncio.TimeoutError: + yield ": ping\n\n" + # klaar → stuur content + data = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk", + "created": int(time.time()),"model": model, + "choices":[{"index":0,"delta":{"content": result['text']}, "finish_reason": None}]} + yield f"data: {json.dumps(data)}\n\n" + yield "data: [DONE]\n\n" + finally: + worker.cancel() + with contextlib.suppress(Exception): + await worker + return StreamingResponse( + event_gen(), + media_type="text/event-stream", + headers={"Cache-Control":"no-cache, no-transform","X-Accel-Buffering":"no","Connection":"keep-alive"} + ) + else: + text = await handle_repo_agent(messages, request) + return JSONResponse(_openai_chat_response(model, text, messages)) + + + if model == "repo-qa": + repo_hint, question, branch, n_ctx = _parse_repo_qa_from_messages(messages) + if not repo_hint: + friendly = ("Geef een repo-hint (URL of owner/repo).\nVoorbeeld:\n" + "repo: (repo URL eindigend op .git)\n" + "Vraag: (de vraag die je hebt over de code in de repo)") + return JSONResponse(_openai_chat_response("repo-qa", friendly, messages)) + try: + text = await repo_qa_answer(repo_hint, question, branch=branch, n_ctx=n_ctx) + except Exception as e: + logging.exception("repo-qa failed") + text = f"Repo-QA faalde: {e}" + if stream: + async def event_gen(): + data = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk","created": int(time.time()), + "model": "repo-qa", "choices":[{"index":0,"delta":{"content": text},"finish_reason": None}]} + yield f"data: {json.dumps(data)}\n\n"; yield "data: [DONE]\n\n" + return StreamingResponse(event_gen(), media_type="text/event-stream") + return JSONResponse(_openai_chat_response("repo-qa", text, messages)) + + # --- Tool-calling --- + tools_payload = body.get("tools") + tool_choice_req = body.get("tool_choice", "auto") + if tools_payload: + # native passthrough (llama.cpp voert tools NIET zelf uit, maar UI kan dit willen) + if LLM_FUNCTION_CALLING_MODE in ("native","auto") and stream: + passthrough = dict(body); passthrough["messages"]=messages + if images_b64: passthrough["images"]=images_b64 + async def _aiter(): + import asyncio, contextlib + client = app.state.HTTPX + async with client.stream("POST", LLM_URL, json=passthrough, timeout=None) as r: + r.raise_for_status() + HEARTBEAT = float(os.getenv("SSE_HEARTBEAT_SEC","10")) + q: asyncio.Queue[bytes] = asyncio.Queue(maxsize=100) + async def _reader(): + try: + async for chunk in r.aiter_bytes(): + if chunk: + await q.put(chunk) + except Exception: + pass + finally: + await q.put(b"__EOF__") + reader_task = asyncio.create_task(_reader()) + try: + while True: + try: + chunk = await asyncio.wait_for(q.get(), timeout=HEARTBEAT) + except asyncio.TimeoutError: + yield b": ping\n\n" + continue + if chunk == b"__EOF__": + break + yield chunk + finally: + reader_task.cancel() + with contextlib.suppress(Exception): + await reader_task + return StreamingResponse( + _aiter(), + media_type="text/event-stream", + headers={"Cache-Control":"no-cache, no-transform","X-Accel-Buffering":"no","Connection":"keep-alive"} + ) + + if LLM_FUNCTION_CALLING_MODE in ("native","auto") and not stream: + # Relay-modus: laat LLM tools kiezen, bridge voert uit, daarna 2e run. + relay = os.getenv("LLM_TOOL_RUNNER", "passthrough").lower() == "bridge" + client = app.state.HTTPX + if not relay: + passthrough = dict(body); passthrough["messages"]=messages + if images_b64: passthrough["images"]=images_b64 + r = await client.post(LLM_URL, json=passthrough) + try: + return JSONResponse(r.json(), status_code=r.status_code) + except Exception: + return PlainTextResponse(r.text, status_code=r.status_code) + + # (A) 1e call: vraag de LLM om tool_calls (geen stream) + first_req = dict(body) + first_req["messages"] = messages + first_req["stream"] = False + if images_b64: first_req["images"] = images_b64 + r1 = await client.post(LLM_URL, json=first_req) + try: + data1 = r1.json() + except Exception: + return PlainTextResponse(r1.text, status_code=r1.status_code) + msg1 = ((data1.get("choices") or [{}])[0].get("message") or {}) + tool_calls = msg1.get("tool_calls") or [] + # Geen tool-calls? Geef direct door. + if not tool_calls: + return JSONResponse(data1, status_code=r1.status_code) + + # (B) voer tool_calls lokaal uit + tool_msgs = [] + for tc in tool_calls: + fn = ((tc or {}).get("function") or {}) + tname = fn.get("name") + raw_args = fn.get("arguments") or "{}" + try: + args = json.loads(raw_args) if isinstance(raw_args, str) else (raw_args or {}) + except Exception: + args = {} + if not tname or tname not in TOOLS_REGISTRY: + out = {"error": f"Unknown tool '{tname}'"} + else: + try: + out = await _execute_tool(tname, args) + except Exception as e: + out = {"error": str(e)} + tool_msgs.append({ + "role": "tool", + "tool_call_id": tc.get("id"), + "name": tname or "unknown", + "content": json.dumps(out, ensure_ascii=False) + }) + + # (C) 2e call: geef tool outputs terug aan LLM voor eindantwoord + follow_messages = messages + [ + {"role": "assistant", "tool_calls": tool_calls}, + *tool_msgs + ] + second_req = dict(body) + second_req["messages"] = follow_messages + second_req["stream"] = False + # images opnieuw meesturen is niet nodig, maar kan geen kwaad: + if images_b64: second_req["images"] = images_b64 + r2 = await client.post(LLM_URL, json=second_req) + try: + return JSONResponse(r2.json(), status_code=r2.status_code) + except Exception: + return PlainTextResponse(r2.text, status_code=r2.status_code) + + # shim (non-stream) + if LLM_FUNCTION_CALLING_MODE == "shim" and not stream: + # 1) Laat LLM beslissen WELKE tool + args (strikte JSON) + tool_lines = [] + for tname, t in TOOLS_REGISTRY.items(): + tool_lines.append(f"- {tname}: {t['description']}\n schema: {json.dumps(t['parameters'])}") + sys = ("You can call tools.\n" + "If a tool is needed, reply with ONLY valid JSON:\n" + '{"call_tool":{"name":"","arguments":{...}}}\n' + "Otherwise reply with ONLY: {\"final_answer\":\"...\"}\n\nTools:\n" + "\n".join(tool_lines)) + decide = await llm_call_openai_compat( + [{"role":"system","content":sys}] + messages, + stream=False, temperature=float(body.get("temperature",0.2)), + top_p=float(body.get("top_p",0.9)), max_tokens=min(512, int(body.get("max_tokens",1024))), + extra=extra_payload if extra_payload else None + ) + raw = ((decide.get("choices",[{}])[0].get("message",{}) or {}).get("content","") or "").strip() + # haal evt. ```json fences weg + if raw.startswith("```"): + raw = extract_code_block(raw) + try: + obj = json.loads(raw) + except Exception: + # Model gaf geen JSON → behandel als final_answer + return JSONResponse(decide) + + if "final_answer" in obj: + return JSONResponse({ + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [{"index":0,"message":{"role":"assistant","content": obj["final_answer"]},"finish_reason":"stop"}], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + }) + + call = (obj.get("call_tool") or {}) + tname = call.get("name") + if tname not in TOOLS_REGISTRY: + return JSONResponse(_openai_chat_response(model, f"Onbekende tool: {tname}", messages)) + args = call.get("arguments") or {} + tool_out = await _execute_tool(tname, args) + + follow = messages + [ + {"role":"assistant","content": f"TOOL[{tname}] OUTPUT:\n{json.dumps(tool_out)[:8000]}"}, + {"role":"user","content": "Gebruik bovenstaande tool-output en geef nu het definitieve antwoord."} + ] + final = await llm_call_openai_compat( + follow, stream=False, + temperature=float(body.get("temperature",0.2)), + top_p=float(body.get("top_p",0.9)), + max_tokens=int(body.get("max_tokens",1024)), + extra=extra_payload if extra_payload else None + ) + return JSONResponse(final) + + # --- streaming passthrough (geen tools) --- + # --- streaming via onze wachtrij + upstreams --- + # --- streaming (zonder tools): nu mét windowing trim --- + # --- streaming (zonder tools): windowing + server-side auto-continue stream --- + simulate_stream=False + if simulate_stream: + if stream: + LLM_WINDOWING_ENABLE = os.getenv("LLM_WINDOWING_ENABLE", "1").lower() not in ("0","false","no") + MAX_CTX_TOKENS = int(os.getenv("LLM_CONTEXT_TOKENS", "13021")) + RESP_RESERVE = int(os.getenv("LLM_RESPONSE_RESERVE", "1024")) + temperature = float(body.get("temperature", 0.2)) + top_p = float(body.get("top_p", 0.9)) + # respecteer env-override voor default + _default_max = int(os.getenv("LLM_DEFAULT_MAX_TOKENS", "1024")) + max_tokens = int(body.get("max_tokens", _default_max)) + MAX_AUTOCONT = int(os.getenv("LLM_AUTO_CONTINUES", "3")) + + trimmed_stream_msgs = messages + try: + if LLM_WINDOWING_ENABLE and 'ConversationWindow' in globals(): + #thread_id = derive_thread_id({"model": model, "messages": messages}) if 'derive_thread_id' in globals() else uuid.uuid4().hex + thread_id = derive_thread_id(messages) if 'derive_thread_id' in globals() else uuid.uuid4().hex + running_summary = SUMMARY_STORE.get(thread_id, "") if isinstance(SUMMARY_STORE, dict) else (SUMMARY_STORE.get(thread_id) or "") + win = ConversationWindow(max_ctx_tokens=MAX_CTX_TOKENS, response_reserve=RESP_RESERVE, + tok_len=approx_token_count, running_summary=running_summary, + summary_header="Samenvatting tot nu toe") + for m in messages: + role, content = m.get("role","user"), m.get("content","") + if role in ("system","user","assistant"): + win.add(role, content) + async def _summarizer(old: str, chunk_msgs: list[dict]) -> str: + chunk_text = "" + for m in chunk_msgs: + chunk_text += f"\n[{m.get('role','user')}] {m.get('content','')}" + prompt = [ + {"role":"system","content":"Je bent een bondige notulist. Vat samen in max 10 bullets (feiten/besluiten/acties)."}, + {"role":"user","content": f"Vorige samenvatting:\n{old}\n\nNieuwe geschiedenis:\n{chunk_text}\n\nGeef geüpdatete samenvatting."} + ] + resp = await llm_call_openai_compat(prompt, stream=False, temperature=0.1, top_p=1.0, max_tokens=300) + return (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content", old or "") + trimmed_stream_msgs = await win.build_within_budget(system_prompt=None, summarizer=_summarizer) + new_summary = getattr(win, "running_summary", running_summary) + if isinstance(SUMMARY_STORE, dict): + if new_summary and new_summary != running_summary: + SUMMARY_STORE[thread_id] = new_summary + else: + try: + if new_summary and new_summary != running_summary: + SUMMARY_STORE.update(thread_id, new_summary) # type: ignore[attr-defined] + except Exception: + pass + except Exception: + trimmed_stream_msgs = messages + + # 1) haal volledige tekst op met non-stream + auto-continue + full_text, last_resp = await _complete_with_autocontinue( + trimmed_stream_msgs, + model=model, temperature=temperature, top_p=top_p, + max_tokens=max_tokens, extra_payload=extra_payload, + max_autocont=MAX_AUTOCONT + ) + + # 2) stream deze tekst als SSE in hapjes (simuleer live tokens) + async def _emit_sse(): + created = int(time.time()) + model_name = model + CHUNK = int(os.getenv("LLM_STREAM_CHUNK_CHARS", "800")) + # optioneel: stuur meteen role-delta (sommige UIs waarderen dat) + head = { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [{"index":0,"delta":{"role":"assistant"},"finish_reason": None}] + } + yield ("data: " + json.dumps(head, ensure_ascii=False) + "\n\n").encode("utf-8") + # content in blokken + for i in range(0, len(full_text), CHUNK): + piece = full_text[i:i+CHUNK] + data = { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [{"index":0,"delta":{"content": piece},"finish_reason": None}] + } + yield ("data: " + json.dumps(data, ensure_ascii=False) + "\n\n").encode("utf-8") + await asyncio.sleep(0) # laat event loop ademen + # afsluiten + done = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk", + "created": int(time.time()),"model": model_name, + "choices": [{"index":0,"delta":{},"finish_reason":"stop"}]} + yield ("data: " + json.dumps(done, ensure_ascii=False) + "\n\n").encode("utf-8") + yield b"data: [DONE]\n\n" + + return StreamingResponse( + _emit_sse(), + media_type="text/event-stream", + headers={"Cache-Control":"no-cache, no-transform","X-Accel-Buffering":"no","Connection":"keep-alive"} + ) + else: + # --- ÉCHTE streaming (geen tools): direct passthrough met heartbeats --- + if stream: + temperature = float(body.get("temperature", 0.2)) + top_p = float(body.get("top_p", 0.9)) + _default_max = int(os.getenv("LLM_DEFAULT_MAX_TOKENS", "1024")) + max_tokens = int(body.get("max_tokens", _default_max)) + return await llm_call_openai_compat( + messages, + model=model, + stream=True, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + extra=extra_payload if extra_payload else None + ) + + + + # --- non-stream: windowing + auto-continue (zoals eerder gepatcht) --- + LLM_WINDOWING_ENABLE = os.getenv("LLM_WINDOWING_ENABLE", "1").lower() not in ("0","false","no") + MAX_CTX_TOKENS = int(os.getenv("LLM_CONTEXT_TOKENS", "13021")) + RESP_RESERVE = int(os.getenv("LLM_RESPONSE_RESERVE", "1024")) + MAX_AUTOCONT = int(os.getenv("LLM_AUTO_CONTINUES", "2")) + temperature = float(body.get("temperature", 0.2)) + top_p = float(body.get("top_p", 0.9)) + # Laat env de default bepalen, zodat OWUI niet hard op 1024 blijft hangen + _default_max = int(os.getenv("LLM_DEFAULT_MAX_TOKENS", "13027")) + max_tokens = int(body.get("max_tokens", _default_max)) + + trimmed = messages + try: + if LLM_WINDOWING_ENABLE and 'ConversationWindow' in globals(): + #thread_id = derive_thread_id({"model": model, "messages": messages}) if 'derive_thread_id' in globals() else uuid.uuid4().hex + thread_id = derive_thread_id(messages) if 'derive_thread_id' in globals() else uuid.uuid4().hex + running_summary = SUMMARY_STORE.get(thread_id, "") if isinstance(SUMMARY_STORE, dict) else (SUMMARY_STORE.get(thread_id) or "") + win = ConversationWindow(max_ctx_tokens=MAX_CTX_TOKENS, response_reserve=RESP_RESERVE, + tok_len=approx_token_count, running_summary=running_summary, + summary_header="Samenvatting tot nu toe") + for m in messages: + role, content = m.get("role","user"), m.get("content","") + if role in ("system","user","assistant"): + win.add(role, content) + async def _summarizer(old: str, chunk_msgs: list[dict]) -> str: + chunk_text = "" + for m in chunk_msgs: + chunk_text += f"\n[{m.get('role','user')}] {m.get('content','')}" + prompt = [ + {"role":"system","content":"Je bent een bondige notulist. Vat samen in max 10 bullets (feiten/besluiten/acties)."}, + {"role":"user","content": f"Vorige samenvatting:\n{old}\n\nNieuwe geschiedenis:\n{chunk_text}\n\nGeef geüpdatete samenvatting."} + ] + resp = await llm_call_openai_compat(prompt, stream=False, temperature=0.1, top_p=1.0, max_tokens=300) + return (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content", old or "") + trimmed = await win.build_within_budget(system_prompt=None, summarizer=_summarizer) + new_summary = getattr(win, "running_summary", running_summary) + if isinstance(SUMMARY_STORE, dict): + if new_summary and new_summary != running_summary: + SUMMARY_STORE[thread_id] = new_summary + else: + try: + if new_summary and new_summary != running_summary: + SUMMARY_STORE.update(thread_id, new_summary) # type: ignore[attr-defined] + except Exception: + pass + except Exception: + trimmed = messages + + resp = await llm_call_openai_compat(trimmed, model=model, stream=False, + temperature=temperature, top_p=top_p, max_tokens=max_tokens, + extra=extra_payload if extra_payload else None) + content = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","") + finish_reason = (resp.get("choices",[{}])[0] or {}).get("finish_reason") + + continues = 0 + # extra trigger: als we nagenoeg max_tokens geraakt hebben, ga door + NEAR_MAX = int(os.getenv("LLM_AUTOCONT_NEAR_MAX", "48")) + def _near_cap(txt: str) -> bool: + try: + return approx_token_count(txt) >= max_tokens - NEAR_MAX + except Exception: + return False + while continues < MAX_AUTOCONT and ( + finish_reason in ("length","content_filter") + or content.endswith("…") + or _near_cap(content) + ): + follow = trimmed + [{"role":"assistant","content": content}, + {"role":"user","content": "Ga verder zonder herhaling; alleen het vervolg."}] + nxt = await llm_call_openai_compat(follow, model=model, stream=False, + temperature=temperature, top_p=top_p, max_tokens=max_tokens, + extra=extra_payload if extra_payload else None) + more = (nxt.get("choices",[{}])[0].get("message",{}) or {}).get("content","") + content = (content + ("\n" if content and more else "") + more).strip() + finish_reason = (nxt.get("choices",[{}])[0] or {}).get("finish_reason") + continues += 1 + + if content: + resp["choices"][0]["message"]["content"] = content + resp["choices"][0]["finish_reason"] = finish_reason or resp["choices"][0].get("finish_reason", "stop") + try: + prompt_tokens = count_message_tokens(trimmed) if 'count_message_tokens' in globals() else approx_token_count(json.dumps(trimmed)) + except Exception: + prompt_tokens = approx_token_count(json.dumps(trimmed)) + completion_tokens = approx_token_count(content) + resp["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens} + return JSONResponse(resp) + + + +# ----------------------------------------------------------------------------- +# RAG Index (repo → Chroma) +# ----------------------------------------------------------------------------- + +_REPO_CACHE_PATH = os.path.join(CHROMA_PATH or "/rag_db", "repo_index_cache.json") +_SYMBOL_INDEX_PATH = os.path.join(CHROMA_PATH or "/rag_db", "symbol_index.json") + +def _load_symbol_index() -> dict: + try: + with open(_SYMBOL_INDEX_PATH, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + +def _save_symbol_index(data: dict): + try: + _atomic_json_write(_SYMBOL_INDEX_PATH, data) + except Exception: + pass + +_SYMBOL_INDEX: dict[str, dict[str, list[dict]]] = _load_symbol_index() + +def _symkey(collection_effective: str, repo_name: str) -> str: + return f"{collection_effective}|{repo_name}" + + +def _load_repo_index_cache() -> dict: + try: + with open(_REPO_CACHE_PATH, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + +def _save_repo_index_cache(data: dict): + try: + _atomic_json_write(_REPO_CACHE_PATH, data) + except Exception: + pass + +def _sha1_text(t: str) -> str: + return hashlib.sha1(t.encode("utf-8", errors="ignore")).hexdigest() + +def _repo_owner_repo_from_url(u: str) -> str: + try: + from urllib.parse import urlparse + p = urlparse(u).path.strip("/").split("/") + if len(p) >= 2: + name = p[-1] + if name.endswith(".git"): name = name[:-4] + return f"{p[-2]}/{name}" + except Exception: + pass + # fallback: basename + return os.path.basename(u).removesuffix(".git") + +def _summary_store_path(repo_url: str, branch: str) -> str: + key = _repo_owner_repo_from_url(repo_url).replace("/", "__") + return os.path.join(_SUMMARY_DIR, f"{key}__{branch}__{_EMBEDDER.slug}.json") + +def _load_summary_store(path: str) -> dict: + try: + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return {} + +def _save_summary_store(path: str, data: dict): + try: + _atomic_json_write(path, data) + except Exception: + pass + +async def _summarize_files_llm(items: list[tuple[str, str]]) -> dict[str, str]: + """ + items: list of (path, text) -> returns {path: one-line summary} + On-demand, tiny prompts; keep it cheap. + """ + out: dict[str, str] = {} + for path, text in items: + snippet = text[:2000] + prompt = [ + {"role":"system","content":"Geef één korte, functionele beschrijving (max 20 woorden) van dit bestand. Geen opsomming, geen code, 1 zin."}, + {"role":"user","content": f"Pad: {path}\n\nInhoud (ingekort):\n{snippet}\n\nAntwoord: "} + ] + try: + resp = await llm_call_openai_compat(prompt, stream=False, temperature=0.1, top_p=1.0, max_tokens=64) + summ = ((resp.get("choices") or [{}])[0].get("message") or {}).get("content","").strip() + except Exception: + summ = "" + out[path] = summ or "Bestand (korte beschrijving niet beschikbaar)" + return out + +async def _repo_summary_get_internal(repo_url: str, branch: str, paths: list[str]) -> dict[str, dict]: + """ + Returns {path: {"summary": str, "sha": str}}. Caches per file SHA, auto-invalidates when changed. + """ + repo_path = await _get_git_repo_async(repo_url, branch) + store_path = _summary_store_path(repo_url, branch) + store = _load_summary_store(store_path) # structure: {path: {"sha": "...", "summary": "..."}} + to_summarize: list[tuple[str, str]] = [] + for rel in paths: + p = Path(repo_path) / rel + if not p.exists(): + continue + text = _read_text_file(p) + fsha = _sha1_text(text) + rec = store.get(rel) or {} + if rec.get("sha") != fsha or not rec.get("summary"): + to_summarize.append((rel, text)) + if to_summarize: + new_summ = await _summarize_files_llm(to_summarize) + for rel, _ in to_summarize: + p = Path(repo_path) / rel + if not p.exists(): + continue + text = _read_text_file(p) + store[rel] = {"sha": _sha1_text(text), "summary": new_summ.get(rel, "")} + _save_summary_store(store_path, store) + # pack result + return {rel: {"summary": (store.get(rel) or {}).get("summary",""), "sha": (store.get(rel) or {}).get("sha","")} + for rel in paths} + +async def _grep_repo_paths(repo_url: str, branch: str, query: str, limit: int = 100) -> list[str]: + """ + Simpele fallback: lineaire scan op tekst-inhoud (alleen text-exts) -> unieke paden. + """ + repo_path = await _get_git_repo_async(repo_url, branch) + root = Path(repo_path) + qlow = (query or "").lower().strip() + hits: list[str] = [] + for p in root.rglob("*"): + if p.is_dir(): + continue + if set(p.parts) & PROFILE_EXCLUDE_DIRS: + continue + if p.suffix.lower() in BINARY_SKIP: + continue + try: + txt = _read_text_file(p) + except Exception: + continue + if not txt: + continue + if qlow in txt.lower(): + hits.append(str(p.relative_to(root))) + if len(hits) >= limit: + break + # uniq (preserve order) + seen=set(); out=[] + for h in hits: + if h in seen: continue + seen.add(h); out.append(h) + return out + +async def _meili_search_internal(query: str, *, repo_full: str | None, branch: str | None, limit: int = 50) -> list[dict]: + """ + Returns list of hits: {"path": str, "chunk_index": int|None, "score": float, "highlights": str} + """ + if not MEILI_ENABLED: + return [] + body = {"q": query, "limit": int(limit)} + filters = [] + if repo_full: + filters.append(f'repo_full = "{repo_full}"') + if branch: + filters.append(f'branch = "{branch}"') + if filters: + body["filter"] = " AND ".join(filters) + headers = {"Content-Type":"application/json"} + if MEILI_API_KEY: + headers["Authorization"] = f"Bearer {MEILI_API_KEY}" + # Meili ook vaak 'X-Meili-API-Key'; houd beide in ere: + headers["X-Meili-API-Key"] = MEILI_API_KEY + url = f"{MEILI_URL}/indexes/{MEILI_INDEX}/search" + client = app.state.HTTPX_PROXY if hasattr(app.state, "HTTPX_PROXY") else httpx.AsyncClient() + try: + r = await client.post(url, headers=headers, json=body, timeout=httpx.Timeout(30.0, connect=10.0)) + r.raise_for_status() + data = r.json() + except Exception: + return [] + hits = [] + for h in data.get("hits", []): + hits.append({ + "path": h.get("path"), + "chunk_index": h.get("chunk_index"), + "score": float(h.get("_rankingScore", h.get("_matchesPosition") or 0) or 0), + "highlights": h.get("_formatted") or {} + }) + # De-dup op path, hoogste score eerst + best: dict[str, dict] = {} + for h in sorted(hits, key=lambda x: x["score"], reverse=True): + p = h.get("path") + if p and p not in best: + best[p] = h + return list(best.values()) + +async def _search_first_candidates(repo_url: str, branch: str, query: str, explicit_paths: list[str] | None = None, limit: int = 50) -> list[str]: + """ + Voorkeursvolgorde: expliciete paden → Meilisearch → grep → (laatste redmiddel) RAG op bestandsniveau + """ + if explicit_paths: + # valideer dat ze bestaan in de repo + repo_path = await _get_git_repo_async(repo_url, branch) + out=[] + for rel in explicit_paths: + if (Path(repo_path)/rel).exists(): + out.append(rel) + return out + repo_full = _repo_owner_repo_from_url(repo_url) + meili_hits = await _meili_search_internal(query, repo_full=repo_full, branch=branch, limit=limit) + if meili_hits: + return [h["path"] for h in meili_hits] + # fallback grep + return await _grep_repo_paths(repo_url, branch, query, limit=limit) + + +def _match_any(name: str, patterns: list[str]) -> bool: + return any(fnmatch.fnmatch(name, pat) for pat in patterns) + +def _rag_index_repo_sync( + *, + repo_url: str, + branch: str = "main", + profile: str = "auto", + include: str = "", + exclude_dirs: str = "", + chunk_chars: int = 3000, + overlap: int = 400, + collection_name: str = "code_docs", + force: bool = False +) -> dict: + repo_path = get_git_repo(repo_url, branch) + root = Path(repo_path) + repo = git.Repo(repo_path) + # HEAD sha + try: + head_sha = repo.head.commit.hexsha + except Exception: + head_sha = "" + # Cache key + cache = _load_repo_index_cache() + repo_key = f"{os.path.basename(repo_url)}|{branch}|{_collection_versioned(collection_name)}|{_EMBEDDER.slug}" + cached = cache.get(repo_key) + + if not force and cached and cached.get("head_sha") == head_sha: + return { + "status": "skipped", + "reason": "head_unchanged", + "collection_effective": _collection_versioned(collection_name), + "detected_profile": profile if profile != "auto" else _detect_repo_profile(root), + "files_indexed": cached.get("files_indexed", 0), + "chunks_added": 0, + "note": f"Skip: HEAD {head_sha} al geïndexeerd (embedder={_EMBEDDER.slug})" + } + + # Profile → include patterns + if profile == "auto": + profile = _detect_repo_profile(root) + include_patterns = PROFILE_INCLUDES.get(profile, PROFILE_INCLUDES["generic"]) + if include.strip(): + include_patterns = [p.strip() for p in include.split(",") if p.strip()] + + exclude_set = set(PROFILE_EXCLUDE_DIRS) + if exclude_dirs.strip(): + exclude_set |= {d.strip() for d in exclude_dirs.split(",") if d.strip()} + + collection = _get_collection(collection_name) + + # --- Slim chunking toggles (env of via profile='smart') --- + use_smart_chunk = ( + (os.getenv("CHROMA_SMART_CHUNK","1").lower() not in ("0","false","no")) + or (str(profile).lower() == "smart") + ) + CH_TGT = int(os.getenv("CHUNK_TARGET_CHARS", "1800")) + CH_MAX = int(os.getenv("CHUNK_HARD_MAX", "2600")) + CH_MIN = int(os.getenv("CHUNK_MIN_CHARS", "800")) + + files_indexed = 0 + chunks_added = 0 + batch_documents: list[str] = [] + batch_metadatas: list[dict] = [] + batch_ids: list[str] = [] + BATCH_SIZE = 64 + + def flush(): + nonlocal chunks_added + if not batch_documents: + return + _collection_add(collection, batch_documents, batch_metadatas, batch_ids) + chunks_added += len(batch_documents) + batch_documents.clear(); batch_metadatas.clear(); batch_ids.clear() + docs_for_bm25: list[dict]=[] + docs_for_meili: list[dict]=[] + # tijdelijke symbol map voor deze run + # structuur: symbol_lower -> [{"path": str, "chunk_index": int}] + run_symbol_map: dict[str, list[dict]] = {} + + def _extract_symbol_hints(txt: str) -> list[str]: + hints = [] + for pat in [ + r"\bclass\s+([A-Za-z_][A-Za-z0-9_]*)\b", + r"\binterface\s+([A-Za-z_][A-Za-z0-9_]*)\b", + r"\btrait\s+([A-Za-z_][A-Za-z0-9_]*)\b", + r"\bdef\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", + r"\bfunction\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", + r"\bpublic\s+function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", + r"\bprotected\s+function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", + r"\bprivate\s+function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", + ]: + try: + hints += re.findall(pat, txt) + except Exception: + pass + # unique, cap + out = [] + for h in hints: + if h not in out: + out.append(h) + if len(out) >= 16: + break + return out + # precompute owner/repo once + repo_full_pre = _repo_owner_repo_from_url(repo_url) + deleted_paths: set[str] = set() + for p in root.rglob("*"): + if p.is_dir(): + continue + if set(p.parts) & exclude_set: + continue + rel = str(p.relative_to(root)) + if not _match_any(rel, include_patterns): + continue + + #rel = str(p.relative_to(root)) + text = _read_text_file(p) + if not text: + continue + + files_indexed += 1 + # Verwijder bestaande chunks voor dit bestand (voorkomt bloat/duplicaten) + if rel not in deleted_paths: + try: + collection.delete(where={ + "repo_full": repo_full_pre, "branch": branch, "path": rel + }) + except Exception: + pass + deleted_paths.add(rel) + #chunks = _chunk_text(text, chunk_chars=int(chunk_chars), overlap=int(overlap)) + + # Kies slim chunking (taal/structuur-aware) of vaste chunks met overlap + if use_smart_chunk: + chunks = smart_chunk_text( + text, rel, + target_chars=CH_TGT, hard_max=CH_MAX, min_chunk=CH_MIN + ) + else: + chunks = _chunk_text(text, chunk_chars=int(chunk_chars), overlap=int(overlap)) + + for idx, ch in enumerate(chunks): + chunk_hash = _sha1_text(ch) + doc_id = f"{os.path.basename(repo_url)}:{branch}:{rel}:{idx}:{chunk_hash}" + # Zet zowel korte als lange repo-id neer + def _owner_repo_from_url(u: str) -> str: + try: + from urllib.parse import urlparse + p = urlparse(u).path.strip("/").split("/") + if len(p) >= 2: + name = p[-1] + if name.endswith(".git"): name = name[:-4] + return f"{p[-2]}/{name}" + except Exception: + pass + return os.path.basename(u).removesuffix(".git") + meta = { + "source": "git-repo", + "repo": os.path.basename(repo_url).removesuffix(".git"), + "repo_full": _owner_repo_from_url(repo_url), + "branch": branch, + "path": rel, + "chunk_index": idx, + "profile": profile, + } + batch_documents.append(ch) + docs_for_bm25.append({"text": ch, "path": rel}) + batch_metadatas.append(meta) + batch_ids.append(doc_id) + if MEILI_ENABLED: + docs_for_meili.append({ + "id": hashlib.sha1(f"{doc_id}".encode("utf-8")).hexdigest(), + "repo": meta["repo"], + "repo_full": meta["repo_full"], + "branch": branch, + "path": rel, + "chunk_index": idx, + "text": ch[:4000], # houd het compact + "head_sha": head_sha, + "ts": int(time.time()) + }) + + # verzamel symbolen per chunk voor symbol-index + if os.getenv("RAG_SYMBOL_INDEX", "1").lower() not in ("0","false","no"): + for sym in _extract_symbol_hints(ch): + run_symbol_map.setdefault(sym.lower(), []).append({"path": rel, "chunk_index": idx}) + if len(batch_documents) >= BATCH_SIZE: + flush() + flush() + if BM25Okapi is not None and docs_for_bm25: + bm = BM25Okapi([_bm25_tok(d["text"]) for d in docs_for_bm25]) + repo_key = f"{os.path.basename(repo_url)}|{branch}|{_collection_versioned(collection_name)}" + _BM25_BY_REPO[repo_key] = (bm, docs_for_bm25) + + cache[repo_key] = {"head_sha": head_sha, "files_indexed": files_indexed, "embedder": _EMBEDDER.slug, "ts": int(time.time())} + _save_repo_index_cache(cache) + + # --- merge & persist symbol index --- + if os.getenv("RAG_SYMBOL_INDEX", "1").lower() not in ("0","false","no"): + col_eff = _collection_versioned(collection_name) + repo_name = os.path.basename(repo_url) + key = _symkey(col_eff, repo_name) + base = _SYMBOL_INDEX.get(key, {}) + # merge met dedupe en cap per symbol (om bloat te vermijden) + CAP = int(os.getenv("RAG_SYMBOL_CAP_PER_SYM", "200")) + for sym, entries in run_symbol_map.items(): + dst = base.get(sym, []) + seen = {(d["path"], d["chunk_index"]) for d in dst} + for e in entries: + tup = (e["path"], e["chunk_index"]) + if tup not in seen: + dst.append(e); seen.add(tup) + if len(dst) >= CAP: + break + base[sym] = dst + _SYMBOL_INDEX[key] = base + _save_symbol_index(_SYMBOL_INDEX) + + # === (optioneel) Meilisearch bulk upsert === + # NB: pas NA return plaatsen heeft geen effect; plaats dit blok vóór de return in jouw file. + if MEILI_ENABLED and docs_for_meili: + try: + headers = {"Content-Type":"application/json"} + if MEILI_API_KEY: + headers["Authorization"] = f"Bearer {MEILI_API_KEY}" + headers["X-Meili-API-Key"] = MEILI_API_KEY + url = f"{MEILI_URL}/indexes/{MEILI_INDEX}/documents?primaryKey=id" + # chunk upload om payloads klein te houden + CH = 500 + for i in range(0, len(docs_for_meili), CH): + chunk = docs_for_meili[i:i+CH] + requests.post(url, headers=headers, data=json.dumps(chunk), timeout=30) + except Exception as e: + logger.warning("Meili bulk upsert failed: %s", e) + + return { + "status": "ok", + "collection_effective": _collection_versioned(collection_name), + "detected_profile": profile, + "files_indexed": files_indexed, + "chunks_added": chunks_added, + "note": "Geïndexeerd met embedder=%s; smart_chunk=%s" % (_EMBEDDER.slug, bool(use_smart_chunk)) + } + + + + + +# Celery task +if celery_app: + @celery_app.task(name="task_rag_index_repo", bind=True, autoretry_for=(Exception,), retry_backoff=True, max_retries=5) + def task_rag_index_repo(self, args: dict): + return _rag_index_repo_sync(**args) + +# API endpoint: sync of async enqueue +@app.post("/rag/index-repo") +async def rag_index_repo( + repo_url: str = Form(...), + branch: str = Form("main"), + profile: str = Form("auto"), + include: str = Form(""), + exclude_dirs: str = Form(""), + chunk_chars: int = Form(3000), + overlap: int = Form(400), + collection_name: str = Form("code_docs"), + force: bool = Form(False), + async_enqueue: bool = Form(False) +): + args = dict( + repo_url=repo_url, branch=branch, profile=profile, + include=include, exclude_dirs=exclude_dirs, + chunk_chars=int(chunk_chars), overlap=int(overlap), + collection_name=collection_name, force=bool(force) + ) + if async_enqueue and celery_app: + task = task_rag_index_repo.delay(args) # type: ignore[name-defined] + return {"status": "enqueued", "task_id": task.id} + try: + return await run_in_threadpool(_rag_index_repo_sync, **args) + except Exception as e: + logger.exception("RAG index repo failed") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/tasks/{task_id}") +def task_status(task_id: str): + if not celery_app: + raise HTTPException(status_code=400, detail="Celery niet geactiveerd") + try: + from celery.result import AsyncResult + res = AsyncResult(task_id, app=celery_app) + out = {"state": res.state} + if res.successful(): + out["result"] = res.result + elif res.failed(): + out["error"] = str(res.result) + return out + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +# ----------------------------------------------------------------------------- +# RAG Query (Chroma) +# ----------------------------------------------------------------------------- +_LARAVEL_ROUTE_FILES = {"routes/web.php", "routes/api.php"} + +def _laravel_guess_view_paths_from_text(txt: str) -> list[str]: + out = [] + for m in re.finditer(r"return\s+view\(\s*['\"]([^'\"]+)['\"]", txt): + v = m.group(1).replace(".", "/") + out.append(f"resources/views/{v}.blade.php") + return out + +def _laravel_pairs_from_route_text(txt: str) -> list[tuple[str, str|None]]: + """Return [(controller_path, method_or_None), ...]""" + pairs = [] + # Route::get('x', [XController::class, 'show']) + for m in re.finditer(r"Route::(?:get|post|put|patch|delete)\([^;]+?\[(.*?)\]\s*\)", txt, re.S): + arr = m.group(1) + mm = re.search(r"([A-Za-z0-9_\\]+)::class\s*,\s*'([A-Za-z0-9_]+)'", arr) + if mm: + ctrl = mm.group(1).split("\\")[-1] + meth = mm.group(2) + pairs.append((f"app/Http/Controllers/{ctrl}.php", meth)) + # Route::resource('users', 'UserController') + for m in re.finditer(r"Route::resource\(\s*'([^']+)'\s*,\s*'([^']+)'\s*\)", txt): + res, ctrl = m.group(1), m.group(2).split("\\")[-1] + pairs.append((f"app/Http/Controllers/{ctrl}.php", None)) + return pairs + +async def rag_query_api( + *, + query: str, + n_results: int = 8, + collection_name: str = "code_docs", + repo: Optional[str] = None, + path_contains: Optional[str] = None, + profile: Optional[str] = None +) -> dict: + col = _get_collection(collection_name) + q_emb = _EMBEDDER.embed_query(query) + where = {} + if repo: + # Accepteer zowel 'repo' (basename) als 'repo_full' (owner/repo) + base = repo.rsplit("/", 1)[-1] + where = {"$or": [ + {"repo": {"$eq": base}}, + {"repo_full": {"$eq": repo}} + ]} + if profile: where["profile"] = {"$eq": profile} + + # ---- symbol hit set (repo-scoped) ---- + sym_hit_keys: set[str] = set() + sym_hit_weight = float(os.getenv("RAG_SYMBOL_HIT_W", "0.12")) + if sym_hit_weight > 0.0 and repo: + try: + # index werd opgeslagen onder basename(repo_url); hier dus normaliseren + repo_base = repo.rsplit("/", 1)[-1] + idx_key = _symkey(_collection_versioned(collection_name), repo_base) + symidx = _SYMBOL_INDEX.get(idx_key, {}) + # query-termen + terms = set(re.findall(r"[A-Za-z_][A-Za-z0-9_]*", query)) + # ook 'Foo\\BarController@show' splitsen + for m in re.finditer(r"([A-Za-z_][A-Za-z0-9_]*)(?:@|::|->)([A-Za-z_][A-Za-z0-9_]*)", query): + terms.add(m.group(1)); terms.add(m.group(2)) + for t in {x.lower() for x in terms}: + for e in symidx.get(t, []): + sym_hit_keys.add(f"{repo}::{e['path']}::{e['chunk_index']}") + except Exception: + pass + + # --- padhints uit query halen (expliciet genoemde bestanden/dirs) --- + # Herken ook gequote varianten en slimme quotes. + path_hints: set[str] = set() + PH_PATTERNS = [ + r"[\"“”'](resources\/[A-Za-z0-9_\/\.-]+\.blade\.php)[\"”']", + r"(resources\/[A-Za-z0-9_\/\.-]+\.blade\.php)", + r"[\"“”'](app\/[A-Za-z0-9_\/\.-]+\.php)[\"”']", + r"(app\/[A-Za-z0-9_\/\.-]+\.php)", + r"\b([A-Za-z0-9_\/-]+\.blade\.php)\b", + r"\b([A-Za-z0-9_\/-]+\.php)\b", + ] + for pat in PH_PATTERNS: + for m in re.finditer(pat, query): + path_hints.add(m.group(1).strip()) + + res = col.query( + query_embeddings=[q_emb], + n_results=max(n_results, 30), # iets ruimer voor rerank + where=where or None, + include=["metadatas","documents","distances"] + ) + logger.info("RAG raw hits: %d (repo=%s)", len((res.get("documents") or [[]])[0]), repo or "-") + docs = (res.get("documents") or [[]])[0] + metas = (res.get("metadatas") or [[]])[0] + dists = (res.get("distances") or [[]])[0] + + # Filter path_contains en bouw kandidaten + cands = [] + # Voor frase-boost: haal een simpele “candidate phrase” uit query (>=2 woorden) + def _candidate_phrase(q: str) -> str | None: + q = q.strip().strip('"').strip("'") + words = re.findall(r"[A-Za-zÀ-ÿ0-9_+-]{2,}", q) + if len(words) >= 2: + return " ".join(words[:6]).lower() + return None + phrase = _candidate_phrase(query) + + for doc, meta, dist in zip(docs, metas, dists): + if path_contains and meta and path_contains.lower() not in (meta.get("path","").lower()): + continue + key = f"{(meta or {}).get('repo','')}::{(meta or {}).get('path','')}::{(meta or {}).get('chunk_index','')}" + emb_sim = 1.0 / (1.0 + float(dist or 0.0)) + # symbol-hit boost vóór verdere combinaties + if key in sym_hit_keys: + emb_sim = min(1.0, emb_sim + sym_hit_weight) + c = { + "document": doc or "", + "metadata": meta or {}, + "emb_sim": emb_sim, # afstand -> ~similarity (+symbol boost) + "distance": float(dist or 0.0) # bewaar ook de ruwe distance voor hybride retrieval + } + # --- extra boost: expliciet genoemde paths in query --- + p = (meta or {}).get("path","") + if path_hints: + for hint in path_hints: + if hint and (hint == p or hint.split("/")[-1] == p.split("/")[-1] or hint in p): + c["emb_sim"] = min(1.0, c["emb_sim"] + 0.12) + break + # --- extra boost: exacte frase in document (case-insensitive) --- + if phrase and phrase in (doc or "").lower(): + c["emb_sim"] = min(1.0, c["emb_sim"] + 0.10) + cands.append(c) + # --- Als er expliciete path_hints zijn, probeer ze "hard" toe te voegen bovenaan --- + # Dit helpt o.a. bij 'resources/views/log/edit.blade.php' exact noemen. + if path_hints and repo: + try: + # Probeer lokale repo te openen (snelle clone/update) en lees de bestanden in. + repo_base = repo.rsplit("/", 1)[-1] + # Zet een conservatieve guess in elkaar voor de URL: gebruiker werkt met Gitea + # en roept elders get_git_repo() aan — wij willen geen netwerk doen hier. + # We gebruiken in plaats daarvan de collectie-metadata om te bepalen + # of het pad al aanwezig is (via metas). + known_paths = { (m or {}).get("path","") for m in metas } + injects = [] + for hint in list(path_hints): + # Match ook alleen op bestandsnaam als volledige path niet gevonden is + matches = [p for p in known_paths if p == hint or p.endswith("/"+hint)] + for pth in matches: + # Zoek het eerste document dat bij dit path hoort en neem zijn tekst + for doc, meta in zip(docs, metas): + if (meta or {}).get("path","") == pth: + injects.append({"document": doc, "metadata": meta, "emb_sim": 1.0, "score": 1.0}) + break + if injects: + # Zet injects vooraan, zonder dubbelingen + seen_keys = set() + new_cands = [] + for j in injects + cands: + key = f"{(j.get('metadata') or {}).get('path','')}::{(j.get('metadata') or {}).get('chunk_index','')}" + if key in seen_keys: + continue + seen_keys.add(key) + new_cands.append(j) + cands = new_cands + except Exception: + pass + + if not cands: + fallback = [] + if BM25Okapi is not None and repo: + colname = _collection_versioned(collection_name) + repo_base = repo.rsplit("/", 1)[-1] + for k,(bm,docs) in _BM25_BY_REPO.items(): + # cache-key formaat: "||" + if (k.startswith(f"{repo_base}|") or k.startswith(f"{repo}|")) and k.endswith(colname): + scores = bm.get_scores(_bm25_tok(query)) + ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)[:n_results] + for d,s in ranked: + fallback.append({ + "document": d["text"], + "metadata": {"repo": repo_base, "path": d["path"]}, + "score": float(s) + }) + break + if fallback: + return {"count": len(fallback), "results": fallback} + return {"count": 0, "results": []} + + + # BM25/Jaccard op kandidaten + def _tok(s: str) -> list[str]: + return re.findall(r"[A-Za-z0-9_]+", s.lower()) + bm25_scores = [] + if BM25Okapi: + bm = BM25Okapi([_tok(c["document"]) for c in cands]) + bm25_scores = bm.get_scores(_tok(query)) + else: + qset = set(_tok(query)) + for c in cands: + tset = set(_tok(c["document"])) + inter = len(qset & tset); uni = len(qset | tset) or 1 + bm25_scores.append(inter / uni) + + # --- Laravel cross-file boosts --- + # We hebben per kandidaat: {"document": ..., "metadata": {"repo": "...", "path": "..."}} + # Bouw een pad->index mapping + idx_by_path = {} + for i,c in enumerate(cands): + p = (c.get("metadata") or {}).get("path","") + idx_by_path.setdefault(p, []).append(i) + + # 1) Route files -> boost controllers/views + for c in cands: + meta = c.get("metadata") or {} + path = meta.get("path","") + if path not in _LARAVEL_ROUTE_FILES: + continue + txt = c.get("document","") + # controllers + for ctrl_path, _meth in _laravel_pairs_from_route_text(txt): + if ctrl_path in idx_by_path: + for j in idx_by_path[ctrl_path]: + cands[j]["emb_sim"] = min(1.0, cands[j]["emb_sim"] + 0.05) # kleine push + # views (heel grof: lookups in route-tekst komen zelden direct voor, skip hier) + + # 2) Controllers -> boost views die ze renderen + for c in cands: + meta = c.get("metadata") or {} + path = meta.get("path","") + if not path.startswith("app/Http/Controllers/") or not path.endswith(".php"): + continue + txt = c.get("document","") + for vpath in _laravel_guess_view_paths_from_text(txt): + if vpath in idx_by_path: + for j in idx_by_path[vpath]: + cands[j]["emb_sim"] = min(1.0, cands[j]["emb_sim"] + 0.05) + + # symbols-boost: kleine bonus als query-termen in symbols of bestandsnaam voorkomen + q_terms = set(re.findall(r"[A-Za-z0-9_]+", query.lower())) + for c in cands: + meta = (c.get("metadata") or {}) + # --- symbols (bugfix: syms_raw was undefined) --- + syms_raw = meta.get("symbols") + if isinstance(syms_raw, str): + syms = [s.strip() for s in syms_raw.split(",") if s.strip()] + elif isinstance(syms_raw, list): + syms = syms_raw + else: + syms = [] + if syms and (q_terms & {s.lower() for s in syms}): + c["emb_sim"] = min(1.0, c["emb_sim"] + 0.04) + + # --- filename exact-ish match --- + fname_terms = set(re.findall(r"[A-Za-z0-9_]+", meta.get("path","").split("/")[-1].lower())) + if fname_terms and (q_terms & fname_terms): + c["emb_sim"] = min(1.0, c["emb_sim"] + 0.02) + # lichte bonus als path exact één van de hints is + if path_hints and meta.get("path") in path_hints: + c["emb_sim"] = min(1.0, c["emb_sim"] + 0.06) + + + + # normaliseer + if len(bm25_scores) > 0: + mn, mx = min(bm25_scores), max(bm25_scores) + bm25_norm = [(s - mn) / (mx - mn) if mx > mn else 0.0 for s in bm25_scores] + else: + bm25_norm = [0.0] * len(cands) + + alpha = float(os.getenv("RAG_EMB_WEIGHT", "0.6")) + for c, b in zip(cands, bm25_norm): + c["score"] = alpha * c["emb_sim"] + (1.0 - alpha) * b + + ranked = sorted(cands, key=lambda x: x["score"], reverse=True)[:n_results] + ranked_full = sorted(cands, key=lambda x: x["score"], reverse=True) + if RAG_LLM_RERANK: + topK = ranked_full[:max(10, n_results)] + # bouw prompt + prompt = "Rerank the following code passages for the query. Return ONLY a JSON array of indices (0-based) in best-to-worst order.\n" + prompt += f"Query: {query}\n" + for i, r in enumerate(topK): + path = (r.get("metadata") or {}).get("path","") + snippet = (r.get("document") or "")[:600] + prompt += f"\n# {i} — {path}\n{snippet}\n" + resp = await llm_call_openai_compat( + [{"role":"system","content":"You are precise and return only valid JSON."}, + {"role":"user","content": prompt+"\n\nOnly JSON array."}], + stream=False, temperature=0.0, top_p=1.0, max_tokens=256 + ) + try: + order = json.loads((resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","[]")) + reranked = [topK[i] for i in order if isinstance(i,int) and 0 <= i < len(topK)] + ranked = reranked[:n_results] if reranked else ranked_full[:n_results] + except Exception: + ranked = ranked_full[:n_results] + else: + ranked = ranked_full[:n_results] + + return { + "count": len(ranked), + "results": [{ + "document": r["document"], + "metadata": r["metadata"], + "file": (r["metadata"] or {}).get("path", ""), # <- expliciete bestandsnaam + "score": round(float(r["score"]), 4), + # houd distance beschikbaar voor hybrid_retrieve (embed-component) + "distance": float(r.get("distance", 0.0)) + } for r in ranked] + } + + +@app.post("/rag/query") +async def rag_query_endpoint( + query: str = Form(...), + n_results: int = Form(8), + collection_name: str = Form("code_docs"), + repo: str = Form(""), + path_contains: str = Form(""), + profile: str = Form("") +): + data = await rag_query_api( + query=query, n_results=n_results, collection_name=collection_name, + repo=(repo or None), path_contains=(path_contains or None), profile=(profile or None) + ) + return JSONResponse(data) + +# ----------------------------------------------------------------------------- +# Repo-agent endpoints +# ----------------------------------------------------------------------------- +@app.post("/agent/repo") +async def agent_repo(messages: ChatRequest, request: Request): + """ + Gespreks-interface met de repo-agent (stateful per sessie). + """ + # Convert naar list[dict] (zoals agent_repo verwacht) + msgs = [{"role": m.role, "content": m.content} for m in messages.messages] + text = await handle_repo_agent(msgs, request) + return PlainTextResponse(text) + +@app.post("/repo/qa") +async def repo_qa(req: RepoQARequest): + """ + Eén-shot vraag-antwoord over een repo. + """ + ans = await repo_qa_answer(req.repo_hint, req.question, branch=req.branch, n_ctx=req.n_ctx) + return PlainTextResponse(ans) + +# ----------------------------------------------------------------------------- +# Injecties voor agent_repo +# ----------------------------------------------------------------------------- +async def _rag_index_repo_internal(*, repo_url: str, branch: str, profile: str, + include: str, exclude_dirs: str, + chunk_chars: int, overlap: int, + collection_name: str): + # offload zware IO/CPU naar threadpool zodat de event-loop vrij blijft + return await run_in_threadpool( + _rag_index_repo_sync, + repo_url=repo_url, branch=branch, profile=profile, + include=include, exclude_dirs=exclude_dirs, + chunk_chars=chunk_chars, overlap=overlap, + collection_name=collection_name, force=False + ) + +async def _rag_query_internal(*, query: str, n_results: int, + collection_name: str, repo=None, path_contains=None, profile=None): + return await rag_query_api( + query=query, n_results=n_results, + collection_name=collection_name, + repo=repo, path_contains=path_contains, profile=profile + ) + +def _read_text_file_wrapper(p: Path | str) -> str: + return _read_text_file(Path(p) if not isinstance(p, Path) else p) + +async def _get_git_repo_async(repo_url: str, branch: str = "main") -> str: + # gitpython doet subprocess/IO → altijd in threadpool + return await run_in_threadpool(get_git_repo, repo_url, branch) + +# Registreer injecties +initialize_agent( + app=app, + get_git_repo_fn=_get_git_repo_async, + rag_index_repo_internal_fn=_rag_index_repo_internal, + rag_query_internal_fn=_rag_query_internal, + llm_call_fn=llm_call_openai_compat, + extract_code_block_fn=extract_code_block, + read_text_file_fn=_read_text_file_wrapper, + client_ip_fn=_client_ip, + profile_exclude_dirs=PROFILE_EXCLUDE_DIRS, + chroma_get_collection_fn=_get_collection, + embed_query_fn=_EMBEDDER.embed_query, + embed_documents_fn=_EMBEDDER.embed_documents, # ← nieuw: voor catalog-embeddings + # === nieuw: search-first + summaries + meili === + search_candidates_fn=_search_first_candidates, + repo_summary_get_fn=_repo_summary_get_internal, + meili_search_fn=_meili_search_internal, +) + +# ----------------------------------------------------------------------------- +# Health +# ----------------------------------------------------------------------------- +@app.get("/healthz") +def health(): + sem_val = getattr(app.state.LLM_SEM, "_value", 0) + ups = [{"url": u.url, "active": u.active, "ok": u.ok} for u in _UPS] + return { + "ok": True, + "embedder": _EMBEDDER.slug, + "chroma_mode": CHROMA_MODE, + "queue_len": len(app.state.LLM_QUEUE), + "permits_free": sem_val, + "upstreams": ups, + } + +@app.get("/metrics") +def metrics(): + ups = [{"url": u.url, "active": u.active, "ok": u.ok} for u in _UPS] + return { + "queue_len": len(app.state.LLM_QUEUE), + "sem_value": getattr(app.state.LLM_SEM, "_value", None), + "upstreams": ups + } +