3470 lines
143 KiB
Python
3470 lines
143 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
FastAPI bridge met:
|
||
- RAG (Chroma) index & query (client-side embeddings, http óf local)
|
||
- Repo cloning/updating (git)
|
||
- LLM-bridge (OpenAI-compatible /v1/chat/completions)
|
||
- Repo agent endpoints (injectie van helpers in agent_repo.py)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
import contextlib
|
||
from contextlib import contextmanager
|
||
import os, re, json, time, uuid, hashlib, logging, asyncio, fnmatch, threading
|
||
from dataclasses import dataclass
|
||
from typing import List, Dict, Optional, Union, Any
|
||
from pathlib import Path
|
||
from io import BytesIO
|
||
|
||
import requests
|
||
import httpx
|
||
import chromadb
|
||
import git
|
||
import base64
|
||
|
||
from fastapi import FastAPI, APIRouter, UploadFile, File, Form, Request, HTTPException, Body
|
||
from fastapi.responses import JSONResponse, StreamingResponse, PlainTextResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.openapi.utils import get_openapi
|
||
from fastapi.routing import APIRoute
|
||
from starlette.concurrency import run_in_threadpool
|
||
from pydantic import BaseModel
|
||
|
||
|
||
|
||
# Optionele libs voor tekst-extractie
|
||
try:
|
||
import PyPDF2
|
||
except Exception:
|
||
PyPDF2 = None
|
||
try:
|
||
import docx # python-docx
|
||
except Exception:
|
||
docx = None
|
||
try:
|
||
import pandas as pd
|
||
except Exception:
|
||
pd = None
|
||
try:
|
||
from pptx import Presentation
|
||
except Exception:
|
||
Presentation = None
|
||
|
||
# (optioneel) BM25 voor hybride retrieval
|
||
try:
|
||
from rank_bm25 import BM25Okapi
|
||
except Exception:
|
||
BM25Okapi = None
|
||
|
||
# --- BM25 fallback registry (per repo) ---
|
||
_BM25_BY_REPO: dict[str, tuple[object, list[dict]]] = {} # repo_key -> (bm25, docs)
|
||
def _bm25_tok(s: str) -> list[str]:
|
||
return re.findall(r"[A-Za-z0-9_]+", s.lower())
|
||
|
||
|
||
# --- Extra optional libs voor audio/vision/images ---
|
||
try:
|
||
import cairosvg
|
||
except Exception:
|
||
cairosvg = None
|
||
|
||
import tempfile, subprocess # voor audio
|
||
|
||
# STT (Whisper-compatible via faster-whisper) — optioneel
|
||
_STT_MODEL = None
|
||
STT_MODEL_NAME = os.getenv("STT_MODEL", "small")
|
||
STT_DEVICE = os.getenv("STT_DEVICE", "auto") # "auto" | "cuda" | "cpu"
|
||
|
||
# TTS (piper) — optioneel
|
||
PIPER_BIN = os.getenv("PIPER_BIN", "/usr/bin/piper")
|
||
PIPER_VOICE = os.getenv("PIPER_VOICE", "")
|
||
|
||
RAG_LLM_RERANK = os.getenv("RAG_LLM_RERANK", "0").lower() in ("1","true","yes")
|
||
|
||
# Lazy LLM import & autodetect
|
||
_openai = None
|
||
_mistral = None
|
||
|
||
|
||
|
||
# Token utility / conversatie window (bestaat in je project)
|
||
try:
|
||
from windowing_utils import (
|
||
derive_thread_id, SUMMARY_STORE, ConversationWindow,
|
||
approx_token_count, count_message_tokens
|
||
)
|
||
except Exception:
|
||
# Fallbacks zodat dit bestand standalone blijft werken
|
||
SUMMARY_STORE = {}
|
||
def approx_token_count(s: str) -> int:
|
||
return max(1, len(s) // 4)
|
||
def count_message_tokens(messages: List[dict]) -> int:
|
||
return sum(approx_token_count(m.get("content","")) for m in messages)
|
||
def derive_thread_id(messages: List[dict]) -> str:
|
||
payload = (messages[0].get("content","") if messages else "") + "|".join(m.get("role","") for m in messages)
|
||
return hashlib.sha256(payload.encode("utf-8", errors="ignore")).hexdigest()[:16]
|
||
class ConversationWindow:
|
||
def __init__(self, *a, **k): pass
|
||
|
||
# Queue helper (optioneel aanwezig in je project)
|
||
try:
|
||
from queue_helper import QueueManager, start_position_notifier, _Job
|
||
from queue_helper import USER_MAX_QUEUE, AGENT_MAX_QUEUE, UPDATE_INTERVAL, WORKER_TIMEOUT
|
||
except Exception:
|
||
QueueManager = None
|
||
|
||
# Smart_rag import
|
||
from smart_rag import enrich_intent, expand_queries, hybrid_retrieve, assemble_context
|
||
|
||
# Repo-agent (wij injecteren functies hieronder)
|
||
from agent_repo import initialize_agent, handle_repo_agent, repo_qa_answer, rag_query_internal_fn
|
||
|
||
try:
|
||
from agent_repo import smart_chunk_text # al aanwezig in jouw agent_repo
|
||
except Exception:
|
||
def smart_chunk_text(text: str, path_hint: str, target_chars: int = 1800,
|
||
hard_max: int = 2600, min_chunk: int = 800):
|
||
# simpele fallback
|
||
chunks = []
|
||
i, n = 0, len(text)
|
||
step = max(1, target_chars - 200)
|
||
while i < n:
|
||
chunks.append(text[i:i+target_chars])
|
||
i += step
|
||
return chunks
|
||
|
||
|
||
def _build_tools_system_prompt(tools: list) -> str:
|
||
lines = [
|
||
"You can call functions. When a function is needed, answer ONLY with a JSON object:",
|
||
'{"tool_calls":[{"name":"<function_name>","arguments":{...}}]}',
|
||
"No prose. Arguments MUST be valid JSON for the function schema."
|
||
]
|
||
for t in tools:
|
||
fn = t.get("function", {})
|
||
lines.append(f"- {fn.get('name')}: {fn.get('description','')}")
|
||
return "\n".join(lines)
|
||
|
||
def _extract_tool_calls_from_text(txt: str):
|
||
# tolerant: pak eerste JSON object (ook als het in ```json staat)
|
||
m = re.search(r"\{[\s\S]*\}", txt or "")
|
||
if not m:
|
||
return []
|
||
try:
|
||
obj = json.loads(m.group(0))
|
||
except Exception:
|
||
return []
|
||
tc = obj.get("tool_calls") or []
|
||
out = []
|
||
for c in tc:
|
||
name = (c or {}).get("name")
|
||
args = (c or {}).get("arguments") or {}
|
||
if isinstance(args, str):
|
||
try:
|
||
args = json.loads(args)
|
||
except Exception:
|
||
args = {}
|
||
if name:
|
||
out.append({
|
||
"id": f"call_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {
|
||
"name": name,
|
||
"arguments": json.dumps(args, ensure_ascii=False)
|
||
}
|
||
})
|
||
return out
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# App & logging
|
||
# -----------------------------------------------------------------------------
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger("app")
|
||
|
||
def _unique_id(route: APIRoute):
|
||
# unieke operation_id op basis van naam, pad en method
|
||
method = list(route.methods)[0].lower() if route.methods else "get"
|
||
return f"{route.name}_{route.path.replace('/', '_')}_{method}"
|
||
|
||
app = FastAPI(title="Mistral Bridge API",generate_unique_id_function=_unique_id)
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_credentials=True,
|
||
allow_origins=["*"], # dev only
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
@app.on_event("startup")
|
||
async def _startup():
|
||
# Zorg dat lokale hosts nooit via een proxy gaan
|
||
os.environ.setdefault(
|
||
"NO_PROXY",
|
||
"localhost,127.0.0.1,::1,host.docker.internal"
|
||
)
|
||
app.state.HTTPX = httpx.AsyncClient(
|
||
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
|
||
timeout=httpx.Timeout(LLM_READ_TIMEOUT, connect=LLM_CONNECT_TIMEOUT),
|
||
trust_env=False, # belangrijk: negeer env-proxy’s voor LLM
|
||
headers={"Connection": "keep-alive"} # houd verbindingen warm
|
||
)
|
||
app.state.HTTPX_PROXY = httpx.AsyncClient(
|
||
limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
|
||
timeout=httpx.Timeout(LLM_READ_TIMEOUT, connect=LLM_CONNECT_TIMEOUT),
|
||
trust_env=True, # belangrijk: negeer env-proxy’s voor LLM
|
||
headers={"Connection": "keep-alive"} # houd verbindingen warm
|
||
)
|
||
|
||
@app.on_event("shutdown")
|
||
async def _shutdown():
|
||
try:
|
||
await app.state.HTTPX.aclose()
|
||
except Exception:
|
||
pass
|
||
try:
|
||
await app.state.HTTPX_PROXY.aclose()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
# --- Globale LLM-concurrency & wachtrij (serieel by default) ---
|
||
LLM_MAX_CONCURRENCY = int(os.getenv("LLM_MAX_CONCURRENCY", os.getenv("LLM_CONCURRENCY", "1")))
|
||
|
||
if not hasattr(app.state, "LLM_SEM"):
|
||
import asyncio
|
||
app.state.LLM_SEM = asyncio.Semaphore(max(1, LLM_MAX_CONCURRENCY))
|
||
if not hasattr(app.state, "LLM_QUEUE"):
|
||
from collections import deque
|
||
app.state.LLM_QUEUE = deque()
|
||
|
||
|
||
@app.middleware("http")
|
||
async def log_requests(request: Request, call_next):
|
||
logger.info("➡️ %s %s", request.method, request.url.path)
|
||
response = await call_next(request)
|
||
logger.info("⬅️ %s", response.status_code)
|
||
return response
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Config
|
||
# -----------------------------------------------------------------------------
|
||
MISTRAL_MODE = os.getenv("MISTRAL_MODE", "v1").lower()
|
||
LLM_URL = os.getenv("LLM_URL", "http://localhost:8000/v1/chat/completions").strip()
|
||
RAW_URL = os.getenv("MISTRAL_URL_RAW", "http://host.docker.internal:8000/completion").strip()
|
||
LLM_CONNECT_TIMEOUT = float(os.getenv("LLM_CONNECT_TIMEOUT", "10"))
|
||
LLM_READ_TIMEOUT = float(os.getenv("LLM_READ_TIMEOUT", "1200"))
|
||
|
||
_UPSTREAM_URLS = [u.strip() for u in os.getenv("LLM_UPSTREAMS","").split(",") if u.strip()]
|
||
|
||
# ==== Meilisearch (optioneel) ====
|
||
MEILI_URL = os.getenv("MEILI_URL", "http://localhost:7700").rstrip("/")
|
||
MEILI_API_KEY = os.getenv("MEILI_API_KEY", "0xipOmfgi_zMgdFplSdv7L8mlx0RPMQCNxVTNJc54lQ")
|
||
MEILI_INDEX = os.getenv("MEILI_INDEX", "code_chunks")
|
||
MEILI_ENABLED = bool(MEILI_URL)
|
||
|
||
# Repo summaries (cache on demand)
|
||
_SUMMARY_DIR = os.path.join("/rag_db", "repo_summaries")
|
||
os.makedirs(_SUMMARY_DIR, exist_ok=True)
|
||
|
||
@dataclass
|
||
class _Upstream:
|
||
url: str
|
||
active: int = 0
|
||
ok: bool = True
|
||
|
||
_UPS = [_Upstream(u) for u in _UPSTREAM_URLS] if _UPSTREAM_URLS else []
|
||
|
||
def _pick_upstream_sticky(key: str) -> _Upstream | None:
|
||
if not _UPS:
|
||
return None
|
||
try:
|
||
h = int(hashlib.sha1(key.encode("utf-8")).hexdigest(), 16)
|
||
except Exception:
|
||
h = 0
|
||
idx = h % len(_UPS)
|
||
cand = _UPS[idx]
|
||
if cand.ok:
|
||
cand.active += 1
|
||
return cand
|
||
ok_list = [u for u in _UPS if u.ok]
|
||
best = min(ok_list or _UPS, key=lambda x: x.active)
|
||
best.active += 1
|
||
return best
|
||
|
||
|
||
def _pick_upstream() -> _Upstream | None:
|
||
if not _UPS:
|
||
return None
|
||
# kies de minst-belaste die ok is, anders toch de minst-belaste
|
||
ok_list = [u for u in _UPS if u.ok]
|
||
cand = min(ok_list or _UPS, key=lambda x: x.active)
|
||
cand.active += 1
|
||
return cand
|
||
|
||
def _release_upstream(u: _Upstream | None, bad: bool = False):
|
||
if not u:
|
||
return
|
||
u.active = max(0, u.active - 1)
|
||
if bad:
|
||
u.ok = False
|
||
# simpele cool-off: na 2 sec weer ok
|
||
asyncio.create_task(_mark_ok_later(u, 10.0))
|
||
|
||
async def _mark_ok_later(u: _Upstream, delay: float):
|
||
await asyncio.sleep(delay)
|
||
u.ok = True
|
||
|
||
SYSTEM_PROMPT = (
|
||
"Je bent een expert programmeringsassistent. Je geeft accurate, specifieke antwoorden zonder hallucinaties.\n"
|
||
"Voor code base analyse:\n"
|
||
"1. Geef eerst een samenvatting van de functionaliteit\n"
|
||
"2. Identificeer mogelijke problemen of verbeterpunten\n"
|
||
"3. Geef concrete aanbevelingen voor correcties\n"
|
||
"4. Voor verbeterde versies: geef eerst de toelichting, dan alleen het codeblok"
|
||
)
|
||
|
||
# ==== Chroma configuratie (local/http) ====
|
||
CHROMA_MODE = os.getenv("CHROMA_MODE", "local").lower() # "local" | "http"
|
||
CHROMA_PATH = os.getenv("CHROMA_PATH", "/rag_db")
|
||
CHROMA_HOST = os.getenv("CHROMA_HOST", "chroma")
|
||
CHROMA_PORT = int(os.getenv("CHROMA_PORT", "8005"))
|
||
|
||
# ==== Celery (optioneel, voor async indexing) ====
|
||
CELERY_ENABLED = os.getenv("CELERY_ENABLED", "0").lower() in ("1", "true", "yes")
|
||
celery_app = None
|
||
if CELERY_ENABLED:
|
||
try:
|
||
from celery import Celery
|
||
celery_app = Celery(
|
||
"agent_tasks",
|
||
broker=os.getenv("AMQP_URL", "amqp://guest:guest@rabbitmq:5672//"),
|
||
backend=os.getenv("CELERY_BACKEND", "redis://redis:6379/0"),
|
||
)
|
||
celery_app.conf.update(
|
||
task_acks_late=True,
|
||
worker_prefetch_multiplier=1,
|
||
task_time_limit=60*60,
|
||
)
|
||
except Exception as e:
|
||
logger.warning("Celery init failed (fallback naar sync): %s", e)
|
||
celery_app = None
|
||
|
||
# Git / repos
|
||
GITEA_URL = os.environ.get("GITEA_URL", "http://localhost:3080").rstrip("/")
|
||
REPO_PATH = os.environ.get("REPO_PATH", "/tmp/repos")
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Models (Pydantic)
|
||
# -----------------------------------------------------------------------------
|
||
class ChatMessage(BaseModel):
|
||
role: str
|
||
content: str
|
||
|
||
class ChatRequest(BaseModel):
|
||
messages: List[ChatMessage]
|
||
|
||
class RepoQARequest(BaseModel):
|
||
repo_hint: str
|
||
question: str
|
||
branch: str = "main"
|
||
n_ctx: int = 8
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Embeddings (SentenceTransformers / ONNX / Default)
|
||
# -----------------------------------------------------------------------------
|
||
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings
|
||
try:
|
||
from sentence_transformers import SentenceTransformer
|
||
except Exception:
|
||
SentenceTransformer = None
|
||
|
||
@dataclass
|
||
class _Embedder:
|
||
slug: str
|
||
family: str
|
||
model: object
|
||
device: str = "cpu"
|
||
|
||
def _encode(self, texts: list[str]) -> list[list[float]]:
|
||
if hasattr(self.model, "encode"):
|
||
bs = int(os.getenv("RAG_EMBED_BATCH_SIZE", "64"))
|
||
# geen progressbar; grotere batches voor throughput
|
||
return self.model.encode(
|
||
texts,
|
||
normalize_embeddings=True,
|
||
batch_size=bs,
|
||
show_progress_bar=False
|
||
).tolist()
|
||
return self.model(texts)
|
||
|
||
def embed_documents(self, docs: list[str]) -> list[list[float]]:
|
||
if self.family == "e5":
|
||
docs = [f"passage: {t}" for t in docs]
|
||
return self._encode(docs)
|
||
|
||
def embed_query(self, q: str) -> list[float]:
|
||
if self.family == "e5":
|
||
q = f"query: {q}"
|
||
return self._encode([q])[0]
|
||
|
||
def _build_embedder() -> _Embedder:
|
||
import inspect
|
||
# voorkom tokenizer thread-oversubscription
|
||
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
||
choice = os.getenv("RAG_EMBEDDINGS", "gte-multilingual").lower().strip()
|
||
if SentenceTransformer:
|
||
mapping = {
|
||
"gte-multilingual": ("Alibaba-NLP/gte-multilingual-base", "gte", "gte-multilingual"),
|
||
"bge-small": ("BAAI/bge-small-en-v1.5", "bge", "bge-small"),
|
||
"e5-small": ("intfloat/e5-small-v2", "e5", "e5-small"),
|
||
"gte-base-en": ("thenlper/gte-base", "gte", "gte-english"),
|
||
}
|
||
if choice not in mapping:
|
||
model_name, family, slug = mapping["bge-small"]
|
||
else:
|
||
model_name, family, slug = mapping[choice]
|
||
st_kwargs = {"device": "cpu"}
|
||
try:
|
||
if "trust_remote_code" in inspect.signature(SentenceTransformer).parameters:
|
||
st_kwargs["trust_remote_code"] = True
|
||
model = SentenceTransformer(model_name, **st_kwargs)
|
||
# optioneel: CPU thread-telling forceren
|
||
try:
|
||
thr = int(os.getenv("RAG_TORCH_THREADS", "0"))
|
||
if thr > 0:
|
||
import torch
|
||
torch.set_num_threads(thr)
|
||
except Exception:
|
||
pass
|
||
return _Embedder(slug=slug, family=family, model=model, device="cpu")
|
||
except Exception:
|
||
pass
|
||
|
||
# Fallback via Chroma embedding functions
|
||
from chromadb.utils import embedding_functions as ef
|
||
try:
|
||
onnx = ef.ONNXMiniLM_L6_V2()
|
||
slug, family = "onnx-minilm", "minilm"
|
||
except Exception:
|
||
onnx = ef.DefaultEmbeddingFunction()
|
||
slug, family = "default", "minilm"
|
||
|
||
class _OnnxWrapper:
|
||
def __init__(self, fun): self.fun = fun
|
||
def __call__(self, texts): return self.fun(texts)
|
||
return _Embedder(slug=slug, family=family, model=_OnnxWrapper(onnx))
|
||
|
||
_EMBEDDER = _build_embedder()
|
||
|
||
class _ChromaEF(EmbeddingFunction):
|
||
"""Alleen gebruikt bij local PersistentClient (niet bij http); naam wordt geborgd."""
|
||
def __init__(self, embedder: _Embedder):
|
||
self._embedder = embedder
|
||
def __call__(self, input: Documents) -> Embeddings:
|
||
return self._embedder.embed_documents(list(input))
|
||
def name(self) -> str:
|
||
return f"rag-ef:{self._embedder.family}:{self._embedder.slug}"
|
||
|
||
_CHROMA_EF = _ChromaEF(_EMBEDDER)
|
||
|
||
# === Chroma client (local of http) ===
|
||
if CHROMA_MODE == "http":
|
||
_CHROMA = chromadb.HttpClient(host=CHROMA_HOST, port=CHROMA_PORT)
|
||
else:
|
||
_CHROMA = chromadb.PersistentClient(path=CHROMA_PATH)
|
||
|
||
def _collection_versioned(base: str) -> str:
|
||
ver = os.getenv("RAG_INDEX_VERSION", "3")
|
||
return f"{base}__{_EMBEDDER.slug}__v{ver}"
|
||
|
||
_COLLECTIONS: dict[str, any] = {}
|
||
|
||
def _get_collection(base: str):
|
||
"""Haalt collection op; bij local voegt embedding_function toe, bij http niet (embedden doen we client-side)."""
|
||
name = _collection_versioned(base)
|
||
if name not in _COLLECTIONS:
|
||
if CHROMA_MODE == "http":
|
||
_COLLECTIONS[name] = _CHROMA.get_or_create_collection(name=name)
|
||
else:
|
||
_COLLECTIONS[name] = _CHROMA.get_or_create_collection(name=name, embedding_function=_CHROMA_EF)
|
||
return _COLLECTIONS[name]
|
||
|
||
def _collection_add(collection, documents: list[str], metadatas: list[dict], ids: list[str]):
|
||
"""Altijd embeddings client-side meezenden — werkt voor local én http.
|
||
Voeg een lichte header toe (pad/taal/symbolen) om retrieval te verbeteren.
|
||
"""
|
||
def _symbol_hints(txt: str) -> list[str]:
|
||
hints = []
|
||
# heel simpele, taalonafhankelijke patterns
|
||
for pat in [r"def\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
|
||
r"class\s+([A-Za-z_][A-Za-z0-9_]*)\b",
|
||
r"function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
|
||
r"public\s+function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\("]:
|
||
try:
|
||
hints += re.findall(pat, txt[:4000])
|
||
except Exception:
|
||
pass
|
||
# uniek en klein houden
|
||
out = []
|
||
for h in hints:
|
||
if h not in out:
|
||
out.append(h)
|
||
if len(out) >= 6:
|
||
break
|
||
return out
|
||
|
||
augmented_docs = []
|
||
metadatas_mod = []
|
||
for doc, meta in zip(documents, metadatas):
|
||
path = (meta or {}).get("path", "")
|
||
ext = (Path(path).suffix.lower().lstrip(".") if path else "") or "txt"
|
||
syms = _symbol_hints(doc)
|
||
header = f"FILE:{path} | LANG:{ext} | SYMBOLS:{','.join(syms)}\n"
|
||
augmented_docs.append(header + (doc or ""))
|
||
m = dict(meta or {})
|
||
if syms:
|
||
m["symbols"] = ",".join(syms[:8])
|
||
metadatas_mod.append(m)
|
||
|
||
embs = _EMBEDDER.embed_documents(augmented_docs)
|
||
collection.add(documents=augmented_docs, embeddings=embs, metadatas=metadatas_mod, ids=ids)
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Repository profile + file selectie
|
||
# -----------------------------------------------------------------------------
|
||
PROFILE_EXCLUDE_DIRS = {
|
||
".git",".npm","node_modules","vendor","storage","dist","build",".next",
|
||
"__pycache__",".venv","venv",".mypy_cache",".pytest_cache",
|
||
"target","bin","obj","logs","cache","temp",".cache",".idea",".vscode"
|
||
}
|
||
PROFILE_INCLUDES = {
|
||
"generic": ["*.md","*.txt","*.json","*.yml","*.yaml","*.ini","*.cfg","*.toml",
|
||
"*.py","*.php","*.js","*.ts","*.jsx","*.tsx","*.css","*.scss",
|
||
"*.html","*.htm","*.vue","*.rb","*.go","*.java","*.cs","*.blade.php"],
|
||
"laravel": ["*.php","*.blade.php","*.md","*.env","*.json"],
|
||
"node": ["*.js","*.ts","*.jsx","*.tsx","*.vue","*.md","*.json","*.css","*.scss","*.html","*.htm"],
|
||
}
|
||
|
||
def _detect_repo_profile(root: Path) -> str:
|
||
if (root / "artisan").exists() or (root / "composer.json").exists():
|
||
return "laravel"
|
||
if (root / "package.json").exists():
|
||
return "node"
|
||
return "generic"
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Tekst-extractie helpers
|
||
# -----------------------------------------------------------------------------
|
||
TEXT_EXTS = {
|
||
".php",".blade.php",".vue",".js",".ts",".jsx",".tsx",".css",".scss",
|
||
".html",".htm",".json",".md",".ini",".cfg",".yml",".yaml",".toml",
|
||
".py",".go",".rb",".java",".cs",".txt",".env",".sh",".bat",".dockerfile",
|
||
}
|
||
BINARY_SKIP = {".png",".jpg",".jpeg",".webp",".bmp",".gif",".ico",".pdf",".zip",".gz",".tar",".7z",".rar",".woff",".woff2",".ttf",".eot",".otf"}
|
||
|
||
def _read_text_file(p: Path) -> str:
|
||
ext = p.suffix.lower()
|
||
# Skip obvious binaries quickly
|
||
if ext in BINARY_SKIP:
|
||
# PDF → probeer tekst
|
||
if ext == ".pdf" and PyPDF2:
|
||
try:
|
||
with open(p, "rb") as f:
|
||
reader = PyPDF2.PdfReader(f)
|
||
out = []
|
||
for page in reader.pages:
|
||
try:
|
||
out.append(page.extract_text() or "")
|
||
except Exception:
|
||
pass
|
||
return "\n".join(out).strip()
|
||
except Exception:
|
||
return ""
|
||
return ""
|
||
# DOCX
|
||
if ext == ".docx" and docx:
|
||
try:
|
||
d = docx.Document(str(p))
|
||
return "\n".join([para.text for para in d.paragraphs])
|
||
except Exception:
|
||
return ""
|
||
# CSV/XLSX (alleen header + paar regels)
|
||
if ext in {".csv",".tsv"} and pd:
|
||
try:
|
||
df = pd.read_csv(p, nrows=200)
|
||
return df.to_csv(index=False)[:20000]
|
||
except Exception:
|
||
pass
|
||
if ext in {".xlsx",".xls"} and pd:
|
||
try:
|
||
df = pd.read_excel(p, nrows=200)
|
||
return df.to_csv(index=False)[:20000]
|
||
except Exception:
|
||
pass
|
||
# PPTX
|
||
if ext == ".pptx" and Presentation:
|
||
try:
|
||
prs = Presentation(str(p))
|
||
texts = []
|
||
for slide in prs.slides:
|
||
for shape in slide.shapes:
|
||
if hasattr(shape, "text"):
|
||
texts.append(shape.text)
|
||
return "\n".join(texts)
|
||
except Exception:
|
||
return ""
|
||
# Default: lees als tekst
|
||
try:
|
||
return p.read_text(encoding="utf-8", errors="ignore")
|
||
except Exception:
|
||
try:
|
||
return p.read_text(encoding="latin-1", errors="ignore")
|
||
except Exception:
|
||
return ""
|
||
|
||
def _chunk_text(text: str, chunk_chars: int = 3000, overlap: int = 400) -> List[str]:
|
||
if not text: return []
|
||
n = len(text); i = 0; step = max(1, chunk_chars - overlap)
|
||
chunks = []
|
||
while i < n:
|
||
chunks.append(text[i:i+chunk_chars])
|
||
i += step
|
||
return chunks
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Git helper
|
||
# -----------------------------------------------------------------------------
|
||
def get_git_repo(repo_url: str, branch: str = "main") -> str:
|
||
"""
|
||
Clone of update repo in REPO_PATH, checkout branch.
|
||
Retourneert pad als string.
|
||
"""
|
||
os.makedirs(REPO_PATH, exist_ok=True)
|
||
# Unieke directory obv owner/repo of hash
|
||
name = None
|
||
try:
|
||
from urllib.parse import urlparse
|
||
u = urlparse(repo_url)
|
||
parts = [p for p in u.path.split("/") if p]
|
||
if parts:
|
||
name = parts[-1]
|
||
if name.endswith(".git"): name = name[:-4]
|
||
except Exception:
|
||
pass
|
||
if not name:
|
||
name = hashlib.sha1(repo_url.encode("utf-8", errors="ignore")).hexdigest()[:12]
|
||
local = os.path.join(REPO_PATH, name)
|
||
lock_path = f"{local}.lock"
|
||
with _file_lock(lock_path, timeout=float(os.getenv("GIT_LOCK_TIMEOUT","60"))):
|
||
if not os.path.exists(local):
|
||
logger.info("Cloning %s → %s", repo_url, local)
|
||
repo = git.Repo.clone_from(repo_url, local, depth=1)
|
||
else:
|
||
repo = git.Repo(local)
|
||
try:
|
||
repo.remote().fetch(depth=1, prune=True)
|
||
except Exception as e:
|
||
logger.warning("Fetch failed: %s", e)
|
||
# Checkout
|
||
logger.info("Checking out branch %s", branch)
|
||
try:
|
||
repo.git.checkout(branch)
|
||
except Exception:
|
||
# probeer origin/branch aan te maken
|
||
try:
|
||
repo.git.checkout("-B", branch, f"origin/{branch}")
|
||
except Exception:
|
||
# laatste fallback: default HEAD
|
||
logger.warning("Checkout %s failed; fallback to default HEAD", branch)
|
||
logger.info("Done with: Checking out branch %s", branch)
|
||
return local
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# LLM call (OpenAI-compatible)
|
||
# -----------------------------------------------------------------------------
|
||
# -----------------------------------------------------------------------------
|
||
# LLM call (OpenAI-compatible)
|
||
# -----------------------------------------------------------------------------
|
||
# -----------------------------------------------------------------------------
|
||
# LLM call (OpenAI-compatible) — met seriële wachtrij (LLM_SEM + LLM_QUEUE)
|
||
# -----------------------------------------------------------------------------
|
||
async def llm_call_openai_compat(
|
||
messages: List[dict],
|
||
*,
|
||
model: Optional[str] = None,
|
||
stream: bool = False,
|
||
temperature: float = 0.2,
|
||
top_p: float = 0.9,
|
||
max_tokens: int = 1024,
|
||
extra: Optional[dict] = None,
|
||
stop: Optional[Union[str, list[str]]] = None,
|
||
**kwargs
|
||
) -> dict | StreamingResponse:
|
||
payload: dict = {
|
||
"model": model or os.getenv("LLM_MODEL", "mistral-medium"),
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"top_p": top_p,
|
||
"max_tokens": max_tokens,
|
||
"stream": bool(stream)
|
||
}
|
||
# OpenAI-compat: optionele stop-sequenties doorgeven indien aanwezig
|
||
if stop is not None:
|
||
payload["stop"] = stop
|
||
# Eventuele andere onbekende kwargs negeren (compat met callsites die extra parameters sturen)
|
||
# (bewust geen payload.update(kwargs) om upstream niet te breken)
|
||
|
||
if extra:
|
||
payload.update(extra)
|
||
|
||
# kies URL
|
||
thread_key = None
|
||
try:
|
||
# Sticky per conversatie én model (voorkomt verkeerde stickiness en TypeErrors)
|
||
_model = model or os.getenv("LLM_MODEL","mistral-medium")
|
||
thread_key = f"{_model}:{derive_thread_id(messages)}"
|
||
except Exception:
|
||
thread_key = (model or os.getenv("LLM_MODEL","mistral-medium")) + ":" + json.dumps(
|
||
[m.get("role","")+":"+(m.get("content","")[:64]) for m in messages]
|
||
)[:256]
|
||
|
||
upstream = _pick_upstream_sticky(thread_key) or _pick_upstream()
|
||
#url = (upstream.url if upstream else LLM_URL)
|
||
#upstream = _pick_upstream()
|
||
url = (upstream.url if upstream else LLM_URL)
|
||
|
||
# --- NON-STREAM: wacht keurig op beurt en houd exclusieve lock vast
|
||
if not stream:
|
||
token = object()
|
||
app.state.LLM_QUEUE.append(token)
|
||
try:
|
||
# Fair: wacht tot je vooraan staat
|
||
# Fair: laat de eerste N (LLM_MAX_CONCURRENCY) door zodra er een permit vrij is
|
||
while True:
|
||
try:
|
||
pos = app.state.LLM_QUEUE.index(token) + 1
|
||
except ValueError:
|
||
return # token verdwenen (client annuleerde)
|
||
free = getattr(app.state.LLM_SEM, "_value", 0)
|
||
if pos <= LLM_MAX_CONCURRENCY and free > 0:
|
||
await app.state.LLM_SEM.acquire()
|
||
break
|
||
await asyncio.sleep(0.1)
|
||
|
||
|
||
# ⬇️ BELANGRIJK: gebruik non-proxy client (zelfde host; vermijd env-proxy timeouts)
|
||
client = app.state.HTTPX
|
||
r = await client.post(url, json=payload)
|
||
try:
|
||
r.raise_for_status()
|
||
return r.json()
|
||
except httpx.HTTPStatusError as e:
|
||
_release_upstream(upstream, bad=True)
|
||
# geef een OpenAI-achtige fout terug
|
||
raise HTTPException(status_code=r.status_code, detail=f"LLM upstream error: {e}") # behoudt statuscode
|
||
except ValueError:
|
||
_release_upstream(upstream, bad=True)
|
||
raise HTTPException(status_code=502, detail="LLM upstream gaf geen geldige JSON.")
|
||
finally:
|
||
_release_upstream(upstream)
|
||
app.state.LLM_SEM.release()
|
||
finally:
|
||
try:
|
||
if app.state.LLM_QUEUE and app.state.LLM_QUEUE[0] is token:
|
||
app.state.LLM_QUEUE.popleft()
|
||
else:
|
||
app.state.LLM_QUEUE.remove(token)
|
||
except ValueError:
|
||
pass
|
||
|
||
# --- STREAM: stuur wachtrij-updates als SSE totdat je aan de beurt bent
|
||
async def _aiter():
|
||
token = object()
|
||
app.state.LLM_QUEUE.append(token)
|
||
try:
|
||
# 1) wachtrij-positie uitsturen zolang je niet mag
|
||
first = True
|
||
while True:
|
||
try:
|
||
pos = app.state.LLM_QUEUE.index(token) + 1
|
||
except ValueError:
|
||
return
|
||
free = getattr(app.state.LLM_SEM, "_value", 0)
|
||
if pos <= LLM_MAX_CONCURRENCY and free > 0:
|
||
await app.state.LLM_SEM.acquire()
|
||
break
|
||
if first or pos <= LLM_MAX_CONCURRENCY:
|
||
data = {
|
||
"id": f"queue-info-{int(time.time())}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(time.time()),
|
||
"model": payload["model"],
|
||
"choices": [{
|
||
"index": 0,
|
||
"delta": {"role": "assistant", "content": f"⏳ Plek #{pos} – vrije slots: {free}\n"},
|
||
"finish_reason": None
|
||
}]
|
||
}
|
||
yield ("data: " + json.dumps(data, ensure_ascii=False) + "\n\n").encode("utf-8")
|
||
first = False
|
||
await asyncio.sleep(1.0)
|
||
|
||
|
||
# 2) nu exclusieve toegang → echte upstream streamen
|
||
try:
|
||
client = app.state.HTTPX
|
||
# timeout=None voor onbegrensde streams
|
||
async with client.stream("POST", url, json=payload, timeout=None) as r:
|
||
r.raise_for_status()
|
||
HEARTBEAT = float(os.getenv("SSE_HEARTBEAT_SEC","10"))
|
||
q: asyncio.Queue[bytes] = asyncio.Queue(maxsize=100)
|
||
|
||
async def _reader():
|
||
try:
|
||
async for chunk in r.aiter_bytes():
|
||
if chunk:
|
||
await q.put(chunk)
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
await q.put(b"__EOF__")
|
||
|
||
reader_task = asyncio.create_task(_reader())
|
||
try:
|
||
while True:
|
||
try:
|
||
chunk = await asyncio.wait_for(q.get(), timeout=HEARTBEAT)
|
||
except asyncio.TimeoutError:
|
||
# SSE comment; door UIs genegeerd → houdt verbinding warm
|
||
yield b": ping\n\n"
|
||
continue
|
||
if chunk == b"__EOF__":
|
||
break
|
||
yield chunk
|
||
finally:
|
||
reader_task.cancel()
|
||
with contextlib.suppress(Exception):
|
||
await reader_task
|
||
except (asyncio.CancelledError, httpx.RemoteProtocolError) as e:
|
||
# Client brak verbinding; upstream niet ‘bad’ markeren
|
||
_release_upstream(upstream, bad=False)
|
||
logger.info("🔌 stream cancelled/closed by client: %s", e)
|
||
return
|
||
except Exception as e:
|
||
_release_upstream(upstream, bad=True)
|
||
logger.info("🔌 stream aborted (compat): %s", e)
|
||
return
|
||
finally:
|
||
_release_upstream(upstream)
|
||
app.state.LLM_SEM.release()
|
||
|
||
finally:
|
||
try:
|
||
if app.state.LLM_QUEUE and app.state.LLM_QUEUE[0] is token:
|
||
app.state.LLM_QUEUE.popleft()
|
||
else:
|
||
app.state.LLM_QUEUE.remove(token)
|
||
except ValueError:
|
||
pass
|
||
|
||
return StreamingResponse(
|
||
_aiter(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache, no-transform",
|
||
"X-Accel-Buffering": "no",
|
||
"Connection": "keep-alive",
|
||
},
|
||
)
|
||
|
||
|
||
def extract_code_block(s: str) -> str:
|
||
"""
|
||
Pak de eerste ``` ``` blok-inhoud (zonder fences). Val terug naar volledige tekst als geen blok aangetroffen.
|
||
"""
|
||
if not s: return ""
|
||
m = re.search(r"```[a-zA-Z0-9_+-]*\n([\s\S]*?)```", s)
|
||
if m:
|
||
return m.group(1).strip()
|
||
return s.strip()
|
||
|
||
_SIZE_RE = re.compile(r"^\s*(\d+)\s*x\s*(\d+)\s*$")
|
||
|
||
def _parse_size(size_str: str) -> tuple[int,int]:
|
||
m = _SIZE_RE.match(str(size_str or "512x512"))
|
||
if not m: return (512,512)
|
||
w = max(64, min(2048, int(m.group(1))))
|
||
h = max(64, min(2048, int(m.group(2))))
|
||
return (w, h)
|
||
|
||
def _sanitize_svg(svg: str) -> str:
|
||
# strip code fences
|
||
if "```" in svg:
|
||
svg = extract_code_block(svg)
|
||
# remove scripts + on* handlers
|
||
svg = re.sub(r"<\s*script\b[^>]*>.*?<\s*/\s*script\s*>", "", svg, flags=re.I|re.S)
|
||
svg = re.sub(r"\son\w+\s*=\s*['\"].*?['\"]", "", svg, flags=re.I|re.S)
|
||
return svg
|
||
|
||
def _svg_wrap_if_needed(svg: str, w: int, h: int, bg: str="white") -> str:
|
||
s = (svg or "").strip()
|
||
if "<svg" not in s:
|
||
return f'''<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}">
|
||
<rect width="100%" height="100%" fill="{bg}"/>
|
||
<text x="20" y="40" font-size="18" font-family="sans-serif">SVG niet gevonden in modeloutput</text>
|
||
</svg>'''
|
||
if 'width=' not in s or 'height=' not in s:
|
||
s = re.sub(r"<svg\b", f'<svg width="{w}" height="{h}"', s, count=1)
|
||
return s
|
||
|
||
async def _svg_from_prompt(prompt: str, w: int, h: int, background: str="white") -> str:
|
||
sys = ("Je bent een SVG-tekenaar. Geef ALLEEN raw SVG 1.1 markup terug; geen uitleg of code fences. "
|
||
"Geen externe <image> refs of scripts.")
|
||
user = (f"Maak een eenvoudige vectorillustratie.\n- Canvas {w}x{h}, achtergrond {background}\n"
|
||
f"- Thema: {prompt}\n- Gebruik eenvoudige vormen/paths/tekst.")
|
||
resp = await llm_call_openai_compat(
|
||
[{"role":"system","content":sys},{"role":"user","content":user}],
|
||
stream=False, temperature=0.35, top_p=0.9, max_tokens=1200
|
||
)
|
||
svg = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","")
|
||
return _svg_wrap_if_needed(_sanitize_svg(svg), w, h, background)
|
||
|
||
# --------- kleine helpers: atomic json write & simpele file lock ----------
|
||
def _atomic_json_write(path: str, data: dict):
|
||
tmp = f"{path}.tmp.{uuid.uuid4().hex}"
|
||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||
with open(tmp, "w", encoding="utf-8") as f:
|
||
json.dump(data, f, ensure_ascii=False)
|
||
os.replace(tmp, path)
|
||
|
||
@contextmanager
|
||
def _file_lock(lock_path: str, timeout: float = 60.0, poll: float = 0.2):
|
||
start = time.time()
|
||
fd = None
|
||
while True:
|
||
try:
|
||
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644)
|
||
break
|
||
except FileExistsError:
|
||
if time.time() - start > timeout: raise TimeoutError(f"Lock timeout for {lock_path}")
|
||
time.sleep(poll)
|
||
try:
|
||
os.write(fd, str(os.getpid()).encode("utf-8", errors="ignore"))
|
||
yield
|
||
finally:
|
||
with contextlib.suppress(Exception):
|
||
os.close(fd) if fd is not None else None
|
||
os.unlink(lock_path)
|
||
|
||
|
||
# -------- UploadFile → text (generiek), gebruikt je optional libs --------
|
||
async def read_file_content(file: UploadFile) -> str:
|
||
name = (file.filename or "").lower()
|
||
raw = await run_in_threadpool(file.file.read)
|
||
|
||
# Plain-text achtig
|
||
if name.endswith((".txt",".py",".php",".js",".ts",".jsx",".tsx",".html",".css",".json",
|
||
".md",".yml",".yaml",".ini",".cfg",".log",".toml",".env",".sh",".bat",".dockerfile")):
|
||
return raw.decode("utf-8", errors="ignore")
|
||
|
||
if name.endswith(".docx") and docx:
|
||
d = docx.Document(BytesIO(raw))
|
||
return "\n".join(p.text for p in d.paragraphs)
|
||
|
||
if name.endswith(".pdf") and PyPDF2:
|
||
reader = PyPDF2.PdfReader(BytesIO(raw))
|
||
return "\n".join([(p.extract_text() or "") for p in reader.pages])
|
||
|
||
if name.endswith(".csv") and pd is not None:
|
||
df = pd.read_csv(BytesIO(raw))
|
||
return df.head(200).to_csv(index=False)
|
||
|
||
if name.endswith((".xlsx",".xls")) and pd is not None:
|
||
try:
|
||
sheets = pd.read_excel(BytesIO(raw), sheet_name=None, header=0)
|
||
except Exception as e:
|
||
return f"(Kon XLSX niet parsen: {e})"
|
||
out = []
|
||
for sname, df in sheets.items():
|
||
out.append(f"# Sheet: {sname}\n{df.head(200).to_csv(index=False)}")
|
||
return "\n\n".join(out)
|
||
|
||
if name.endswith(".pptx") and Presentation:
|
||
pres = Presentation(BytesIO(raw))
|
||
texts = []
|
||
for i, slide in enumerate(pres.slides, 1):
|
||
buf = [f"# Slide {i}"]
|
||
for shape in slide.shapes:
|
||
if hasattr(shape, "has_text_frame") and shape.has_text_frame:
|
||
buf.append(shape.text.strip())
|
||
texts.append("\n".join([t for t in buf if t]))
|
||
return "\n".join(texts)
|
||
|
||
# Afbeeldingen of onbekend → probeer raw tekst
|
||
try:
|
||
return raw.decode("utf-8", errors="ignore")
|
||
except Exception:
|
||
return "(onbekend/beeld-bestand)"
|
||
|
||
|
||
def _client_ip(request: Request) -> str:
|
||
for hdr in ("cf-connecting-ip","x-real-ip","x-forwarded-for"):
|
||
v = request.headers.get(hdr)
|
||
if v: return v.split(",")[0].strip()
|
||
ip = request.client.host if request.client else "0.0.0.0"
|
||
return ip
|
||
|
||
# --- Vision helpers: OpenAI content parts -> plain text + images (b64) ---
|
||
_JSON_FENCE_RE = re.compile(r"^```(?:json)?\s*([\s\S]*?)\s*```$", re.I)
|
||
|
||
def _normalize_openai_vision_messages(messages: list[dict]) -> tuple[list[dict], list[str]]:
|
||
"""
|
||
Converteer OpenAI-achtige message parts (text + image_url)
|
||
naar (1) platte string met <image> markers en (2) images=[base64,...].
|
||
Alleen data: URI's of al-b64 doorlaten; remote http(s) URL's negeren (veilig default).
|
||
"""
|
||
imgs: list[str] = []
|
||
out: list[dict] = []
|
||
for m in messages:
|
||
c = m.get("content")
|
||
role = m.get("role","user")
|
||
if isinstance(c, list):
|
||
parts = []
|
||
for item in c:
|
||
if isinstance(item, dict) and item.get("type") == "text":
|
||
parts.append(item.get("text",""))
|
||
elif isinstance(item, dict) and item.get("type") == "image_url":
|
||
u = item.get("image_url")
|
||
if isinstance(u, dict):
|
||
u = u.get("url")
|
||
if isinstance(u, str):
|
||
if u.startswith("data:image") and "," in u:
|
||
b64 = u.split(",",1)[1]
|
||
imgs.append(b64)
|
||
parts.append("<image>")
|
||
elif re.match(r"^[A-Za-z0-9+/=]+$", u.strip()): # ruwe base64
|
||
imgs.append(u.strip())
|
||
parts.append("<image>")
|
||
else:
|
||
# http(s) URL's niet meesturen (veilig), maar ook geen ruis in de prompt
|
||
pass
|
||
out.append({"role": role, "content": "\n".join([p for p in parts if p]).strip()})
|
||
else:
|
||
out.append({"role": role, "content": c or ""})
|
||
return out, imgs
|
||
|
||
def _parse_repo_qa_from_messages(messages: list[dict]) -> tuple[Optional[str], str, str, int]:
|
||
"""Haal repo_hint, question, branch, n_ctx grofweg uit user-tekst."""
|
||
txts = [m.get("content","") for m in messages if m.get("role") == "user"]
|
||
full = "\n".join(txts).strip()
|
||
repo_hint = None
|
||
m = re.search(r"(https?://\S+?)(?:\s|$)", full)
|
||
if m: repo_hint = m.group(1).strip()
|
||
if not repo_hint:
|
||
m = re.search(r"\brepo\s*:\s*([A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+)", full, flags=re.I)
|
||
if m: repo_hint = m.group(1).strip()
|
||
if not repo_hint:
|
||
m = re.search(r"\b([A-Za-z0-9_.-]+/[A-Za-z0-9_.-]+)\b", full)
|
||
if m: repo_hint = m.group(1).strip()
|
||
branch = "main"
|
||
m = re.search(r"\bbranch\s*:\s*([A-Za-z0-9_.\/-]+)", full, flags=re.I)
|
||
if m: branch = m.group(1).strip()
|
||
n_ctx = 8
|
||
m = re.search(r"\bn_ctx\s*:\s*(\d+)", full, flags=re.I)
|
||
if m:
|
||
try: n_ctx = max(1, min(16, int(m.group(1))))
|
||
except: pass
|
||
question = full
|
||
return repo_hint, question, branch, n_ctx
|
||
|
||
|
||
# ------------------------------
|
||
# OpenAI-compatible endpoints
|
||
# ------------------------------
|
||
@app.get("/v1/models")
|
||
def list_models():
|
||
base_model = os.getenv("LLM_MODEL", "mistral-medium")
|
||
# Toon ook je twee "virtuele" modellen voor de UI
|
||
return {
|
||
"object": "list",
|
||
"data": [
|
||
{"id": base_model, "object": "model", "created": 0, "owned_by": "you"},
|
||
{"id": "repo-agent", "object": "model", "created": 0, "owned_by": "you"},
|
||
{"id": "repo-qa", "object": "model", "created": 0, "owned_by": "you"},
|
||
],
|
||
}
|
||
|
||
|
||
def _openai_chat_response(model: str, text: str, messages: list[dict]):
|
||
created = int(time.time())
|
||
# heel simpele usage schatting; vermijdt None
|
||
try:
|
||
prompt_tokens = count_message_tokens(messages) if 'count_message_tokens' in globals() else approx_token_count(json.dumps(messages))
|
||
except Exception:
|
||
prompt_tokens = approx_token_count(json.dumps(messages))
|
||
completion_tokens = approx_token_count(text)
|
||
return {
|
||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||
"object": "chat.completion",
|
||
"created": created,
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"message": {"role": "assistant", "content": text},
|
||
"finish_reason": "stop"
|
||
}],
|
||
"usage": {
|
||
"prompt_tokens": prompt_tokens,
|
||
"completion_tokens": completion_tokens,
|
||
"total_tokens": prompt_tokens + completion_tokens
|
||
}
|
||
}
|
||
|
||
def _get_stt_model():
|
||
global _STT_MODEL
|
||
if _STT_MODEL is not None:
|
||
return _STT_MODEL
|
||
try:
|
||
from faster_whisper import WhisperModel
|
||
except Exception:
|
||
raise HTTPException(status_code=500, detail="STT niet beschikbaar: faster-whisper ontbreekt.")
|
||
# device-selectie
|
||
if STT_DEVICE == "auto":
|
||
try:
|
||
_STT_MODEL = WhisperModel(STT_MODEL_NAME, device="cuda", compute_type="float16")
|
||
return _STT_MODEL
|
||
except Exception:
|
||
_STT_MODEL = WhisperModel(STT_MODEL_NAME, device="cpu", compute_type="int8")
|
||
return _STT_MODEL
|
||
dev, comp = ("cuda","float16") if STT_DEVICE=="cuda" else ("cpu","int8")
|
||
_STT_MODEL = WhisperModel(STT_MODEL_NAME, device=dev, compute_type=comp)
|
||
return _STT_MODEL
|
||
|
||
def _stt_transcribe_path(path: str, lang: str | None):
|
||
model = _get_stt_model()
|
||
segments, info = model.transcribe(path, language=lang or None, vad_filter=True)
|
||
text = "".join(seg.text for seg in segments).strip()
|
||
return text, getattr(info, "language", None)
|
||
|
||
@app.post("/v1/audio/transcriptions")
|
||
async def audio_transcriptions(
|
||
file: UploadFile = File(...),
|
||
model: str = Form("whisper-1"),
|
||
language: str | None = Form(None),
|
||
prompt: str | None = Form(None)
|
||
):
|
||
data = await file.read()
|
||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||
tmp.write(data)
|
||
tmp_path = tmp.name
|
||
try:
|
||
text, lang = await run_in_threadpool(_stt_transcribe_path, tmp_path, language)
|
||
return {"text": text, "language": lang or "unknown"}
|
||
finally:
|
||
try: os.unlink(tmp_path)
|
||
except Exception: pass
|
||
|
||
@app.post("/v1/audio/speech")
|
||
async def audio_speech(body: dict = Body(...)):
|
||
"""
|
||
OpenAI-compat: { "model": "...", "voice": "optional", "input": "tekst" }
|
||
Return: audio/wav
|
||
"""
|
||
if not PIPER_VOICE:
|
||
raise HTTPException(status_code=500, detail="PIPER_VOICE env var niet gezet (piper TTS).")
|
||
text = (body.get("input") or "").strip()
|
||
if not text:
|
||
raise HTTPException(status_code=400, detail="Lege input")
|
||
|
||
def _synth_to_wav_bytes() -> bytes:
|
||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as out:
|
||
out_path = out.name
|
||
try:
|
||
cp = subprocess.run(
|
||
[PIPER_BIN, "-m", PIPER_VOICE, "-f", out_path, "-q"],
|
||
input=text.encode("utf-8"),
|
||
stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
|
||
)
|
||
with open(out_path, "rb") as f:
|
||
return f.read()
|
||
finally:
|
||
try: os.unlink(out_path)
|
||
except Exception: pass
|
||
|
||
audio_bytes = await run_in_threadpool(_synth_to_wav_bytes)
|
||
return StreamingResponse(iter([audio_bytes]), media_type="audio/wav")
|
||
|
||
@app.post("/v1/images/generations")
|
||
async def images_generations(payload: dict = Body(...)):
|
||
prompt = (payload.get("prompt") or "").strip()
|
||
if not prompt:
|
||
raise HTTPException(status_code=400, detail="prompt required")
|
||
n = max(1, min(8, int(payload.get("n", 1))))
|
||
size = payload.get("size","512x512")
|
||
w,h = _parse_size(size)
|
||
background = payload.get("background","white")
|
||
fmt = (payload.get("format") or "").lower().strip() # non-standard
|
||
if fmt not in ("png","svg",""):
|
||
fmt = ""
|
||
if not fmt:
|
||
fmt = "png" if cairosvg is not None else "svg"
|
||
|
||
out_items = []
|
||
for _ in range(n):
|
||
svg = await _svg_from_prompt(prompt, w, h, background)
|
||
if fmt == "png" and cairosvg is not None:
|
||
png_bytes = cairosvg.svg2png(bytestring=svg.encode("utf-8"), output_width=w, output_height=h)
|
||
b64 = base64.b64encode(png_bytes).decode("ascii")
|
||
else:
|
||
b64 = base64.b64encode(svg.encode("utf-8")).decode("ascii")
|
||
out_items.append({"b64_json": b64})
|
||
return {"created": int(time.time()), "data": out_items}
|
||
|
||
@app.get("/v1/images/health")
|
||
def images_health():
|
||
return {"svg_to_png": bool(cairosvg is not None)}
|
||
|
||
@app.post("/present/make")
|
||
async def present_make(
|
||
prompt: str = Form(...),
|
||
file: UploadFile | None = File(None),
|
||
max_slides: int = Form(8)
|
||
):
|
||
if Presentation is None:
|
||
raise HTTPException(500, "python-pptx ontbreekt in de container.")
|
||
src_text = ""
|
||
if file:
|
||
try:
|
||
src_text = (await read_file_content(file))[:30000]
|
||
except Exception:
|
||
src_text = ""
|
||
sys = ("Je bent een presentatieschrijver. Geef ALLEEN geldige JSON terug: "
|
||
'{"title": str, "slides":[{"title": str, "bullets":[str,...]}...]}.')
|
||
user = (f"Doelpresentatie: {prompt}\nBron (optioneel):\n{src_text[:12000]}\n"
|
||
f"Max. {max_slides} dia's, 3–6 bullets per dia.")
|
||
plan = await llm_call_openai_compat(
|
||
[{"role":"system","content":sys},{"role":"user","content":user}],
|
||
stream=False, temperature=0.3, top_p=0.9, max_tokens=1200
|
||
)
|
||
raw = (plan.get("choices",[{}])[0].get("message",{}) or {}).get("content","{}")
|
||
try:
|
||
spec = json.loads(raw)
|
||
except Exception:
|
||
raise HTTPException(500, "Model gaf geen valide JSON voor slides.")
|
||
prs = Presentation()
|
||
title = spec.get("title") or "Presentatie"
|
||
# Titel
|
||
slide_layout = prs.slide_layouts[0]
|
||
s = prs.slides.add_slide(slide_layout)
|
||
s.shapes.title.text = title
|
||
# Content
|
||
for slide in spec.get("slides", []):
|
||
layout = prs.slide_layouts[1]
|
||
sl = prs.slides.add_slide(layout)
|
||
sl.shapes.title.text = slide.get("title","")
|
||
tx = sl.placeholders[1].text_frame
|
||
tx.clear()
|
||
for i, bullet in enumerate(slide.get("bullets", [])[:10]):
|
||
p = tx.add_paragraph() if i>0 else tx.paragraphs[0]
|
||
p.text = bullet
|
||
p.level = 0
|
||
bio = BytesIO()
|
||
prs.save(bio); bio.seek(0)
|
||
headers = {"Content-Disposition": f'attachment; filename="deck.pptx"'}
|
||
return StreamingResponse(bio, media_type="application/vnd.openxmlformats-officedocument.presentationml.presentation", headers=headers)
|
||
|
||
@app.get("/whoami")
|
||
def whoami():
|
||
import socket
|
||
return {"service": "mistral-api", "host": socket.gethostname(), "LLM_URL": LLM_URL}
|
||
|
||
IMG_EXTS = (".png",".jpg",".jpeg",".webp",".bmp",".gif")
|
||
|
||
def _is_image_filename(name: str) -> bool:
|
||
return (name or "").lower().endswith(IMG_EXTS)
|
||
|
||
@app.post("/vision/ask")
|
||
async def vision_ask(
|
||
file: UploadFile = File(...),
|
||
prompt: str = Form("Beschrijf kort wat je ziet."),
|
||
stream: bool = Form(False),
|
||
temperature: float = Form(0.2),
|
||
top_p: float = Form(0.9),
|
||
max_tokens: int = Form(512),
|
||
):
|
||
raw = await run_in_threadpool(file.file.read)
|
||
img_b64 = base64.b64encode(raw).decode("utf-8")
|
||
messages = [{"role":"user","content": f"<image> {prompt}"}]
|
||
if stream:
|
||
return await llm_call_openai_compat(
|
||
messages,
|
||
model=os.getenv("LLM_MODEL","mistral-medium"),
|
||
stream=True,
|
||
temperature=float(temperature),
|
||
top_p=float(top_p),
|
||
max_tokens=int(max_tokens),
|
||
extra={"images": [img_b64]}
|
||
)
|
||
|
||
|
||
return await llm_call_openai_compat(
|
||
messages, stream=False, temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||
extra={"images": [img_b64]}
|
||
)
|
||
|
||
@app.post("/file/vision-and-text")
|
||
async def vision_and_text(
|
||
files: List[UploadFile] = File(...),
|
||
prompt: str = Form("Combineer visuele analyse met de tekstcontext. Geef 5 bullets met bevindingen en 3 actiepunten."),
|
||
stream: bool = Form(False),
|
||
max_images: int = Form(6),
|
||
max_chars: int = Form(25000),
|
||
temperature: float = Form(0.2),
|
||
top_p: float = Form(0.9),
|
||
max_tokens: int = Form(1024),
|
||
):
|
||
images_b64: list[str] = []
|
||
text_chunks: list[str] = []
|
||
for f in files:
|
||
name = f.filename or ""
|
||
raw = await run_in_threadpool(f.file.read)
|
||
if _is_image_filename(name) and len(images_b64) < int(max_images):
|
||
images_b64.append(base64.b64encode(raw).decode("utf-8"))
|
||
else:
|
||
try:
|
||
tmp = UploadFile(filename=name, file=BytesIO(raw), headers=f.headers)
|
||
text = await read_file_content(tmp)
|
||
except Exception:
|
||
text = raw.decode("utf-8", errors="ignore")
|
||
if text:
|
||
text_chunks.append(f"### {name}\n{text.strip()}")
|
||
|
||
text_context = ("\n\n".join(text_chunks))[:int(max_chars)].strip()
|
||
image_markers = ("\n".join(["<image>"] * len(images_b64))).strip()
|
||
user_content = (image_markers + ("\n\n" if image_markers else "") + prompt.strip())
|
||
if text_context:
|
||
user_content += f"\n\n--- TEKST CONTEXT (ingekort) ---\n{text_context}"
|
||
messages = [{"role":"user","content": user_content}]
|
||
|
||
if stream:
|
||
return await llm_call_openai_compat(
|
||
messages,
|
||
model=os.getenv("LLM_MODEL","mistral-medium"),
|
||
stream=True,
|
||
temperature=float(temperature),
|
||
top_p=float(top_p),
|
||
max_tokens=int(max_tokens),
|
||
extra={"images": images_b64}
|
||
)
|
||
|
||
|
||
return await llm_call_openai_compat(
|
||
messages, stream=False, temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||
extra={"images": images_b64}
|
||
)
|
||
|
||
@app.get("/vision/health")
|
||
async def vision_health():
|
||
tiny_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAOaO5nYAAAAASUVORK5CYII="
|
||
try:
|
||
messages = [{"role":"user","content":"<image> Beschrijf dit in één woord."}]
|
||
resp = await llm_call_openai_compat(messages, extra={"images":[tiny_png]}, max_tokens=16)
|
||
txt = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","").strip()
|
||
return {"vision": bool(txt), "sample": txt[:60]}
|
||
except Exception as e:
|
||
return {"vision": False, "error": str(e)}
|
||
|
||
# -------- Tool registry (OpenAI-style) --------
|
||
LLM_FUNCTION_CALLING_MODE = os.getenv("LLM_FUNCTION_CALLING_MODE", "auto").lower() # "native" | "shim" | "auto"
|
||
|
||
OWUI_BASE_URL='http://localhost:3000'
|
||
OWUI_API_TOKEN='sk-f1b7991b054442b5ae388de905019726'
|
||
# Aliassen zodat oudere codepaths blijven werken
|
||
OWUI_BASE = OWUI_BASE_URL
|
||
OWUI_TOKEN = OWUI_API_TOKEN
|
||
|
||
|
||
@app.get("/tools", operation_id="get_tools_list_compat")
|
||
async def tools_compat():
|
||
return await list_tools()
|
||
|
||
|
||
|
||
async def _fetch_openwebui_file(file_id: str, dst_dir: Path) -> Path:
|
||
"""
|
||
Download raw file content from OpenWebUI and save locally.
|
||
"""
|
||
if not (OWUI_BASE_URL and OWUI_API_TOKEN):
|
||
raise HTTPException(status_code=501, detail="OpenWebUI niet geconfigureerd (OWUI_BASE_URL/OWUI_API_TOKEN).")
|
||
url = f"{OWUI_BASE_URL}/api/v1/files/{file_id}/content"
|
||
headers = {"Authorization": f"Bearer {OWUI_API_TOKEN}"}
|
||
|
||
dst_dir.mkdir(parents=True, exist_ok=True)
|
||
out_path = dst_dir / f"{file_id}"
|
||
# Gebruik proxy-aware client uit app.state als beschikbaar
|
||
client = getattr(app.state, "HTTPX_PROXY", None)
|
||
close_after = False
|
||
if client is None:
|
||
client = httpx.AsyncClient(timeout=None, trust_env=True)
|
||
close_after = True
|
||
try:
|
||
async with client.stream("GET", url, headers=headers, timeout=None) as resp:
|
||
if resp.status_code != 200:
|
||
# lees body (beperkt) voor foutmelding
|
||
body = await resp.aread()
|
||
raise HTTPException(status_code=resp.status_code,
|
||
detail=f"OpenWebUI file fetch failed: {body[:2048]!r}")
|
||
max_mb = int(os.getenv("OWUI_MAX_DOWNLOAD_MB", "64"))
|
||
max_bytes = max_mb * 1024 * 1024
|
||
total = 0
|
||
with out_path.open("wb") as f:
|
||
async for chunk in resp.aiter_bytes():
|
||
if not chunk:
|
||
continue
|
||
total += len(chunk)
|
||
if total > max_bytes:
|
||
try: f.close()
|
||
except Exception: pass
|
||
try: out_path.unlink(missing_ok=True)
|
||
except Exception: pass
|
||
raise HTTPException(status_code=413, detail=f"Bestand groter dan {max_mb}MB; download afgebroken.")
|
||
f.write(chunk)
|
||
return out_path
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logging.exception("download failed")
|
||
raise HTTPException(status_code=500, detail=f"download error: {e}")
|
||
finally:
|
||
if close_after:
|
||
await client.aclose()
|
||
|
||
|
||
def _normalize_files_arg(args: dict):
|
||
# OpenWebUI injecteert files als __files__ = [{ id, name, mime, size, ...}]
|
||
files = args.get("__files__") or args.get("files") or []
|
||
if isinstance(files, dict): # enkel bestand
|
||
files = [files]
|
||
return files
|
||
|
||
|
||
@app.get("/openapi.json", include_in_schema=False)
|
||
async def openapi_endpoint():
|
||
return get_openapi(
|
||
title="Tool Server",
|
||
version="0.1.0",
|
||
routes=app.routes,
|
||
)
|
||
|
||
def _openai_tools_from_registry(reg: dict):
|
||
out = []
|
||
for name, spec in reg.items():
|
||
out.append({
|
||
"type": "function",
|
||
"function": {
|
||
"name": name,
|
||
"description": spec.get("description",""),
|
||
"parameters": spec.get("parameters", {"type":"object","properties":{}})
|
||
}
|
||
})
|
||
return out
|
||
|
||
def _visible_registry(reg: dict) -> dict:
|
||
return {k:v for k,v in reg.items() if not v.get("hidden")}
|
||
|
||
VALIDATE_PROMPT = (
|
||
"Je bent een code-reviewer die deze code moet valideren:\n"
|
||
"1. Controleer op syntactische fouten\n"
|
||
"2. Controleer op logische fouten en onjuiste functionaliteit\n"
|
||
"3. Check of alle vereiste functionaliteit aanwezig is\n"
|
||
"4. Zoek naar mogelijke bugs of veiligheidsrisico's\n"
|
||
"5. Geef specifieke feedback met regelnummers\n\n"
|
||
"Geef een lijst in dit formaat:\n"
|
||
"- Regels [10-12]: Beschrijving\n"
|
||
"- Regels 25: Beschrijving"
|
||
)
|
||
|
||
def _parse_validation_results(text: str) -> list[str]:
|
||
issues = []
|
||
for line in (text or "").splitlines():
|
||
line = line.strip()
|
||
if line.startswith('-') and 'Regels' in line and ':' in line:
|
||
issues.append(line)
|
||
return issues
|
||
|
||
async def _execute_tool(name: str, args: dict) -> dict:
|
||
if name == "repo_grep":
|
||
repo_url = args.get("repo_url","")
|
||
branch = args.get("branch","main")
|
||
query = (args.get("query") or "").strip()
|
||
max_hits = int(args.get("max_hits",200))
|
||
if not repo_url or not query:
|
||
raise HTTPException(status_code=400, detail="repo_url en query verplicht.")
|
||
repo_path = await _get_git_repo_async(repo_url, branch)
|
||
root = Path(repo_path)
|
||
hits = []
|
||
qlow = query.lower()
|
||
for p in root.rglob("*"):
|
||
if p.is_dir():
|
||
continue
|
||
if set(p.parts) & PROFILE_EXCLUDE_DIRS:
|
||
continue
|
||
if p.suffix.lower() in BINARY_SKIP:
|
||
continue
|
||
try:
|
||
txt = _read_text_file(p)
|
||
except Exception:
|
||
continue
|
||
if not txt:
|
||
continue
|
||
# snelle lijngewijze scan
|
||
for ln, line in enumerate(txt.splitlines(), 1):
|
||
if qlow in line.lower():
|
||
hits.append({
|
||
"path": str(p.relative_to(root)),
|
||
"line": ln,
|
||
"excerpt": line.strip()[:400]
|
||
})
|
||
if len(hits) >= max_hits:
|
||
break
|
||
if len(hits) >= max_hits:
|
||
break
|
||
return {"count": len(hits), "hits": hits, "repo": os.path.basename(repo_url), "branch": branch}
|
||
# RAG
|
||
if name == "rag_index_repo":
|
||
out = await run_in_threadpool(_rag_index_repo_sync, **{
|
||
"repo_url": args.get("repo_url",""),
|
||
"branch": args.get("branch","main"),
|
||
"profile": args.get("profile","auto"),
|
||
"include": args.get("include",""),
|
||
"exclude_dirs": args.get("exclude_dirs",""),
|
||
"chunk_chars": int(args.get("chunk_chars",3000)),
|
||
"overlap": int(args.get("overlap",400)),
|
||
"collection_name": args.get("collection_name","code_docs"),
|
||
"force": bool(args.get("force", False)),
|
||
})
|
||
return out
|
||
if name == "rag_query":
|
||
out = await rag_query_api(
|
||
query=args.get("query",""),
|
||
n_results=int(args.get("n_results",5)),
|
||
collection_name=args.get("collection_name","code_docs"),
|
||
repo=args.get("repo"),
|
||
path_contains=args.get("path_contains"),
|
||
profile=args.get("profile")
|
||
)
|
||
return out
|
||
|
||
# Tekst tools
|
||
if name == "summarize_text":
|
||
text = (args.get("text") or "")[:int(args.get("max_chars",16000))]
|
||
instruction = args.get("instruction") or "Vat samen in bullets (max 10), met korte inleiding en actiepunten."
|
||
resp = await llm_call_openai_compat(
|
||
[{"role":"system","content":"Je bent behulpzaam en exact."},
|
||
{"role":"user","content": f"{instruction}\n\n--- BEGIN ---\n{text}\n--- EINDE ---"}],
|
||
stream=False, max_tokens=768
|
||
)
|
||
return resp
|
||
|
||
if name == "analyze_text":
|
||
text = (args.get("text") or "")[:int(args.get("max_chars",20000))]
|
||
goal = args.get("goal") or "Licht toe wat dit document doet. Benoem sterke/zwakke punten, risico’s en concrete verbeterpunten."
|
||
resp = await llm_call_openai_compat(
|
||
[{"role":"system","content":"Wees feitelijk en concreet."},
|
||
{"role":"user","content": f"{goal}\n\n--- BEGIN ---\n{text}\n--- EINDE ---"}],
|
||
stream=False, max_tokens=768
|
||
)
|
||
return resp
|
||
|
||
if name == "improve_text":
|
||
text = (args.get("text") or "")[:int(args.get("max_chars",20000))]
|
||
objective = args.get("objective") or "Verbeter dit document. Lever een beknopte toelichting en vervolgens de volledige verbeterde versie."
|
||
style = args.get("style") or "Hanteer best practices, behoud inhoudelijke betekenis."
|
||
resp = await llm_call_openai_compat(
|
||
[{"role":"system","content": SYSTEM_PROMPT},
|
||
{"role":"user","content": f"{objective}\nStijl/criteria: {style}\n"
|
||
"Antwoord met eerst een korte toelichting, daarna alleen de verbeterde inhoud tussen een codeblok.\n\n"
|
||
f"--- BEGIN ---\n{text}\n--- EINDE ---"}],
|
||
stream=False, max_tokens=1024
|
||
)
|
||
return resp
|
||
|
||
# Code tools
|
||
if name == "validate_code_text":
|
||
code = args.get("code","")
|
||
resp = await llm_call_openai_compat(
|
||
[{"role":"system","content":"Wees strikt en concreet."},
|
||
{"role":"user","content": f"{VALIDATE_PROMPT}\n\nCode om te valideren:\n```\n{code}\n```"}],
|
||
stream=False, max_tokens=512
|
||
)
|
||
txt = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","")
|
||
return {"status":"issues_found" if _parse_validation_results(txt) else "valid",
|
||
"issues": _parse_validation_results(txt), "raw": txt}
|
||
|
||
if name == "improve_code_text":
|
||
code = args.get("code","")
|
||
language = args.get("language","auto")
|
||
focus = args.get("improvement_focus","best practices")
|
||
resp = await llm_call_openai_compat(
|
||
[{"role":"system","content": SYSTEM_PROMPT},
|
||
{"role":"user","content": f"Verbeter deze {language}-code met focus op {focus}:\n\n"
|
||
f"{code}\n\nGeef eerst een korte toelichting, dan alleen het verbeterde codeblok. Behoud functionaliteit."}],
|
||
stream=False, max_tokens=1536
|
||
)
|
||
return resp
|
||
|
||
if name == "ingest_openwebui_files":
|
||
files = _normalize_files_arg(args)
|
||
if not files:
|
||
raise HTTPException(status_code=400, detail="Geen bestanden ontvangen (__files__ is leeg).")
|
||
saved = []
|
||
tmpdir = Path("/tmp/owui_files")
|
||
for fmeta in files:
|
||
fid = fmeta.get("id")
|
||
if not fid:
|
||
continue
|
||
path = await _fetch_openwebui_file(fid, tmpdir)
|
||
saved.append({"id": fid, "path": str(path), "name": fmeta.get("name"), "mime": fmeta.get("mime")})
|
||
|
||
# >>> HIER je eigen pipeline aanroepen, bijv. direct indexeren:
|
||
# for s in saved: index_file_into_chroma(s["path"], collection=args.get("target_collection","code_docs"), ...)
|
||
|
||
return {"ok": True, "downloaded": saved, "collection": args.get("target_collection","code_docs")}
|
||
|
||
if name == "vision_analyze":
|
||
image_url = args.get("image_url","")
|
||
prompt = args.get("prompt","Beschrijf beknopt wat je ziet en noem de belangrijkste details.")
|
||
max_tokens = int(args.get("max_tokens",512))
|
||
b64 = None
|
||
# Alleen data: of raw base64 accepteren; http(s) niet, want die worden niet
|
||
# ingeladen in de vision-call en zouden stil falen.
|
||
if image_url.startswith(("http://","https://")):
|
||
raise HTTPException(status_code=400, detail="vision_analyze: image_url moet data: URI of raw base64 zijn.")
|
||
if image_url.startswith("data:image") and "," in image_url:
|
||
b64 = image_url.split(",",1)[1]
|
||
elif re.match(r"^[A-Za-z0-9+/=]+$", image_url.strip()):
|
||
b64 = image_url.strip()
|
||
messages = [{"role":"user","content": f"<image> {prompt}"}]
|
||
return await llm_call_openai_compat(messages, stream=False, max_tokens=max_tokens,
|
||
extra={"images": [b64] if b64 else []})
|
||
|
||
raise HTTPException(status_code=400, detail=f"Unknown tool: {name}")
|
||
|
||
TOOLS_REGISTRY = {
|
||
"repo_grep": {
|
||
"description": "Zoek exact(e) tekst in een git repo (fast grep-achtig).",
|
||
"parameters": {
|
||
"type":"object",
|
||
"properties":{
|
||
"repo_url":{"type":"string"},
|
||
"branch":{"type":"string","default":"main"},
|
||
"query":{"type":"string"},
|
||
"max_hits":{"type":"integer","default":200}
|
||
},
|
||
"required":["repo_url","query"]
|
||
}
|
||
},
|
||
"ingest_openwebui_files": {
|
||
"description": "Download aangehechte OpenWebUI-bestanden en voer ingestie/embedding uit.",
|
||
"parameters": {
|
||
"type":"object",
|
||
"properties":{
|
||
"target_collection":{"type":"string","default":"code_docs"},
|
||
"profile":{"type":"string","default":"auto"},
|
||
"chunk_chars":{"type":"integer","default":3000},
|
||
"overlap":{"type":"integer","default":400}
|
||
}
|
||
}
|
||
},
|
||
"rag_index_repo": {
|
||
"description": "Indexeer een git-repo in ChromaDB (chunken & metadata).",
|
||
"parameters": {
|
||
"type":"object",
|
||
"properties":{
|
||
"repo_url":{"type":"string"},
|
||
"branch":{"type":"string","default":"main"},
|
||
"profile":{"type":"string","default":"auto"},
|
||
"include":{"type":"string","default":""},
|
||
"exclude_dirs":{"type":"string","default":""},
|
||
"chunk_chars":{"type":"integer","default":3000},
|
||
"overlap":{"type":"integer","default":400},
|
||
"collection_name":{"type":"string","default":"code_docs"},
|
||
"force":{"type":"boolean","default":False}
|
||
},
|
||
"required":["repo_url"]
|
||
}
|
||
},
|
||
"rag_query": {
|
||
"description": "Zoek in de RAG-collectie en geef top-N passages (hybride rerank).",
|
||
"parameters": {
|
||
"type":"object",
|
||
"properties":{
|
||
"query":{"type":"string"},
|
||
"n_results":{"type":"integer","default":8},
|
||
"collection_name":{"type":"string","default":"code_docs"},
|
||
"repo":{"type":["string","null"]},
|
||
"path_contains":{"type":["string","null"]},
|
||
"profile":{"type":["string","null"]}
|
||
},
|
||
"required":["query"]
|
||
}
|
||
},
|
||
"summarize_text": {
|
||
"description": "Vat tekst samen in bullets met inleiding en actiepunten.",
|
||
"parameters": {"type":"object","properties":{
|
||
"text":{"type":"string"},
|
||
"instruction":{"type":"string","default":"Vat samen in bullets (max 10), met korte inleiding en actiepunten."},
|
||
"max_chars":{"type":"integer","default":16000}
|
||
},"required":["text"]}
|
||
},
|
||
"analyze_text": {
|
||
"description": "Analyseer tekst: sterke/zwakke punten, risico’s en verbeterpunten.",
|
||
"parameters": {"type":"object","properties":{
|
||
"text":{"type":"string"},
|
||
"goal":{"type":"string","default":"Licht toe wat dit document doet. Benoem sterke/zwakke punten, risico’s en concrete verbeterpunten."},
|
||
"max_chars":{"type":"integer","default":20000}
|
||
},"required":["text"]}
|
||
},
|
||
"improve_text": {
|
||
"description": "Verbeter tekst: korte toelichting + volledige verbeterde versie.",
|
||
"parameters": {"type":"object","properties":{
|
||
"text":{"type":"string"},
|
||
"objective":{"type":"string","default":"Verbeter dit document. Lever een beknopte toelichting en vervolgens de volledige verbeterde versie."},
|
||
"style":{"type":"string","default":"Hanteer best practices, behoud inhoudelijke betekenis."},
|
||
"max_chars":{"type":"integer","default":20000}
|
||
},"required":["text"]}
|
||
},
|
||
"validate_code_text": {
|
||
"description": "Valideer code; geef issues (met regelnummers).",
|
||
"parameters": {"type":"object","properties":{
|
||
"code":{"type":"string"}
|
||
},"required":["code"]}
|
||
},
|
||
"improve_code_text": {
|
||
"description": "Verbeter code met focus (best practices/perf/security).",
|
||
"parameters": {"type":"object","properties":{
|
||
"code":{"type":"string"},
|
||
"language":{"type":"string","default":"auto"},
|
||
"improvement_focus":{"type":"string","default":"best practices"}
|
||
},"required":["code"]}
|
||
},
|
||
"vision_analyze": {
|
||
"description": "Eén afbeelding analyseren via data: URI of base64.",
|
||
"parameters":{"type":"object","properties":{
|
||
"image_url":{"type":"string"},
|
||
"prompt":{"type":"string","default":"Beschrijf beknopt wat je ziet en de belangrijkste details."},
|
||
"max_tokens":{"type":"integer","default":512}
|
||
},"required":["image_url"]}
|
||
}
|
||
}
|
||
|
||
# Verberg OWUI-afhankelijke tool wanneer niet geconfigureerd
|
||
try:
|
||
if not (OWUI_BASE_URL and OWUI_API_TOKEN):
|
||
# bestaat altijd als key; markeer als hidden zodat /v1/tools hem niet toont
|
||
TOOLS_REGISTRY["ingest_openwebui_files"]["hidden"] = True
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def _tools_list_from_registry(reg: dict):
|
||
lst = []
|
||
for name, spec in reg.items():
|
||
lst.append({
|
||
"name": name,
|
||
"description": spec.get("description", ""),
|
||
"parameters": spec.get("parameters", {"type":"object","properties":{}})
|
||
})
|
||
return {"mode": LLM_FUNCTION_CALLING_MODE if 'LLM_FUNCTION_CALLING_MODE' in globals() else "shim",
|
||
"tools": lst}
|
||
|
||
|
||
@app.get("/v1/tools", operation_id="get_tools_list")
|
||
async def list_tools(format: str = "proxy"):
|
||
# negeer 'format' en geef altijd OpenAI tool list terug
|
||
return {
|
||
"object": "tool.list",
|
||
"mode": "function", # <- belangrijk voor OWUI 0.6.21
|
||
"data": _openai_tools_from_registry(_visible_registry(TOOLS_REGISTRY))
|
||
}
|
||
|
||
# === OpenAPI-compat: 1 endpoint per tool (operationId = toolnaam) ===
|
||
# Open WebUI kijkt naar /openapi.json, leest operationId’s en maakt daar “tools” van.
|
||
def _register_openapi_tool(name: str):
|
||
# Opzet: POST /openapi/{name} met body == arguments-dict
|
||
# operation_id = name => Open WebUI toolnaam = name
|
||
# Belangrijk: injecteer requestBody schema in OpenAPI, zodat OWUI de parameters kent.
|
||
schema = (TOOLS_REGISTRY.get(name, {}) or {}).get("parameters", {"type":"object","properties":{}})
|
||
@app.post(
|
||
f"/openapi/{name}",
|
||
operation_id=name,
|
||
summary=f"Tool: {name}",
|
||
openapi_extra={
|
||
"requestBody": {
|
||
"required": True,
|
||
"content": {
|
||
"application/json": {
|
||
"schema": schema
|
||
}
|
||
}
|
||
}
|
||
}
|
||
)
|
||
async def _tool_entry(payload: dict = Body(...)):
|
||
if name not in TOOLS_REGISTRY:
|
||
raise HTTPException(status_code=404, detail=f"Unknown tool: {name}")
|
||
# payload = {"arg1":..., "arg2":...}
|
||
out = await _execute_tool(name, payload or {})
|
||
# Als _execute_tool een OpenAI-chat response teruggeeft, haal de tekst eruit
|
||
if isinstance(out, dict) and "choices" in out:
|
||
try:
|
||
txt = ((out.get("choices") or [{}])[0].get("message") or {}).get("content", "")
|
||
return {"result": txt}
|
||
except Exception:
|
||
pass
|
||
return out
|
||
|
||
for _t in TOOLS_REGISTRY.keys():
|
||
_register_openapi_tool(_t)
|
||
|
||
|
||
@app.post("/v1/tools/call", operation_id="post_tools_call")
|
||
async def tools_entry(request: Request):
|
||
#if request.method == "GET":
|
||
# return {"object":"tool.list","mode":LLM_FUNCTION_CALLING_MODE,"data": _openai_tools_from_registry(TOOLS_REGISTRY)}
|
||
data = await request.json()
|
||
name = (data.get("name")
|
||
or (data.get("tool") or {}).get("name")
|
||
or (data.get("function") or {}).get("name"))
|
||
raw_args = (data.get("arguments")
|
||
or (data.get("tool") or {}).get("arguments")
|
||
or (data.get("function") or {}).get("arguments")
|
||
or {})
|
||
if isinstance(raw_args, str):
|
||
try: args = json.loads(raw_args)
|
||
except Exception: args = {}
|
||
else:
|
||
args = raw_args
|
||
if not name or name not in TOOLS_REGISTRY or TOOLS_REGISTRY[name].get("hidden"):
|
||
raise HTTPException(status_code=404, detail=f"Unknown tool: {name}")
|
||
try:
|
||
output = await _execute_tool(name, args)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logging.exception("tool call failed")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
return {"id": f"toolcall-{uuid.uuid4().hex}","object":"tool.run","created": int(time.time()),"name": name,"arguments": args,"output": output}
|
||
|
||
async def _complete_with_autocontinue(
|
||
messages: list[dict],
|
||
*,
|
||
model: str,
|
||
temperature: float,
|
||
top_p: float,
|
||
max_tokens: int,
|
||
extra_payload: dict | None,
|
||
max_autocont: int
|
||
) -> tuple[str, dict]:
|
||
"""
|
||
Eén of meer non-stream calls met jouw bestaande auto-continue logica.
|
||
Retourneert (full_text, last_response_json).
|
||
"""
|
||
resp = await llm_call_openai_compat(
|
||
messages, model=model, stream=False,
|
||
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||
extra=extra_payload if extra_payload else None
|
||
)
|
||
content = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","")
|
||
finish_reason = (resp.get("choices",[{}])[0] or {}).get("finish_reason")
|
||
|
||
continues = 0
|
||
NEAR_MAX = int(os.getenv("LLM_AUTOCONT_NEAR_MAX", "48"))
|
||
def _near_cap(txt: str) -> bool:
|
||
try:
|
||
return approx_token_count(txt) >= max_tokens - NEAR_MAX
|
||
except Exception:
|
||
return False
|
||
while continues < max_autocont and (
|
||
finish_reason in ("length","content_filter")
|
||
or content.endswith("…")
|
||
or _near_cap(content)
|
||
):
|
||
follow = messages + [
|
||
{"role":"assistant","content": content},
|
||
{"role":"user","content": "Ga verder zonder herhaling; alleen het vervolg."}
|
||
]
|
||
nxt = await llm_call_openai_compat(
|
||
follow, model=model, stream=False,
|
||
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||
extra=extra_payload if extra_payload else None
|
||
)
|
||
more = (nxt.get("choices",[{}])[0].get("message",{}) or {}).get("content","")
|
||
content = (content + ("\n" if content and more else "") + more).strip()
|
||
finish_reason = (nxt.get("choices",[{}])[0] or {}).get("finish_reason")
|
||
resp = nxt # laatst gezien resp teruggeven (usage etc.)
|
||
continues += 1
|
||
return content, resp
|
||
|
||
@app.post("/v1/chat/completions")
|
||
async def openai_chat_completions(body: dict = Body(...), request: Request = None):
|
||
model = (body.get("model") or os.getenv("LLM_MODEL", "mistral-medium")).strip()
|
||
#logging.info(str(body))
|
||
#logging.info(str(request))
|
||
stream = bool(body.get("stream", False))
|
||
raw_messages = body.get("messages") or []
|
||
# normaliseer tool-berichten naar plain tekst voor het LLM
|
||
norm_messages = []
|
||
for m in raw_messages:
|
||
if m.get("role") == "tool":
|
||
nm = m.get("name") or "tool"
|
||
norm_messages.append({
|
||
"role": "user",
|
||
"content": f"[{nm} RESULT]\n{m.get('content') or ''}"
|
||
})
|
||
else:
|
||
norm_messages.append(m)
|
||
|
||
|
||
# --- minimal tool-calling glue (laat rest van je functie intact) ---
|
||
tools = body.get("tools") or []
|
||
# 'tool_choice' hier alleen lezen; later in de native branch wordt opnieuw naar body gekeken
|
||
tool_choice_req = body.get("tool_choice") # 'auto' | 'none' | 'required' | {...}
|
||
try:
|
||
logger.info("🧰 tools_count=%s, tool_choice=%s", len(tools), tool_choice_req)
|
||
except Exception:
|
||
pass
|
||
|
||
# OWUI stuurt vaak "required" als: "er MOET een tool worden gebruikt".
|
||
# Als er precies 1 tool is meegegeven, normaliseren we dat naar "force deze tool".
|
||
if tool_choice_req == "required" and tools:
|
||
names = [ (t.get("function") or {}).get("name") for t in tools if t.get("function") ]
|
||
names = [ n for n in names if n ]
|
||
# 1) exact 1 → force die
|
||
if len(set(names)) == 1:
|
||
tool_choice_req = {"type":"function","function":{"name": names[0]}}
|
||
else:
|
||
# 2) meerdere → kies op basis van user prompt (noem de toolnaam)
|
||
last_user = next((m for m in reversed(norm_messages) if m.get("role")=="user"), {})
|
||
utext = (last_user.get("content") or "").lower()
|
||
mentioned = [n for n in names if n and n.lower() in utext]
|
||
if mentioned:
|
||
tool_choice_req = {"type":"function","function":{"name": mentioned[0]}}
|
||
logger.info("🔧 required->picked tool by mention: %s", mentioned[0])
|
||
|
||
|
||
|
||
# (1) Force: OWUI dwingt een specifieke tool af
|
||
if isinstance(tool_choice_req, dict) and (tool_choice_req.get("type") == "function"):
|
||
fname = (tool_choice_req.get("function") or {}).get("name")
|
||
if fname and fname not in TOOLS_REGISTRY:
|
||
# Onbekende tool → laat de LLM zelf native tool_calls teruggeven.
|
||
passthrough = dict(body)
|
||
passthrough["messages"] = norm_messages
|
||
passthrough["stream"] = False
|
||
client = app.state.HTTPX
|
||
r = await client.post(LLM_URL, json=passthrough)
|
||
try:
|
||
return JSONResponse(r.json(), status_code=r.status_code)
|
||
except Exception:
|
||
return PlainTextResponse(r.text, status_code=r.status_code)
|
||
|
||
if fname:
|
||
# Probeer lichte heuristiek voor bekende tools
|
||
last_user = next((m for m in reversed(norm_messages) if m.get("role")=="user"), {})
|
||
utext = (last_user.get("content") or "")
|
||
args: dict = {}
|
||
|
||
if fname == "rag_index_repo":
|
||
m = re.search(r'(https?://\S+)', utext)
|
||
if m: args["repo_url"] = m.group(1)
|
||
mb = re.search(r'\bbranch\s+([A-Za-z0-9._/-]+)', utext, re.I)
|
||
if mb: args["branch"] = mb.group(1)
|
||
elif fname == "rag_query":
|
||
args["query"] = utext.strip()
|
||
elif fname == "summarize_text":
|
||
m = re.search(r':\s*(.+)$', utext, re.S)
|
||
args["text"] = (m.group(1).strip() if m else utext.strip())[:16000]
|
||
elif fname == "analyze_text":
|
||
m = re.search(r':\s*(.+)$', utext, re.S)
|
||
args["text"] = (m.group(1).strip() if m else utext.strip())[:20000]
|
||
elif fname == "improve_text":
|
||
m = re.search(r':\s*(.+)$', utext, re.S)
|
||
args["text"] = (m.group(1).strip() if m else utext.strip())[:20000]
|
||
elif fname == "validate_code_text":
|
||
code = re.search(r"```.*?\n(.*?)```", utext, re.S)
|
||
args["code"] = (code.group(1).strip() if code else utext.strip())
|
||
elif fname == "improve_code_text":
|
||
code = re.search(r"```.*?\n(.*?)```", utext, re.S)
|
||
args["code"] = (code.group(1).strip() if code else utext.strip())
|
||
elif fname == "vision_analyze":
|
||
m = re.search(r'(data:image\/[a-zA-Z]+;base64,[A-Za-z0-9+/=]+|https?://\S+)', utext)
|
||
if m: args["image_url"] = m.group(1)
|
||
|
||
# Check verplichte velden; zo niet → native passthrough met alleen deze tool
|
||
required = (TOOLS_REGISTRY.get(fname, {}).get("parameters", {}) or {}).get("required", [])
|
||
if not all(k in args and args[k] for k in required):
|
||
passthrough = dict(body)
|
||
passthrough["messages"] = norm_messages
|
||
# Alleen deze tool meegeven + dwing deze tool af
|
||
only = [t for t in (body.get("tools") or []) if (t.get("function") or {}).get("name")==fname]
|
||
if only: passthrough["tools"] = only
|
||
passthrough["tool_choice"] = {"type":"function","function":{"name": fname}}
|
||
passthrough["stream"] = False
|
||
client = app.state.HTTPX
|
||
r = await client.post(LLM_URL, json=passthrough)
|
||
try:
|
||
return JSONResponse(r.json(), status_code=r.status_code)
|
||
except Exception:
|
||
return PlainTextResponse(r.text, status_code=r.status_code)
|
||
|
||
|
||
# Heuristiek geslaagd → stuur tool_calls terug (compat met OWUI)
|
||
return {
|
||
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"finish_reason": "tool_calls",
|
||
"message": {
|
||
"role": "assistant",
|
||
"tool_calls": [{
|
||
"id": f"call_{uuid.uuid4().hex[:8]}",
|
||
"type": "function",
|
||
"function": {"name": fname, "arguments": json.dumps(args, ensure_ascii=False)}
|
||
}]
|
||
}
|
||
}]
|
||
}
|
||
|
||
# Snelle escape: bij streaming en geen expliciete 'required' tool -> forceer directe streaming
|
||
if stream and tools and tool_choice_req in (None, "auto", "none") and \
|
||
os.getenv("STREAM_PREFER_DIRECT", "1").lower() not in ("0","false","no"):
|
||
tools = [] # bypass tool glue zodat we rechtstreeks naar de echte streaming gaan
|
||
|
||
# (2) Auto: vraag de LLM om 1+ function calls te produceren
|
||
if (tool_choice_req in (None, "auto")) and tools:
|
||
sys = _build_tools_system_prompt(tools)
|
||
ask = [{"role": "system", "content": sys}] + norm_messages
|
||
# jouw bestaande helper; hou 'm zoals je al gebruikt
|
||
resp = await llm_call_openai_compat(ask, stream=False, max_tokens=512)
|
||
txt = ((resp.get("choices") or [{}])[0].get("message") or {}).get("content", "") or ""
|
||
calls = _extract_tool_calls_from_text(txt)
|
||
if calls:
|
||
return {
|
||
"id": f"chatcmpl-{uuid.uuid4().hex}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [{
|
||
"index": 0,
|
||
"finish_reason": "tool_calls",
|
||
"message": {"role": "assistant", "tool_calls": calls}
|
||
}]
|
||
}
|
||
# --- einde minimal tool-calling glue ---
|
||
|
||
# Vision normalisatie (na tool->tekst normalisatie)
|
||
msgs, images_b64 = _normalize_openai_vision_messages(norm_messages)
|
||
messages = msgs if msgs else norm_messages
|
||
extra_payload = {"images": images_b64} if images_b64 else {}
|
||
|
||
# Speciale modellen
|
||
if model == "repo-agent":
|
||
if stream:
|
||
async def event_gen():
|
||
import asyncio, time, json, contextlib
|
||
# Stuur meteen een role-delta om bytes te pushen
|
||
head = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk",
|
||
"created": int(time.time()),"model": model,
|
||
"choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason": None}]}
|
||
yield f"data: {json.dumps(head)}\n\n"
|
||
HEARTBEAT = float(os.getenv("SSE_HEARTBEAT_SEC","10"))
|
||
done = asyncio.Event()
|
||
result = {"text": ""}
|
||
# geef wat vroege status in de buitenste generator (hier mag je wel yielden)
|
||
for s in ("🔎 Indexeren/zoeken…", "🧠 Analyseren…", "🛠️ Voorstel genereren…"):
|
||
data = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk",
|
||
"created": int(time.time()),"model": model,
|
||
"choices":[{"index":0,"delta":{"content": s + "\n"},"finish_reason": None}]}
|
||
yield f"data: {json.dumps(data)}\n\n"
|
||
await asyncio.sleep(0)
|
||
|
||
async def _work():
|
||
try:
|
||
result["text"] = await handle_repo_agent(messages, request)
|
||
finally:
|
||
done.set()
|
||
worker = asyncio.create_task(_work())
|
||
try:
|
||
while not done.is_set():
|
||
try:
|
||
await asyncio.wait_for(done.wait(), timeout=HEARTBEAT)
|
||
except asyncio.TimeoutError:
|
||
yield ": ping\n\n"
|
||
# klaar → stuur content
|
||
data = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk",
|
||
"created": int(time.time()),"model": model,
|
||
"choices":[{"index":0,"delta":{"content": result['text']}, "finish_reason": None}]}
|
||
yield f"data: {json.dumps(data)}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
finally:
|
||
worker.cancel()
|
||
with contextlib.suppress(Exception):
|
||
await worker
|
||
return StreamingResponse(
|
||
event_gen(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control":"no-cache, no-transform","X-Accel-Buffering":"no","Connection":"keep-alive"}
|
||
)
|
||
else:
|
||
text = await handle_repo_agent(messages, request)
|
||
return JSONResponse(_openai_chat_response(model, text, messages))
|
||
|
||
|
||
if model == "repo-qa":
|
||
repo_hint, question, branch, n_ctx = _parse_repo_qa_from_messages(messages)
|
||
if not repo_hint:
|
||
friendly = ("Geef een repo-hint (URL of owner/repo).\nVoorbeeld:\n"
|
||
"repo: (repo URL eindigend op .git)\n"
|
||
"Vraag: (de vraag die je hebt over de code in de repo)")
|
||
return JSONResponse(_openai_chat_response("repo-qa", friendly, messages))
|
||
try:
|
||
text = await repo_qa_answer(repo_hint, question, branch=branch, n_ctx=n_ctx)
|
||
except Exception as e:
|
||
logging.exception("repo-qa failed")
|
||
text = f"Repo-QA faalde: {e}"
|
||
if stream:
|
||
async def event_gen():
|
||
data = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk","created": int(time.time()),
|
||
"model": "repo-qa", "choices":[{"index":0,"delta":{"content": text},"finish_reason": None}]}
|
||
yield f"data: {json.dumps(data)}\n\n"; yield "data: [DONE]\n\n"
|
||
return StreamingResponse(event_gen(), media_type="text/event-stream")
|
||
return JSONResponse(_openai_chat_response("repo-qa", text, messages))
|
||
|
||
# --- Tool-calling ---
|
||
tools_payload = body.get("tools")
|
||
tool_choice_req = body.get("tool_choice", "auto")
|
||
if tools_payload:
|
||
# native passthrough (llama.cpp voert tools NIET zelf uit, maar UI kan dit willen)
|
||
if LLM_FUNCTION_CALLING_MODE in ("native","auto") and stream:
|
||
passthrough = dict(body); passthrough["messages"]=messages
|
||
if images_b64: passthrough["images"]=images_b64
|
||
async def _aiter():
|
||
import asyncio, contextlib
|
||
client = app.state.HTTPX
|
||
async with client.stream("POST", LLM_URL, json=passthrough, timeout=None) as r:
|
||
r.raise_for_status()
|
||
HEARTBEAT = float(os.getenv("SSE_HEARTBEAT_SEC","10"))
|
||
q: asyncio.Queue[bytes] = asyncio.Queue(maxsize=100)
|
||
async def _reader():
|
||
try:
|
||
async for chunk in r.aiter_bytes():
|
||
if chunk:
|
||
await q.put(chunk)
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
await q.put(b"__EOF__")
|
||
reader_task = asyncio.create_task(_reader())
|
||
try:
|
||
while True:
|
||
try:
|
||
chunk = await asyncio.wait_for(q.get(), timeout=HEARTBEAT)
|
||
except asyncio.TimeoutError:
|
||
yield b": ping\n\n"
|
||
continue
|
||
if chunk == b"__EOF__":
|
||
break
|
||
yield chunk
|
||
finally:
|
||
reader_task.cancel()
|
||
with contextlib.suppress(Exception):
|
||
await reader_task
|
||
return StreamingResponse(
|
||
_aiter(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control":"no-cache, no-transform","X-Accel-Buffering":"no","Connection":"keep-alive"}
|
||
)
|
||
|
||
if LLM_FUNCTION_CALLING_MODE in ("native","auto") and not stream:
|
||
# Relay-modus: laat LLM tools kiezen, bridge voert uit, daarna 2e run.
|
||
relay = os.getenv("LLM_TOOL_RUNNER", "passthrough").lower() == "bridge"
|
||
client = app.state.HTTPX
|
||
if not relay:
|
||
passthrough = dict(body); passthrough["messages"]=messages
|
||
if images_b64: passthrough["images"]=images_b64
|
||
r = await client.post(LLM_URL, json=passthrough)
|
||
try:
|
||
return JSONResponse(r.json(), status_code=r.status_code)
|
||
except Exception:
|
||
return PlainTextResponse(r.text, status_code=r.status_code)
|
||
|
||
# (A) 1e call: vraag de LLM om tool_calls (geen stream)
|
||
first_req = dict(body)
|
||
first_req["messages"] = messages
|
||
first_req["stream"] = False
|
||
if images_b64: first_req["images"] = images_b64
|
||
r1 = await client.post(LLM_URL, json=first_req)
|
||
try:
|
||
data1 = r1.json()
|
||
except Exception:
|
||
return PlainTextResponse(r1.text, status_code=r1.status_code)
|
||
msg1 = ((data1.get("choices") or [{}])[0].get("message") or {})
|
||
tool_calls = msg1.get("tool_calls") or []
|
||
# Geen tool-calls? Geef direct door.
|
||
if not tool_calls:
|
||
return JSONResponse(data1, status_code=r1.status_code)
|
||
|
||
# (B) voer tool_calls lokaal uit
|
||
tool_msgs = []
|
||
for tc in tool_calls:
|
||
fn = ((tc or {}).get("function") or {})
|
||
tname = fn.get("name")
|
||
raw_args = fn.get("arguments") or "{}"
|
||
try:
|
||
args = json.loads(raw_args) if isinstance(raw_args, str) else (raw_args or {})
|
||
except Exception:
|
||
args = {}
|
||
if not tname or tname not in TOOLS_REGISTRY:
|
||
out = {"error": f"Unknown tool '{tname}'"}
|
||
else:
|
||
try:
|
||
out = await _execute_tool(tname, args)
|
||
except Exception as e:
|
||
out = {"error": str(e)}
|
||
tool_msgs.append({
|
||
"role": "tool",
|
||
"tool_call_id": tc.get("id"),
|
||
"name": tname or "unknown",
|
||
"content": json.dumps(out, ensure_ascii=False)
|
||
})
|
||
|
||
# (C) 2e call: geef tool outputs terug aan LLM voor eindantwoord
|
||
follow_messages = messages + [
|
||
{"role": "assistant", "tool_calls": tool_calls},
|
||
*tool_msgs
|
||
]
|
||
second_req = dict(body)
|
||
second_req["messages"] = follow_messages
|
||
second_req["stream"] = False
|
||
# images opnieuw meesturen is niet nodig, maar kan geen kwaad:
|
||
if images_b64: second_req["images"] = images_b64
|
||
r2 = await client.post(LLM_URL, json=second_req)
|
||
try:
|
||
return JSONResponse(r2.json(), status_code=r2.status_code)
|
||
except Exception:
|
||
return PlainTextResponse(r2.text, status_code=r2.status_code)
|
||
|
||
# shim (non-stream)
|
||
if LLM_FUNCTION_CALLING_MODE == "shim" and not stream:
|
||
# 1) Laat LLM beslissen WELKE tool + args (strikte JSON)
|
||
tool_lines = []
|
||
for tname, t in TOOLS_REGISTRY.items():
|
||
tool_lines.append(f"- {tname}: {t['description']}\n schema: {json.dumps(t['parameters'])}")
|
||
sys = ("You can call tools.\n"
|
||
"If a tool is needed, reply with ONLY valid JSON:\n"
|
||
'{"call_tool":{"name":"<tool_name>","arguments":{...}}}\n'
|
||
"Otherwise reply with ONLY: {\"final_answer\":\"...\"}\n\nTools:\n" + "\n".join(tool_lines))
|
||
decide = await llm_call_openai_compat(
|
||
[{"role":"system","content":sys}] + messages,
|
||
stream=False, temperature=float(body.get("temperature",0.2)),
|
||
top_p=float(body.get("top_p",0.9)), max_tokens=min(512, int(body.get("max_tokens",1024))),
|
||
extra=extra_payload if extra_payload else None
|
||
)
|
||
raw = ((decide.get("choices",[{}])[0].get("message",{}) or {}).get("content","") or "").strip()
|
||
# haal evt. ```json fences weg
|
||
if raw.startswith("```"):
|
||
raw = extract_code_block(raw)
|
||
try:
|
||
obj = json.loads(raw)
|
||
except Exception:
|
||
# Model gaf geen JSON → behandel als final_answer
|
||
return JSONResponse(decide)
|
||
|
||
if "final_answer" in obj:
|
||
return JSONResponse({
|
||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||
"object": "chat.completion",
|
||
"created": int(time.time()),
|
||
"model": model,
|
||
"choices": [{"index":0,"message":{"role":"assistant","content": obj["final_answer"]},"finish_reason":"stop"}],
|
||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||
})
|
||
|
||
call = (obj.get("call_tool") or {})
|
||
tname = call.get("name")
|
||
if tname not in TOOLS_REGISTRY:
|
||
return JSONResponse(_openai_chat_response(model, f"Onbekende tool: {tname}", messages))
|
||
args = call.get("arguments") or {}
|
||
tool_out = await _execute_tool(tname, args)
|
||
|
||
follow = messages + [
|
||
{"role":"assistant","content": f"TOOL[{tname}] OUTPUT:\n{json.dumps(tool_out)[:8000]}"},
|
||
{"role":"user","content": "Gebruik bovenstaande tool-output en geef nu het definitieve antwoord."}
|
||
]
|
||
final = await llm_call_openai_compat(
|
||
follow, stream=False,
|
||
temperature=float(body.get("temperature",0.2)),
|
||
top_p=float(body.get("top_p",0.9)),
|
||
max_tokens=int(body.get("max_tokens",1024)),
|
||
extra=extra_payload if extra_payload else None
|
||
)
|
||
return JSONResponse(final)
|
||
|
||
# --- streaming passthrough (geen tools) ---
|
||
# --- streaming via onze wachtrij + upstreams ---
|
||
# --- streaming (zonder tools): nu mét windowing trim ---
|
||
# --- streaming (zonder tools): windowing + server-side auto-continue stream ---
|
||
simulate_stream=False
|
||
if simulate_stream:
|
||
if stream:
|
||
LLM_WINDOWING_ENABLE = os.getenv("LLM_WINDOWING_ENABLE", "1").lower() not in ("0","false","no")
|
||
MAX_CTX_TOKENS = int(os.getenv("LLM_CONTEXT_TOKENS", "13021"))
|
||
RESP_RESERVE = int(os.getenv("LLM_RESPONSE_RESERVE", "1024"))
|
||
temperature = float(body.get("temperature", 0.2))
|
||
top_p = float(body.get("top_p", 0.9))
|
||
# respecteer env-override voor default
|
||
_default_max = int(os.getenv("LLM_DEFAULT_MAX_TOKENS", "1024"))
|
||
max_tokens = int(body.get("max_tokens", _default_max))
|
||
MAX_AUTOCONT = int(os.getenv("LLM_AUTO_CONTINUES", "3"))
|
||
|
||
trimmed_stream_msgs = messages
|
||
try:
|
||
if LLM_WINDOWING_ENABLE and 'ConversationWindow' in globals():
|
||
#thread_id = derive_thread_id({"model": model, "messages": messages}) if 'derive_thread_id' in globals() else uuid.uuid4().hex
|
||
thread_id = derive_thread_id(messages) if 'derive_thread_id' in globals() else uuid.uuid4().hex
|
||
running_summary = SUMMARY_STORE.get(thread_id, "") if isinstance(SUMMARY_STORE, dict) else (SUMMARY_STORE.get(thread_id) or "")
|
||
win = ConversationWindow(max_ctx_tokens=MAX_CTX_TOKENS, response_reserve=RESP_RESERVE,
|
||
tok_len=approx_token_count, running_summary=running_summary,
|
||
summary_header="Samenvatting tot nu toe")
|
||
for m in messages:
|
||
role, content = m.get("role","user"), m.get("content","")
|
||
if role in ("system","user","assistant"):
|
||
win.add(role, content)
|
||
async def _summarizer(old: str, chunk_msgs: list[dict]) -> str:
|
||
chunk_text = ""
|
||
for m in chunk_msgs:
|
||
chunk_text += f"\n[{m.get('role','user')}] {m.get('content','')}"
|
||
prompt = [
|
||
{"role":"system","content":"Je bent een bondige notulist. Vat samen in max 10 bullets (feiten/besluiten/acties)."},
|
||
{"role":"user","content": f"Vorige samenvatting:\n{old}\n\nNieuwe geschiedenis:\n{chunk_text}\n\nGeef geüpdatete samenvatting."}
|
||
]
|
||
resp = await llm_call_openai_compat(prompt, stream=False, temperature=0.1, top_p=1.0, max_tokens=300)
|
||
return (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content", old or "")
|
||
trimmed_stream_msgs = await win.build_within_budget(system_prompt=None, summarizer=_summarizer)
|
||
new_summary = getattr(win, "running_summary", running_summary)
|
||
if isinstance(SUMMARY_STORE, dict):
|
||
if new_summary and new_summary != running_summary:
|
||
SUMMARY_STORE[thread_id] = new_summary
|
||
else:
|
||
try:
|
||
if new_summary and new_summary != running_summary:
|
||
SUMMARY_STORE.update(thread_id, new_summary) # type: ignore[attr-defined]
|
||
except Exception:
|
||
pass
|
||
except Exception:
|
||
trimmed_stream_msgs = messages
|
||
|
||
# 1) haal volledige tekst op met non-stream + auto-continue
|
||
full_text, last_resp = await _complete_with_autocontinue(
|
||
trimmed_stream_msgs,
|
||
model=model, temperature=temperature, top_p=top_p,
|
||
max_tokens=max_tokens, extra_payload=extra_payload,
|
||
max_autocont=MAX_AUTOCONT
|
||
)
|
||
|
||
# 2) stream deze tekst als SSE in hapjes (simuleer live tokens)
|
||
async def _emit_sse():
|
||
created = int(time.time())
|
||
model_name = model
|
||
CHUNK = int(os.getenv("LLM_STREAM_CHUNK_CHARS", "800"))
|
||
# optioneel: stuur meteen role-delta (sommige UIs waarderen dat)
|
||
head = {
|
||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||
"object": "chat.completion.chunk",
|
||
"created": created,
|
||
"model": model_name,
|
||
"choices": [{"index":0,"delta":{"role":"assistant"},"finish_reason": None}]
|
||
}
|
||
yield ("data: " + json.dumps(head, ensure_ascii=False) + "\n\n").encode("utf-8")
|
||
# content in blokken
|
||
for i in range(0, len(full_text), CHUNK):
|
||
piece = full_text[i:i+CHUNK]
|
||
data = {
|
||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||
"object": "chat.completion.chunk",
|
||
"created": int(time.time()),
|
||
"model": model_name,
|
||
"choices": [{"index":0,"delta":{"content": piece},"finish_reason": None}]
|
||
}
|
||
yield ("data: " + json.dumps(data, ensure_ascii=False) + "\n\n").encode("utf-8")
|
||
await asyncio.sleep(0) # laat event loop ademen
|
||
# afsluiten
|
||
done = {"id": f"chatcmpl-{uuid.uuid4().hex[:12]}","object":"chat.completion.chunk",
|
||
"created": int(time.time()),"model": model_name,
|
||
"choices": [{"index":0,"delta":{},"finish_reason":"stop"}]}
|
||
yield ("data: " + json.dumps(done, ensure_ascii=False) + "\n\n").encode("utf-8")
|
||
yield b"data: [DONE]\n\n"
|
||
|
||
return StreamingResponse(
|
||
_emit_sse(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control":"no-cache, no-transform","X-Accel-Buffering":"no","Connection":"keep-alive"}
|
||
)
|
||
else:
|
||
# --- ÉCHTE streaming (geen tools): direct passthrough met heartbeats ---
|
||
if stream:
|
||
temperature = float(body.get("temperature", 0.2))
|
||
top_p = float(body.get("top_p", 0.9))
|
||
_default_max = int(os.getenv("LLM_DEFAULT_MAX_TOKENS", "1024"))
|
||
max_tokens = int(body.get("max_tokens", _default_max))
|
||
return await llm_call_openai_compat(
|
||
messages,
|
||
model=model,
|
||
stream=True,
|
||
temperature=temperature,
|
||
top_p=top_p,
|
||
max_tokens=max_tokens,
|
||
extra=extra_payload if extra_payload else None
|
||
)
|
||
|
||
|
||
|
||
# --- non-stream: windowing + auto-continue (zoals eerder gepatcht) ---
|
||
LLM_WINDOWING_ENABLE = os.getenv("LLM_WINDOWING_ENABLE", "1").lower() not in ("0","false","no")
|
||
MAX_CTX_TOKENS = int(os.getenv("LLM_CONTEXT_TOKENS", "13021"))
|
||
RESP_RESERVE = int(os.getenv("LLM_RESPONSE_RESERVE", "1024"))
|
||
MAX_AUTOCONT = int(os.getenv("LLM_AUTO_CONTINUES", "2"))
|
||
temperature = float(body.get("temperature", 0.2))
|
||
top_p = float(body.get("top_p", 0.9))
|
||
# Laat env de default bepalen, zodat OWUI niet hard op 1024 blijft hangen
|
||
_default_max = int(os.getenv("LLM_DEFAULT_MAX_TOKENS", "13027"))
|
||
max_tokens = int(body.get("max_tokens", _default_max))
|
||
|
||
trimmed = messages
|
||
try:
|
||
if LLM_WINDOWING_ENABLE and 'ConversationWindow' in globals():
|
||
#thread_id = derive_thread_id({"model": model, "messages": messages}) if 'derive_thread_id' in globals() else uuid.uuid4().hex
|
||
thread_id = derive_thread_id(messages) if 'derive_thread_id' in globals() else uuid.uuid4().hex
|
||
running_summary = SUMMARY_STORE.get(thread_id, "") if isinstance(SUMMARY_STORE, dict) else (SUMMARY_STORE.get(thread_id) or "")
|
||
win = ConversationWindow(max_ctx_tokens=MAX_CTX_TOKENS, response_reserve=RESP_RESERVE,
|
||
tok_len=approx_token_count, running_summary=running_summary,
|
||
summary_header="Samenvatting tot nu toe")
|
||
for m in messages:
|
||
role, content = m.get("role","user"), m.get("content","")
|
||
if role in ("system","user","assistant"):
|
||
win.add(role, content)
|
||
async def _summarizer(old: str, chunk_msgs: list[dict]) -> str:
|
||
chunk_text = ""
|
||
for m in chunk_msgs:
|
||
chunk_text += f"\n[{m.get('role','user')}] {m.get('content','')}"
|
||
prompt = [
|
||
{"role":"system","content":"Je bent een bondige notulist. Vat samen in max 10 bullets (feiten/besluiten/acties)."},
|
||
{"role":"user","content": f"Vorige samenvatting:\n{old}\n\nNieuwe geschiedenis:\n{chunk_text}\n\nGeef geüpdatete samenvatting."}
|
||
]
|
||
resp = await llm_call_openai_compat(prompt, stream=False, temperature=0.1, top_p=1.0, max_tokens=300)
|
||
return (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content", old or "")
|
||
trimmed = await win.build_within_budget(system_prompt=None, summarizer=_summarizer)
|
||
new_summary = getattr(win, "running_summary", running_summary)
|
||
if isinstance(SUMMARY_STORE, dict):
|
||
if new_summary and new_summary != running_summary:
|
||
SUMMARY_STORE[thread_id] = new_summary
|
||
else:
|
||
try:
|
||
if new_summary and new_summary != running_summary:
|
||
SUMMARY_STORE.update(thread_id, new_summary) # type: ignore[attr-defined]
|
||
except Exception:
|
||
pass
|
||
except Exception:
|
||
trimmed = messages
|
||
|
||
resp = await llm_call_openai_compat(trimmed, model=model, stream=False,
|
||
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||
extra=extra_payload if extra_payload else None)
|
||
content = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","")
|
||
finish_reason = (resp.get("choices",[{}])[0] or {}).get("finish_reason")
|
||
|
||
continues = 0
|
||
# extra trigger: als we nagenoeg max_tokens geraakt hebben, ga door
|
||
NEAR_MAX = int(os.getenv("LLM_AUTOCONT_NEAR_MAX", "48"))
|
||
def _near_cap(txt: str) -> bool:
|
||
try:
|
||
return approx_token_count(txt) >= max_tokens - NEAR_MAX
|
||
except Exception:
|
||
return False
|
||
while continues < MAX_AUTOCONT and (
|
||
finish_reason in ("length","content_filter")
|
||
or content.endswith("…")
|
||
or _near_cap(content)
|
||
):
|
||
follow = trimmed + [{"role":"assistant","content": content},
|
||
{"role":"user","content": "Ga verder zonder herhaling; alleen het vervolg."}]
|
||
nxt = await llm_call_openai_compat(follow, model=model, stream=False,
|
||
temperature=temperature, top_p=top_p, max_tokens=max_tokens,
|
||
extra=extra_payload if extra_payload else None)
|
||
more = (nxt.get("choices",[{}])[0].get("message",{}) or {}).get("content","")
|
||
content = (content + ("\n" if content and more else "") + more).strip()
|
||
finish_reason = (nxt.get("choices",[{}])[0] or {}).get("finish_reason")
|
||
continues += 1
|
||
|
||
if content:
|
||
resp["choices"][0]["message"]["content"] = content
|
||
resp["choices"][0]["finish_reason"] = finish_reason or resp["choices"][0].get("finish_reason", "stop")
|
||
try:
|
||
prompt_tokens = count_message_tokens(trimmed) if 'count_message_tokens' in globals() else approx_token_count(json.dumps(trimmed))
|
||
except Exception:
|
||
prompt_tokens = approx_token_count(json.dumps(trimmed))
|
||
completion_tokens = approx_token_count(content)
|
||
resp["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens,
|
||
"total_tokens": prompt_tokens + completion_tokens}
|
||
return JSONResponse(resp)
|
||
|
||
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# RAG Index (repo → Chroma)
|
||
# -----------------------------------------------------------------------------
|
||
|
||
_REPO_CACHE_PATH = os.path.join(CHROMA_PATH or "/rag_db", "repo_index_cache.json")
|
||
_SYMBOL_INDEX_PATH = os.path.join(CHROMA_PATH or "/rag_db", "symbol_index.json")
|
||
|
||
def _load_symbol_index() -> dict:
|
||
try:
|
||
with open(_SYMBOL_INDEX_PATH, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
return {}
|
||
|
||
def _save_symbol_index(data: dict):
|
||
try:
|
||
_atomic_json_write(_SYMBOL_INDEX_PATH, data)
|
||
except Exception:
|
||
pass
|
||
|
||
_SYMBOL_INDEX: dict[str, dict[str, list[dict]]] = _load_symbol_index()
|
||
|
||
def _symkey(collection_effective: str, repo_name: str) -> str:
|
||
return f"{collection_effective}|{repo_name}"
|
||
|
||
|
||
def _load_repo_index_cache() -> dict:
|
||
try:
|
||
with open(_REPO_CACHE_PATH, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
return {}
|
||
|
||
def _save_repo_index_cache(data: dict):
|
||
try:
|
||
_atomic_json_write(_REPO_CACHE_PATH, data)
|
||
except Exception:
|
||
pass
|
||
|
||
def _sha1_text(t: str) -> str:
|
||
return hashlib.sha1(t.encode("utf-8", errors="ignore")).hexdigest()
|
||
|
||
def _repo_owner_repo_from_url(u: str) -> str:
|
||
try:
|
||
from urllib.parse import urlparse
|
||
p = urlparse(u).path.strip("/").split("/")
|
||
if len(p) >= 2:
|
||
name = p[-1]
|
||
if name.endswith(".git"): name = name[:-4]
|
||
return f"{p[-2]}/{name}"
|
||
except Exception:
|
||
pass
|
||
# fallback: basename
|
||
return os.path.basename(u).removesuffix(".git")
|
||
|
||
def _summary_store_path(repo_url: str, branch: str) -> str:
|
||
key = _repo_owner_repo_from_url(repo_url).replace("/", "__")
|
||
return os.path.join(_SUMMARY_DIR, f"{key}__{branch}__{_EMBEDDER.slug}.json")
|
||
|
||
def _load_summary_store(path: str) -> dict:
|
||
try:
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
except Exception:
|
||
return {}
|
||
|
||
def _save_summary_store(path: str, data: dict):
|
||
try:
|
||
_atomic_json_write(path, data)
|
||
except Exception:
|
||
pass
|
||
|
||
async def _summarize_files_llm(items: list[tuple[str, str]]) -> dict[str, str]:
|
||
"""
|
||
items: list of (path, text) -> returns {path: one-line summary}
|
||
On-demand, tiny prompts; keep it cheap.
|
||
"""
|
||
out: dict[str, str] = {}
|
||
for path, text in items:
|
||
snippet = text[:2000]
|
||
prompt = [
|
||
{"role":"system","content":"Geef één korte, functionele beschrijving (max 20 woorden) van dit bestand. Geen opsomming, geen code, 1 zin."},
|
||
{"role":"user","content": f"Pad: {path}\n\nInhoud (ingekort):\n{snippet}\n\nAntwoord: "}
|
||
]
|
||
try:
|
||
resp = await llm_call_openai_compat(prompt, stream=False, temperature=0.1, top_p=1.0, max_tokens=64)
|
||
summ = ((resp.get("choices") or [{}])[0].get("message") or {}).get("content","").strip()
|
||
except Exception:
|
||
summ = ""
|
||
out[path] = summ or "Bestand (korte beschrijving niet beschikbaar)"
|
||
return out
|
||
|
||
async def _repo_summary_get_internal(repo_url: str, branch: str, paths: list[str]) -> dict[str, dict]:
|
||
"""
|
||
Returns {path: {"summary": str, "sha": str}}. Caches per file SHA, auto-invalidates when changed.
|
||
"""
|
||
repo_path = await _get_git_repo_async(repo_url, branch)
|
||
store_path = _summary_store_path(repo_url, branch)
|
||
store = _load_summary_store(store_path) # structure: {path: {"sha": "...", "summary": "..."}}
|
||
to_summarize: list[tuple[str, str]] = []
|
||
for rel in paths:
|
||
p = Path(repo_path) / rel
|
||
if not p.exists():
|
||
continue
|
||
text = _read_text_file(p)
|
||
fsha = _sha1_text(text)
|
||
rec = store.get(rel) or {}
|
||
if rec.get("sha") != fsha or not rec.get("summary"):
|
||
to_summarize.append((rel, text))
|
||
if to_summarize:
|
||
new_summ = await _summarize_files_llm(to_summarize)
|
||
for rel, _ in to_summarize:
|
||
p = Path(repo_path) / rel
|
||
if not p.exists():
|
||
continue
|
||
text = _read_text_file(p)
|
||
store[rel] = {"sha": _sha1_text(text), "summary": new_summ.get(rel, "")}
|
||
_save_summary_store(store_path, store)
|
||
# pack result
|
||
return {rel: {"summary": (store.get(rel) or {}).get("summary",""), "sha": (store.get(rel) or {}).get("sha","")}
|
||
for rel in paths}
|
||
|
||
async def _grep_repo_paths(repo_url: str, branch: str, query: str, limit: int = 100) -> list[str]:
|
||
"""
|
||
Simpele fallback: lineaire scan op tekst-inhoud (alleen text-exts) -> unieke paden.
|
||
"""
|
||
repo_path = await _get_git_repo_async(repo_url, branch)
|
||
root = Path(repo_path)
|
||
qlow = (query or "").lower().strip()
|
||
hits: list[str] = []
|
||
for p in root.rglob("*"):
|
||
if p.is_dir():
|
||
continue
|
||
if set(p.parts) & PROFILE_EXCLUDE_DIRS:
|
||
continue
|
||
if p.suffix.lower() in BINARY_SKIP:
|
||
continue
|
||
try:
|
||
txt = _read_text_file(p)
|
||
except Exception:
|
||
continue
|
||
if not txt:
|
||
continue
|
||
if qlow in txt.lower():
|
||
hits.append(str(p.relative_to(root)))
|
||
if len(hits) >= limit:
|
||
break
|
||
# uniq (preserve order)
|
||
seen=set(); out=[]
|
||
for h in hits:
|
||
if h in seen: continue
|
||
seen.add(h); out.append(h)
|
||
return out
|
||
|
||
async def _meili_search_internal(query: str, *, repo_full: str | None, branch: str | None, limit: int = 50) -> list[dict]:
|
||
"""
|
||
Returns list of hits: {"path": str, "chunk_index": int|None, "score": float, "highlights": str}
|
||
"""
|
||
if not MEILI_ENABLED:
|
||
return []
|
||
body = {"q": query, "limit": int(limit)}
|
||
filters = []
|
||
if repo_full:
|
||
filters.append(f'repo_full = "{repo_full}"')
|
||
if branch:
|
||
filters.append(f'branch = "{branch}"')
|
||
if filters:
|
||
body["filter"] = " AND ".join(filters)
|
||
headers = {"Content-Type":"application/json"}
|
||
if MEILI_API_KEY:
|
||
headers["Authorization"] = f"Bearer {MEILI_API_KEY}"
|
||
# Meili ook vaak 'X-Meili-API-Key'; houd beide in ere:
|
||
headers["X-Meili-API-Key"] = MEILI_API_KEY
|
||
url = f"{MEILI_URL}/indexes/{MEILI_INDEX}/search"
|
||
client = app.state.HTTPX_PROXY if hasattr(app.state, "HTTPX_PROXY") else httpx.AsyncClient()
|
||
try:
|
||
r = await client.post(url, headers=headers, json=body, timeout=httpx.Timeout(30.0, connect=10.0))
|
||
r.raise_for_status()
|
||
data = r.json()
|
||
except Exception:
|
||
return []
|
||
hits = []
|
||
for h in data.get("hits", []):
|
||
hits.append({
|
||
"path": h.get("path"),
|
||
"chunk_index": h.get("chunk_index"),
|
||
"score": float(h.get("_rankingScore", h.get("_matchesPosition") or 0) or 0),
|
||
"highlights": h.get("_formatted") or {}
|
||
})
|
||
# De-dup op path, hoogste score eerst
|
||
best: dict[str, dict] = {}
|
||
for h in sorted(hits, key=lambda x: x["score"], reverse=True):
|
||
p = h.get("path")
|
||
if p and p not in best:
|
||
best[p] = h
|
||
return list(best.values())
|
||
|
||
async def _search_first_candidates(repo_url: str, branch: str, query: str, explicit_paths: list[str] | None = None, limit: int = 50) -> list[str]:
|
||
"""
|
||
Voorkeursvolgorde: expliciete paden → Meilisearch → grep → (laatste redmiddel) RAG op bestandsniveau
|
||
"""
|
||
if explicit_paths:
|
||
# valideer dat ze bestaan in de repo
|
||
repo_path = await _get_git_repo_async(repo_url, branch)
|
||
out=[]
|
||
for rel in explicit_paths:
|
||
if (Path(repo_path)/rel).exists():
|
||
out.append(rel)
|
||
return out
|
||
repo_full = _repo_owner_repo_from_url(repo_url)
|
||
meili_hits = await _meili_search_internal(query, repo_full=repo_full, branch=branch, limit=limit)
|
||
if meili_hits:
|
||
return [h["path"] for h in meili_hits]
|
||
# fallback grep
|
||
return await _grep_repo_paths(repo_url, branch, query, limit=limit)
|
||
|
||
|
||
def _match_any(name: str, patterns: list[str]) -> bool:
|
||
return any(fnmatch.fnmatch(name, pat) for pat in patterns)
|
||
|
||
def _rag_index_repo_sync(
|
||
*,
|
||
repo_url: str,
|
||
branch: str = "main",
|
||
profile: str = "auto",
|
||
include: str = "",
|
||
exclude_dirs: str = "",
|
||
chunk_chars: int = 3000,
|
||
overlap: int = 400,
|
||
collection_name: str = "code_docs",
|
||
force: bool = False
|
||
) -> dict:
|
||
repo_path = get_git_repo(repo_url, branch)
|
||
root = Path(repo_path)
|
||
repo = git.Repo(repo_path)
|
||
# HEAD sha
|
||
try:
|
||
head_sha = repo.head.commit.hexsha
|
||
except Exception:
|
||
head_sha = ""
|
||
# Cache key
|
||
cache = _load_repo_index_cache()
|
||
repo_key = f"{os.path.basename(repo_url)}|{branch}|{_collection_versioned(collection_name)}|{_EMBEDDER.slug}"
|
||
cached = cache.get(repo_key)
|
||
|
||
if not force and cached and cached.get("head_sha") == head_sha:
|
||
return {
|
||
"status": "skipped",
|
||
"reason": "head_unchanged",
|
||
"collection_effective": _collection_versioned(collection_name),
|
||
"detected_profile": profile if profile != "auto" else _detect_repo_profile(root),
|
||
"files_indexed": cached.get("files_indexed", 0),
|
||
"chunks_added": 0,
|
||
"note": f"Skip: HEAD {head_sha} al geïndexeerd (embedder={_EMBEDDER.slug})"
|
||
}
|
||
|
||
# Profile → include patterns
|
||
if profile == "auto":
|
||
profile = _detect_repo_profile(root)
|
||
include_patterns = PROFILE_INCLUDES.get(profile, PROFILE_INCLUDES["generic"])
|
||
if include.strip():
|
||
include_patterns = [p.strip() for p in include.split(",") if p.strip()]
|
||
|
||
exclude_set = set(PROFILE_EXCLUDE_DIRS)
|
||
if exclude_dirs.strip():
|
||
exclude_set |= {d.strip() for d in exclude_dirs.split(",") if d.strip()}
|
||
|
||
collection = _get_collection(collection_name)
|
||
|
||
# --- Slim chunking toggles (env of via profile='smart') ---
|
||
use_smart_chunk = (
|
||
(os.getenv("CHROMA_SMART_CHUNK","1").lower() not in ("0","false","no"))
|
||
or (str(profile).lower() == "smart")
|
||
)
|
||
CH_TGT = int(os.getenv("CHUNK_TARGET_CHARS", "1800"))
|
||
CH_MAX = int(os.getenv("CHUNK_HARD_MAX", "2600"))
|
||
CH_MIN = int(os.getenv("CHUNK_MIN_CHARS", "800"))
|
||
|
||
files_indexed = 0
|
||
chunks_added = 0
|
||
batch_documents: list[str] = []
|
||
batch_metadatas: list[dict] = []
|
||
batch_ids: list[str] = []
|
||
BATCH_SIZE = 64
|
||
|
||
def flush():
|
||
nonlocal chunks_added
|
||
if not batch_documents:
|
||
return
|
||
_collection_add(collection, batch_documents, batch_metadatas, batch_ids)
|
||
chunks_added += len(batch_documents)
|
||
batch_documents.clear(); batch_metadatas.clear(); batch_ids.clear()
|
||
docs_for_bm25: list[dict]=[]
|
||
docs_for_meili: list[dict]=[]
|
||
# tijdelijke symbol map voor deze run
|
||
# structuur: symbol_lower -> [{"path": str, "chunk_index": int}]
|
||
run_symbol_map: dict[str, list[dict]] = {}
|
||
|
||
def _extract_symbol_hints(txt: str) -> list[str]:
|
||
hints = []
|
||
for pat in [
|
||
r"\bclass\s+([A-Za-z_][A-Za-z0-9_]*)\b",
|
||
r"\binterface\s+([A-Za-z_][A-Za-z0-9_]*)\b",
|
||
r"\btrait\s+([A-Za-z_][A-Za-z0-9_]*)\b",
|
||
r"\bdef\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
|
||
r"\bfunction\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
|
||
r"\bpublic\s+function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
|
||
r"\bprotected\s+function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
|
||
r"\bprivate\s+function\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
|
||
]:
|
||
try:
|
||
hints += re.findall(pat, txt)
|
||
except Exception:
|
||
pass
|
||
# unique, cap
|
||
out = []
|
||
for h in hints:
|
||
if h not in out:
|
||
out.append(h)
|
||
if len(out) >= 16:
|
||
break
|
||
return out
|
||
# precompute owner/repo once
|
||
repo_full_pre = _repo_owner_repo_from_url(repo_url)
|
||
deleted_paths: set[str] = set()
|
||
for p in root.rglob("*"):
|
||
if p.is_dir():
|
||
continue
|
||
if set(p.parts) & exclude_set:
|
||
continue
|
||
rel = str(p.relative_to(root))
|
||
if not _match_any(rel, include_patterns):
|
||
continue
|
||
|
||
#rel = str(p.relative_to(root))
|
||
text = _read_text_file(p)
|
||
if not text:
|
||
continue
|
||
|
||
files_indexed += 1
|
||
# Verwijder bestaande chunks voor dit bestand (voorkomt bloat/duplicaten)
|
||
if rel not in deleted_paths:
|
||
try:
|
||
collection.delete(where={
|
||
"repo_full": repo_full_pre, "branch": branch, "path": rel
|
||
})
|
||
except Exception:
|
||
pass
|
||
deleted_paths.add(rel)
|
||
#chunks = _chunk_text(text, chunk_chars=int(chunk_chars), overlap=int(overlap))
|
||
|
||
# Kies slim chunking (taal/structuur-aware) of vaste chunks met overlap
|
||
if use_smart_chunk:
|
||
chunks = smart_chunk_text(
|
||
text, rel,
|
||
target_chars=CH_TGT, hard_max=CH_MAX, min_chunk=CH_MIN
|
||
)
|
||
else:
|
||
chunks = _chunk_text(text, chunk_chars=int(chunk_chars), overlap=int(overlap))
|
||
|
||
for idx, ch in enumerate(chunks):
|
||
chunk_hash = _sha1_text(ch)
|
||
doc_id = f"{os.path.basename(repo_url)}:{branch}:{rel}:{idx}:{chunk_hash}"
|
||
# Zet zowel korte als lange repo-id neer
|
||
def _owner_repo_from_url(u: str) -> str:
|
||
try:
|
||
from urllib.parse import urlparse
|
||
p = urlparse(u).path.strip("/").split("/")
|
||
if len(p) >= 2:
|
||
name = p[-1]
|
||
if name.endswith(".git"): name = name[:-4]
|
||
return f"{p[-2]}/{name}"
|
||
except Exception:
|
||
pass
|
||
return os.path.basename(u).removesuffix(".git")
|
||
meta = {
|
||
"source": "git-repo",
|
||
"repo": os.path.basename(repo_url).removesuffix(".git"),
|
||
"repo_full": _owner_repo_from_url(repo_url),
|
||
"branch": branch,
|
||
"path": rel,
|
||
"chunk_index": idx,
|
||
"profile": profile,
|
||
}
|
||
batch_documents.append(ch)
|
||
docs_for_bm25.append({"text": ch, "path": rel})
|
||
batch_metadatas.append(meta)
|
||
batch_ids.append(doc_id)
|
||
if MEILI_ENABLED:
|
||
docs_for_meili.append({
|
||
"id": hashlib.sha1(f"{doc_id}".encode("utf-8")).hexdigest(),
|
||
"repo": meta["repo"],
|
||
"repo_full": meta["repo_full"],
|
||
"branch": branch,
|
||
"path": rel,
|
||
"chunk_index": idx,
|
||
"text": ch[:4000], # houd het compact
|
||
"head_sha": head_sha,
|
||
"ts": int(time.time())
|
||
})
|
||
|
||
# verzamel symbolen per chunk voor symbol-index
|
||
if os.getenv("RAG_SYMBOL_INDEX", "1").lower() not in ("0","false","no"):
|
||
for sym in _extract_symbol_hints(ch):
|
||
run_symbol_map.setdefault(sym.lower(), []).append({"path": rel, "chunk_index": idx})
|
||
if len(batch_documents) >= BATCH_SIZE:
|
||
flush()
|
||
flush()
|
||
if BM25Okapi is not None and docs_for_bm25:
|
||
bm = BM25Okapi([_bm25_tok(d["text"]) for d in docs_for_bm25])
|
||
repo_key = f"{os.path.basename(repo_url)}|{branch}|{_collection_versioned(collection_name)}"
|
||
_BM25_BY_REPO[repo_key] = (bm, docs_for_bm25)
|
||
|
||
cache[repo_key] = {"head_sha": head_sha, "files_indexed": files_indexed, "embedder": _EMBEDDER.slug, "ts": int(time.time())}
|
||
_save_repo_index_cache(cache)
|
||
|
||
# --- merge & persist symbol index ---
|
||
if os.getenv("RAG_SYMBOL_INDEX", "1").lower() not in ("0","false","no"):
|
||
col_eff = _collection_versioned(collection_name)
|
||
repo_name = os.path.basename(repo_url)
|
||
key = _symkey(col_eff, repo_name)
|
||
base = _SYMBOL_INDEX.get(key, {})
|
||
# merge met dedupe en cap per symbol (om bloat te vermijden)
|
||
CAP = int(os.getenv("RAG_SYMBOL_CAP_PER_SYM", "200"))
|
||
for sym, entries in run_symbol_map.items():
|
||
dst = base.get(sym, [])
|
||
seen = {(d["path"], d["chunk_index"]) for d in dst}
|
||
for e in entries:
|
||
tup = (e["path"], e["chunk_index"])
|
||
if tup not in seen:
|
||
dst.append(e); seen.add(tup)
|
||
if len(dst) >= CAP:
|
||
break
|
||
base[sym] = dst
|
||
_SYMBOL_INDEX[key] = base
|
||
_save_symbol_index(_SYMBOL_INDEX)
|
||
|
||
# === (optioneel) Meilisearch bulk upsert ===
|
||
# NB: pas NA return plaatsen heeft geen effect; plaats dit blok vóór de return in jouw file.
|
||
if MEILI_ENABLED and docs_for_meili:
|
||
try:
|
||
headers = {"Content-Type":"application/json"}
|
||
if MEILI_API_KEY:
|
||
headers["Authorization"] = f"Bearer {MEILI_API_KEY}"
|
||
headers["X-Meili-API-Key"] = MEILI_API_KEY
|
||
url = f"{MEILI_URL}/indexes/{MEILI_INDEX}/documents?primaryKey=id"
|
||
# chunk upload om payloads klein te houden
|
||
CH = 500
|
||
for i in range(0, len(docs_for_meili), CH):
|
||
chunk = docs_for_meili[i:i+CH]
|
||
requests.post(url, headers=headers, data=json.dumps(chunk), timeout=30)
|
||
except Exception as e:
|
||
logger.warning("Meili bulk upsert failed: %s", e)
|
||
|
||
return {
|
||
"status": "ok",
|
||
"collection_effective": _collection_versioned(collection_name),
|
||
"detected_profile": profile,
|
||
"files_indexed": files_indexed,
|
||
"chunks_added": chunks_added,
|
||
"note": "Geïndexeerd met embedder=%s; smart_chunk=%s" % (_EMBEDDER.slug, bool(use_smart_chunk))
|
||
}
|
||
|
||
|
||
|
||
|
||
|
||
# Celery task
|
||
if celery_app:
|
||
@celery_app.task(name="task_rag_index_repo", bind=True, autoretry_for=(Exception,), retry_backoff=True, max_retries=5)
|
||
def task_rag_index_repo(self, args: dict):
|
||
return _rag_index_repo_sync(**args)
|
||
|
||
# API endpoint: sync of async enqueue
|
||
@app.post("/rag/index-repo")
|
||
async def rag_index_repo(
|
||
repo_url: str = Form(...),
|
||
branch: str = Form("main"),
|
||
profile: str = Form("auto"),
|
||
include: str = Form(""),
|
||
exclude_dirs: str = Form(""),
|
||
chunk_chars: int = Form(3000),
|
||
overlap: int = Form(400),
|
||
collection_name: str = Form("code_docs"),
|
||
force: bool = Form(False),
|
||
async_enqueue: bool = Form(False)
|
||
):
|
||
args = dict(
|
||
repo_url=repo_url, branch=branch, profile=profile,
|
||
include=include, exclude_dirs=exclude_dirs,
|
||
chunk_chars=int(chunk_chars), overlap=int(overlap),
|
||
collection_name=collection_name, force=bool(force)
|
||
)
|
||
if async_enqueue and celery_app:
|
||
task = task_rag_index_repo.delay(args) # type: ignore[name-defined]
|
||
return {"status": "enqueued", "task_id": task.id}
|
||
try:
|
||
return await run_in_threadpool(_rag_index_repo_sync, **args)
|
||
except Exception as e:
|
||
logger.exception("RAG index repo failed")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@app.get("/tasks/{task_id}")
|
||
def task_status(task_id: str):
|
||
if not celery_app:
|
||
raise HTTPException(status_code=400, detail="Celery niet geactiveerd")
|
||
try:
|
||
from celery.result import AsyncResult
|
||
res = AsyncResult(task_id, app=celery_app)
|
||
out = {"state": res.state}
|
||
if res.successful():
|
||
out["result"] = res.result
|
||
elif res.failed():
|
||
out["error"] = str(res.result)
|
||
return out
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# RAG Query (Chroma)
|
||
# -----------------------------------------------------------------------------
|
||
_LARAVEL_ROUTE_FILES = {"routes/web.php", "routes/api.php"}
|
||
|
||
def _laravel_guess_view_paths_from_text(txt: str) -> list[str]:
|
||
out = []
|
||
for m in re.finditer(r"return\s+view\(\s*['\"]([^'\"]+)['\"]", txt):
|
||
v = m.group(1).replace(".", "/")
|
||
out.append(f"resources/views/{v}.blade.php")
|
||
return out
|
||
|
||
def _laravel_pairs_from_route_text(txt: str) -> list[tuple[str, str|None]]:
|
||
"""Return [(controller_path, method_or_None), ...]"""
|
||
pairs = []
|
||
# Route::get('x', [XController::class, 'show'])
|
||
for m in re.finditer(r"Route::(?:get|post|put|patch|delete)\([^;]+?\[(.*?)\]\s*\)", txt, re.S):
|
||
arr = m.group(1)
|
||
mm = re.search(r"([A-Za-z0-9_\\]+)::class\s*,\s*'([A-Za-z0-9_]+)'", arr)
|
||
if mm:
|
||
ctrl = mm.group(1).split("\\")[-1]
|
||
meth = mm.group(2)
|
||
pairs.append((f"app/Http/Controllers/{ctrl}.php", meth))
|
||
# Route::resource('users', 'UserController')
|
||
for m in re.finditer(r"Route::resource\(\s*'([^']+)'\s*,\s*'([^']+)'\s*\)", txt):
|
||
res, ctrl = m.group(1), m.group(2).split("\\")[-1]
|
||
pairs.append((f"app/Http/Controllers/{ctrl}.php", None))
|
||
return pairs
|
||
|
||
async def rag_query_api(
|
||
*,
|
||
query: str,
|
||
n_results: int = 8,
|
||
collection_name: str = "code_docs",
|
||
repo: Optional[str] = None,
|
||
path_contains: Optional[str] = None,
|
||
profile: Optional[str] = None
|
||
) -> dict:
|
||
col = _get_collection(collection_name)
|
||
q_emb = _EMBEDDER.embed_query(query)
|
||
where = {}
|
||
if repo:
|
||
# Accepteer zowel 'repo' (basename) als 'repo_full' (owner/repo)
|
||
base = repo.rsplit("/", 1)[-1]
|
||
where = {"$or": [
|
||
{"repo": {"$eq": base}},
|
||
{"repo_full": {"$eq": repo}}
|
||
]}
|
||
if profile: where["profile"] = {"$eq": profile}
|
||
|
||
# ---- symbol hit set (repo-scoped) ----
|
||
sym_hit_keys: set[str] = set()
|
||
sym_hit_weight = float(os.getenv("RAG_SYMBOL_HIT_W", "0.12"))
|
||
if sym_hit_weight > 0.0 and repo:
|
||
try:
|
||
# index werd opgeslagen onder basename(repo_url); hier dus normaliseren
|
||
repo_base = repo.rsplit("/", 1)[-1]
|
||
idx_key = _symkey(_collection_versioned(collection_name), repo_base)
|
||
symidx = _SYMBOL_INDEX.get(idx_key, {})
|
||
# query-termen
|
||
terms = set(re.findall(r"[A-Za-z_][A-Za-z0-9_]*", query))
|
||
# ook 'Foo\\BarController@show' splitsen
|
||
for m in re.finditer(r"([A-Za-z_][A-Za-z0-9_]*)(?:@|::|->)([A-Za-z_][A-Za-z0-9_]*)", query):
|
||
terms.add(m.group(1)); terms.add(m.group(2))
|
||
for t in {x.lower() for x in terms}:
|
||
for e in symidx.get(t, []):
|
||
sym_hit_keys.add(f"{repo}::{e['path']}::{e['chunk_index']}")
|
||
except Exception:
|
||
pass
|
||
|
||
# --- padhints uit query halen (expliciet genoemde bestanden/dirs) ---
|
||
# Herken ook gequote varianten en slimme quotes.
|
||
path_hints: set[str] = set()
|
||
PH_PATTERNS = [
|
||
r"[\"“”'](resources\/[A-Za-z0-9_\/\.-]+\.blade\.php)[\"”']",
|
||
r"(resources\/[A-Za-z0-9_\/\.-]+\.blade\.php)",
|
||
r"[\"“”'](app\/[A-Za-z0-9_\/\.-]+\.php)[\"”']",
|
||
r"(app\/[A-Za-z0-9_\/\.-]+\.php)",
|
||
r"\b([A-Za-z0-9_\/-]+\.blade\.php)\b",
|
||
r"\b([A-Za-z0-9_\/-]+\.php)\b",
|
||
]
|
||
for pat in PH_PATTERNS:
|
||
for m in re.finditer(pat, query):
|
||
path_hints.add(m.group(1).strip())
|
||
|
||
res = col.query(
|
||
query_embeddings=[q_emb],
|
||
n_results=max(n_results, 30), # iets ruimer voor rerank
|
||
where=where or None,
|
||
include=["metadatas","documents","distances"]
|
||
)
|
||
logger.info("RAG raw hits: %d (repo=%s)", len((res.get("documents") or [[]])[0]), repo or "-")
|
||
docs = (res.get("documents") or [[]])[0]
|
||
metas = (res.get("metadatas") or [[]])[0]
|
||
dists = (res.get("distances") or [[]])[0]
|
||
|
||
# Filter path_contains en bouw kandidaten
|
||
cands = []
|
||
# Voor frase-boost: haal een simpele “candidate phrase” uit query (>=2 woorden)
|
||
def _candidate_phrase(q: str) -> str | None:
|
||
q = q.strip().strip('"').strip("'")
|
||
words = re.findall(r"[A-Za-zÀ-ÿ0-9_+-]{2,}", q)
|
||
if len(words) >= 2:
|
||
return " ".join(words[:6]).lower()
|
||
return None
|
||
phrase = _candidate_phrase(query)
|
||
|
||
for doc, meta, dist in zip(docs, metas, dists):
|
||
if path_contains and meta and path_contains.lower() not in (meta.get("path","").lower()):
|
||
continue
|
||
key = f"{(meta or {}).get('repo','')}::{(meta or {}).get('path','')}::{(meta or {}).get('chunk_index','')}"
|
||
emb_sim = 1.0 / (1.0 + float(dist or 0.0))
|
||
# symbol-hit boost vóór verdere combinaties
|
||
if key in sym_hit_keys:
|
||
emb_sim = min(1.0, emb_sim + sym_hit_weight)
|
||
c = {
|
||
"document": doc or "",
|
||
"metadata": meta or {},
|
||
"emb_sim": emb_sim, # afstand -> ~similarity (+symbol boost)
|
||
"distance": float(dist or 0.0) # bewaar ook de ruwe distance voor hybride retrieval
|
||
}
|
||
# --- extra boost: expliciet genoemde paths in query ---
|
||
p = (meta or {}).get("path","")
|
||
if path_hints:
|
||
for hint in path_hints:
|
||
if hint and (hint == p or hint.split("/")[-1] == p.split("/")[-1] or hint in p):
|
||
c["emb_sim"] = min(1.0, c["emb_sim"] + 0.12)
|
||
break
|
||
# --- extra boost: exacte frase in document (case-insensitive) ---
|
||
if phrase and phrase in (doc or "").lower():
|
||
c["emb_sim"] = min(1.0, c["emb_sim"] + 0.10)
|
||
cands.append(c)
|
||
# --- Als er expliciete path_hints zijn, probeer ze "hard" toe te voegen bovenaan ---
|
||
# Dit helpt o.a. bij 'resources/views/log/edit.blade.php' exact noemen.
|
||
if path_hints and repo:
|
||
try:
|
||
# Probeer lokale repo te openen (snelle clone/update) en lees de bestanden in.
|
||
repo_base = repo.rsplit("/", 1)[-1]
|
||
# Zet een conservatieve guess in elkaar voor de URL: gebruiker werkt met Gitea
|
||
# en roept elders get_git_repo() aan — wij willen geen netwerk doen hier.
|
||
# We gebruiken in plaats daarvan de collectie-metadata om te bepalen
|
||
# of het pad al aanwezig is (via metas).
|
||
known_paths = { (m or {}).get("path","") for m in metas }
|
||
injects = []
|
||
for hint in list(path_hints):
|
||
# Match ook alleen op bestandsnaam als volledige path niet gevonden is
|
||
matches = [p for p in known_paths if p == hint or p.endswith("/"+hint)]
|
||
for pth in matches:
|
||
# Zoek het eerste document dat bij dit path hoort en neem zijn tekst
|
||
for doc, meta in zip(docs, metas):
|
||
if (meta or {}).get("path","") == pth:
|
||
injects.append({"document": doc, "metadata": meta, "emb_sim": 1.0, "score": 1.0})
|
||
break
|
||
if injects:
|
||
# Zet injects vooraan, zonder dubbelingen
|
||
seen_keys = set()
|
||
new_cands = []
|
||
for j in injects + cands:
|
||
key = f"{(j.get('metadata') or {}).get('path','')}::{(j.get('metadata') or {}).get('chunk_index','')}"
|
||
if key in seen_keys:
|
||
continue
|
||
seen_keys.add(key)
|
||
new_cands.append(j)
|
||
cands = new_cands
|
||
except Exception:
|
||
pass
|
||
|
||
if not cands:
|
||
fallback = []
|
||
if BM25Okapi is not None and repo:
|
||
colname = _collection_versioned(collection_name)
|
||
repo_base = repo.rsplit("/", 1)[-1]
|
||
for k,(bm,docs) in _BM25_BY_REPO.items():
|
||
# cache-key formaat: "<basename>|<branch>|<collection_effective>"
|
||
if (k.startswith(f"{repo_base}|") or k.startswith(f"{repo}|")) and k.endswith(colname):
|
||
scores = bm.get_scores(_bm25_tok(query))
|
||
ranked = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)[:n_results]
|
||
for d,s in ranked:
|
||
fallback.append({
|
||
"document": d["text"],
|
||
"metadata": {"repo": repo_base, "path": d["path"]},
|
||
"score": float(s)
|
||
})
|
||
break
|
||
if fallback:
|
||
return {"count": len(fallback), "results": fallback}
|
||
return {"count": 0, "results": []}
|
||
|
||
|
||
# BM25/Jaccard op kandidaten
|
||
def _tok(s: str) -> list[str]:
|
||
return re.findall(r"[A-Za-z0-9_]+", s.lower())
|
||
bm25_scores = []
|
||
if BM25Okapi:
|
||
bm = BM25Okapi([_tok(c["document"]) for c in cands])
|
||
bm25_scores = bm.get_scores(_tok(query))
|
||
else:
|
||
qset = set(_tok(query))
|
||
for c in cands:
|
||
tset = set(_tok(c["document"]))
|
||
inter = len(qset & tset); uni = len(qset | tset) or 1
|
||
bm25_scores.append(inter / uni)
|
||
|
||
# --- Laravel cross-file boosts ---
|
||
# We hebben per kandidaat: {"document": ..., "metadata": {"repo": "...", "path": "..."}}
|
||
# Bouw een pad->index mapping
|
||
idx_by_path = {}
|
||
for i,c in enumerate(cands):
|
||
p = (c.get("metadata") or {}).get("path","")
|
||
idx_by_path.setdefault(p, []).append(i)
|
||
|
||
# 1) Route files -> boost controllers/views
|
||
for c in cands:
|
||
meta = c.get("metadata") or {}
|
||
path = meta.get("path","")
|
||
if path not in _LARAVEL_ROUTE_FILES:
|
||
continue
|
||
txt = c.get("document","")
|
||
# controllers
|
||
for ctrl_path, _meth in _laravel_pairs_from_route_text(txt):
|
||
if ctrl_path in idx_by_path:
|
||
for j in idx_by_path[ctrl_path]:
|
||
cands[j]["emb_sim"] = min(1.0, cands[j]["emb_sim"] + 0.05) # kleine push
|
||
# views (heel grof: lookups in route-tekst komen zelden direct voor, skip hier)
|
||
|
||
# 2) Controllers -> boost views die ze renderen
|
||
for c in cands:
|
||
meta = c.get("metadata") or {}
|
||
path = meta.get("path","")
|
||
if not path.startswith("app/Http/Controllers/") or not path.endswith(".php"):
|
||
continue
|
||
txt = c.get("document","")
|
||
for vpath in _laravel_guess_view_paths_from_text(txt):
|
||
if vpath in idx_by_path:
|
||
for j in idx_by_path[vpath]:
|
||
cands[j]["emb_sim"] = min(1.0, cands[j]["emb_sim"] + 0.05)
|
||
|
||
# symbols-boost: kleine bonus als query-termen in symbols of bestandsnaam voorkomen
|
||
q_terms = set(re.findall(r"[A-Za-z0-9_]+", query.lower()))
|
||
for c in cands:
|
||
meta = (c.get("metadata") or {})
|
||
# --- symbols (bugfix: syms_raw was undefined) ---
|
||
syms_raw = meta.get("symbols")
|
||
if isinstance(syms_raw, str):
|
||
syms = [s.strip() for s in syms_raw.split(",") if s.strip()]
|
||
elif isinstance(syms_raw, list):
|
||
syms = syms_raw
|
||
else:
|
||
syms = []
|
||
if syms and (q_terms & {s.lower() for s in syms}):
|
||
c["emb_sim"] = min(1.0, c["emb_sim"] + 0.04)
|
||
|
||
# --- filename exact-ish match ---
|
||
fname_terms = set(re.findall(r"[A-Za-z0-9_]+", meta.get("path","").split("/")[-1].lower()))
|
||
if fname_terms and (q_terms & fname_terms):
|
||
c["emb_sim"] = min(1.0, c["emb_sim"] + 0.02)
|
||
# lichte bonus als path exact één van de hints is
|
||
if path_hints and meta.get("path") in path_hints:
|
||
c["emb_sim"] = min(1.0, c["emb_sim"] + 0.06)
|
||
|
||
|
||
|
||
# normaliseer
|
||
if len(bm25_scores) > 0:
|
||
mn, mx = min(bm25_scores), max(bm25_scores)
|
||
bm25_norm = [(s - mn) / (mx - mn) if mx > mn else 0.0 for s in bm25_scores]
|
||
else:
|
||
bm25_norm = [0.0] * len(cands)
|
||
|
||
alpha = float(os.getenv("RAG_EMB_WEIGHT", "0.6"))
|
||
for c, b in zip(cands, bm25_norm):
|
||
c["score"] = alpha * c["emb_sim"] + (1.0 - alpha) * b
|
||
|
||
ranked = sorted(cands, key=lambda x: x["score"], reverse=True)[:n_results]
|
||
ranked_full = sorted(cands, key=lambda x: x["score"], reverse=True)
|
||
if RAG_LLM_RERANK:
|
||
topK = ranked_full[:max(10, n_results)]
|
||
# bouw prompt
|
||
prompt = "Rerank the following code passages for the query. Return ONLY a JSON array of indices (0-based) in best-to-worst order.\n"
|
||
prompt += f"Query: {query}\n"
|
||
for i, r in enumerate(topK):
|
||
path = (r.get("metadata") or {}).get("path","")
|
||
snippet = (r.get("document") or "")[:600]
|
||
prompt += f"\n# {i} — {path}\n{snippet}\n"
|
||
resp = await llm_call_openai_compat(
|
||
[{"role":"system","content":"You are precise and return only valid JSON."},
|
||
{"role":"user","content": prompt+"\n\nOnly JSON array."}],
|
||
stream=False, temperature=0.0, top_p=1.0, max_tokens=256
|
||
)
|
||
try:
|
||
order = json.loads((resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","[]"))
|
||
reranked = [topK[i] for i in order if isinstance(i,int) and 0 <= i < len(topK)]
|
||
ranked = reranked[:n_results] if reranked else ranked_full[:n_results]
|
||
except Exception:
|
||
ranked = ranked_full[:n_results]
|
||
else:
|
||
ranked = ranked_full[:n_results]
|
||
|
||
return {
|
||
"count": len(ranked),
|
||
"results": [{
|
||
"document": r["document"],
|
||
"metadata": r["metadata"],
|
||
"file": (r["metadata"] or {}).get("path", ""), # <- expliciete bestandsnaam
|
||
"score": round(float(r["score"]), 4),
|
||
# houd distance beschikbaar voor hybrid_retrieve (embed-component)
|
||
"distance": float(r.get("distance", 0.0))
|
||
} for r in ranked]
|
||
}
|
||
|
||
|
||
@app.post("/rag/query")
|
||
async def rag_query_endpoint(
|
||
query: str = Form(...),
|
||
n_results: int = Form(8),
|
||
collection_name: str = Form("code_docs"),
|
||
repo: str = Form(""),
|
||
path_contains: str = Form(""),
|
||
profile: str = Form("")
|
||
):
|
||
data = await rag_query_api(
|
||
query=query, n_results=n_results, collection_name=collection_name,
|
||
repo=(repo or None), path_contains=(path_contains or None), profile=(profile or None)
|
||
)
|
||
return JSONResponse(data)
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Repo-agent endpoints
|
||
# -----------------------------------------------------------------------------
|
||
@app.post("/agent/repo")
|
||
async def agent_repo(messages: ChatRequest, request: Request):
|
||
"""
|
||
Gespreks-interface met de repo-agent (stateful per sessie).
|
||
"""
|
||
# Convert naar list[dict] (zoals agent_repo verwacht)
|
||
msgs = [{"role": m.role, "content": m.content} for m in messages.messages]
|
||
text = await handle_repo_agent(msgs, request)
|
||
return PlainTextResponse(text)
|
||
|
||
@app.post("/repo/qa")
|
||
async def repo_qa(req: RepoQARequest):
|
||
"""
|
||
Eén-shot vraag-antwoord over een repo.
|
||
"""
|
||
ans = await repo_qa_answer(req.repo_hint, req.question, branch=req.branch, n_ctx=req.n_ctx)
|
||
return PlainTextResponse(ans)
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Injecties voor agent_repo
|
||
# -----------------------------------------------------------------------------
|
||
async def _rag_index_repo_internal(*, repo_url: str, branch: str, profile: str,
|
||
include: str, exclude_dirs: str,
|
||
chunk_chars: int, overlap: int,
|
||
collection_name: str):
|
||
# offload zware IO/CPU naar threadpool zodat de event-loop vrij blijft
|
||
return await run_in_threadpool(
|
||
_rag_index_repo_sync,
|
||
repo_url=repo_url, branch=branch, profile=profile,
|
||
include=include, exclude_dirs=exclude_dirs,
|
||
chunk_chars=chunk_chars, overlap=overlap,
|
||
collection_name=collection_name, force=False
|
||
)
|
||
|
||
async def _rag_query_internal(*, query: str, n_results: int,
|
||
collection_name: str, repo=None, path_contains=None, profile=None):
|
||
return await rag_query_api(
|
||
query=query, n_results=n_results,
|
||
collection_name=collection_name,
|
||
repo=repo, path_contains=path_contains, profile=profile
|
||
)
|
||
|
||
def _read_text_file_wrapper(p: Path | str) -> str:
|
||
return _read_text_file(Path(p) if not isinstance(p, Path) else p)
|
||
|
||
async def _get_git_repo_async(repo_url: str, branch: str = "main") -> str:
|
||
# gitpython doet subprocess/IO → altijd in threadpool
|
||
return await run_in_threadpool(get_git_repo, repo_url, branch)
|
||
|
||
# Registreer injecties
|
||
initialize_agent(
|
||
app=app,
|
||
get_git_repo_fn=_get_git_repo_async,
|
||
rag_index_repo_internal_fn=_rag_index_repo_internal,
|
||
rag_query_internal_fn=_rag_query_internal,
|
||
llm_call_fn=llm_call_openai_compat,
|
||
extract_code_block_fn=extract_code_block,
|
||
read_text_file_fn=_read_text_file_wrapper,
|
||
client_ip_fn=_client_ip,
|
||
profile_exclude_dirs=PROFILE_EXCLUDE_DIRS,
|
||
chroma_get_collection_fn=_get_collection,
|
||
embed_query_fn=_EMBEDDER.embed_query,
|
||
embed_documents_fn=_EMBEDDER.embed_documents, # ← nieuw: voor catalog-embeddings
|
||
# === nieuw: search-first + summaries + meili ===
|
||
search_candidates_fn=_search_first_candidates,
|
||
repo_summary_get_fn=_repo_summary_get_internal,
|
||
meili_search_fn=_meili_search_internal,
|
||
)
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Health
|
||
# -----------------------------------------------------------------------------
|
||
@app.get("/healthz")
|
||
def health():
|
||
sem_val = getattr(app.state.LLM_SEM, "_value", 0)
|
||
ups = [{"url": u.url, "active": u.active, "ok": u.ok} for u in _UPS]
|
||
return {
|
||
"ok": True,
|
||
"embedder": _EMBEDDER.slug,
|
||
"chroma_mode": CHROMA_MODE,
|
||
"queue_len": len(app.state.LLM_QUEUE),
|
||
"permits_free": sem_val,
|
||
"upstreams": ups,
|
||
}
|
||
|
||
@app.get("/metrics")
|
||
def metrics():
|
||
ups = [{"url": u.url, "active": u.active, "ok": u.ok} for u in _UPS]
|
||
return {
|
||
"queue_len": len(app.state.LLM_QUEUE),
|
||
"sem_value": getattr(app.state.LLM_SEM, "_value", None),
|
||
"upstreams": ups
|
||
}
|
||
|