# smart_rag.py # Kleine util-laag voor intent + hybride retrieval + context-assemblage. from __future__ import annotations import os, re, json, math, hashlib from typing import List, Dict, Tuple, DefaultDict, Optional from collections import defaultdict def _decamel(s: str) -> str: s = re.sub(r"([a-z])([A-Z])", r"\1 \2", s) s = s.replace("_", " ") return re.sub(r"\s+", " ", s).strip() def _symbol_guess(q: str) -> list[str]: # pak langste 'code-achtig' token als symboolkandidaat toks = re.findall(r"[A-Za-z_][A-Za-z0-9_]{2,}", q) toks.sort(key=len, reverse=True) return toks[:2] def _simple_variants(q: str, max_k: int = 3) -> list[str]: base = [q] lo = q.lower().strip() if lo and lo not in base: base.append(lo) dec = _decamel(q) if dec and dec.lower() != lo and dec not in base: base.append(dec) syms = _symbol_guess(q) for s in syms: v = s.replace("_", " ") if v not in base: base.append(v) v2 = s # raw symbool if v2 not in base: base.append(v2) # cap return base[: max(1, min(len(base), max_k))] # --- Query routing + RRF fuse --- def _route_query_buckets(q: str) -> list[dict]: """Hele lichte router: retourneert lijst subqueries met optionele path filters en boost.""" lo = (q or "").lower() buckets = [] # Queue/Jobs/Event pipeline (Laravel) if any(w in lo for w in ["job", "queue", "listener", "event", "dispatch"]): buckets.append({"q": q, "path_contains": "app/Jobs", "boost": 1.18}) buckets.append({"q": q, "path_contains": "app/Listeners", "boost": 1.12}) buckets.append({"q": q, "path_contains": "app/Events", "boost": 1.10}) # Models / Migrations if any(w in lo for w in ["model", "eloquent", "scope", "attribute"]): buckets.append({"q": q, "path_contains": "app/Models", "boost": 1.12}) if any(w in lo for w in ["migration", "schema", "table", "column"]): buckets.append({"q": q, "path_contains": "database/migrations", "boost": 1.08}) # Laravel/Blade/UI if any(w in lo for w in ["blade", "view", "template", "button", "placeholder", "label"]): buckets.append({"q": q, "path_contains": "resources/views", "boost": 1.2}) # Routes/controllers if any(w in lo for w in ["route", "controller", "middleware", "api", "web.php", "controller@"]): buckets.append({"q": q, "path_contains": "routes", "boost": 1.15}) buckets.append({"q": q, "path_contains": "app/Http/Controllers", "boost": 1.2}) # Config/ENV if any(w in lo for w in ["env", "config", "database", "queue", "cache"]): buckets.append({"q": q, "path_contains": "config", "boost": 1.15}) buckets.append({"q": q, "path_contains": ".env", "boost": 1.1}) # Docs/README if any(w in lo for w in ["readme", "install", "setup", "document", "usage"]): buckets.append({"q": q, "path_contains": "README", "boost": 1.05}) buckets.append({"q": q, "path_contains": "docs", "boost": 1.05}) # Fallback: generiek buckets.append({"q": q, "path_contains": None, "boost": 1.0}) # dedup op (q, path_contains) seen = set(); out = [] for b in buckets: key = (b["q"], b["path_contains"]) if key in seen: continue seen.add(key); out.append(b) return out def rrf_fuse_ranked_lists(ranked_lists: list[list[dict]], k: int = 60) -> list[dict]: """ ranked_lists: bv. [[{key,score,item},...], ...] (elk al per kanaal/bucket gesorteerd) Return: één samengevoegde lijst (dicts) met veld 'score_fused'. """ # bouw mapping pos_maps: list[dict] = [] for rl in ranked_lists or []: pos = {} for i, it in enumerate(rl, 1): meta = it.get("metadata") or {} key = f"{meta.get('repo','')}::{meta.get('path','')}::{meta.get('chunk_index','')}" pos[key] = i pos_maps.append(pos) fused: dict[str, float] = {} ref_item: dict[str, dict] = {} for idx, rl in enumerate(ranked_lists or []): pos_map = pos_maps[idx] for it in rl: meta = it.get("metadata") or {} key = f"{meta.get('repo','')}::{meta.get('path','')}::{meta.get('chunk_index','')}" r = pos_map.get(key, 10**9) fused[key] = fused.get(key, 0.0) + 1.0 / (k + r) ref_item[key] = it out = [] for key, f in fused.items(): it = dict(ref_item[key]) it["score_fused"] = f out.append(it) out.sort(key=lambda x: x.get("score_fused", 0.0), reverse=True) return out def _rrf_from_ranklists(ranklists: List[List[str]], k: int = int(os.getenv("RRF_K", "60"))) -> Dict[str, float]: """ Reciprocal Rank Fusion: neemt geordende lijsten (best eerst) en geeft samengevoegde scores {key: rrf_score}. """ acc = defaultdict(float) for lst in ranklists: for i, key in enumerate(lst): acc[key] += 1.0 / (k + i + 1) return acc def _path_prior(path: str) -> float: """ Light-weight prior per pad. 0..1 schaal. Laravel paden krijgen bonus, generieke code dirs ook een kleine bonus; binaire/test/asset minder. """ p = (path or "").replace("\\", "/").lower() bonus = 0.0 # Laravel priors if p.startswith("routes/"): bonus += 0.35 if p.startswith("app/http/controllers/"): bonus += 0.30 if p.startswith("resources/views/"): bonus += 0.25 if p.endswith(".blade.php"): bonus += 0.15 # Generieke priors if p.startswith(("src/", "app/", "lib/", "pages/", "components/")): bonus += 0.12 if p.endswith((".php",".ts",".tsx",".js",".jsx",".py",".go",".rb",".java",".cs",".vue",".html",".md")): bonus += 0.05 # Demote obvious low-signal if "/tests/" in p or p.startswith(("tests/", "test/")): bonus -= 0.10 if p.endswith((".lock",".map",".min.js",".min.css")): bonus -= 0.10 return max(0.0, min(1.0, bonus)) def _safe_json_loads(s: str): if not s: return None t = s.strip() if t.startswith("```"): t = re.sub(r"^```(?:json)?", "", t, count=1, flags=re.IGNORECASE).strip() if t.endswith("```"): t = t[:-3].strip() try: return json.loads(t) except Exception: return None def _tok(s: str) -> List[str]: return re.findall(r"[A-Za-z0-9_]+", s.lower()) def _jaccard(a: str, b: str) -> float: A, B = set(_tok(a)), set(_tok(b)) if not A or not B: return 0.0 # heel kleine set-caps (noodrem tegen pathologische inputs) if len(B) > 8000: # reduceer B met stabiele (deterministische) sampling op basis van sha1 def _stable_byte(tok: str) -> int: return hashlib.sha1(tok.encode("utf-8")).digest()[0] B = {t for t in B if _stable_byte(t) < 64} # ~25% sample return len(A & B) / max(1, len(A | B)) def _normalize(xs: List[float]) -> List[float]: if not xs: return xs lo, hi = min(xs), max(xs) if hi <= lo: return [0.0]*len(xs) return [(x - lo) / (hi - lo) for x in xs] async def enrich_intent(llm_call_fn, messages: List[Dict]) -> Dict: """ Zet ongestructureerde vraag om naar een compact plan. Velden: task, constraints, file_hints, keywords, acceptance, ask(optional). """ user_text = "" for m in reversed(messages): if m.get("role") == "user": user_text = m.get("content","").strip() break sys = ("Je herstructureert een developer-vraag naar JSON. " "Geef ALLEEN JSON, geen toelichting.") usr = ( "Zet de essentie van de vraag om naar dit schema:\n" "{" "\"task\": str, " "\"constraints\": [str,...], " "\"file_hints\": [str,...], " "\"keywords\": [str,...], " "\"acceptance\": [str,...], " "\"ask\": str|null " "}\n\n" f"Vraag:\n{user_text}" ) try: resp = await llm_call_fn( [{"role":"system","content":sys},{"role":"user","content":usr}], stream=False, temperature=0.1, top_p=1.0, max_tokens=512 ) raw = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","{}") spec = _safe_json_loads(raw) or {"task": user_text, "constraints": [], "file_hints": [], "keywords": [], "acceptance": [], "ask": None} #json.loads(raw.strip()) except Exception: # Veilige defaults spec = { "task": user_text, "constraints": [], "file_hints": [], "keywords": [], "acceptance": [], "ask": None } # Minimalistische fallback sanity for k in ("constraints","file_hints","keywords","acceptance"): if not isinstance(spec.get(k), list): spec[k] = [] if not isinstance(spec.get("task"), str): spec["task"] = user_text if spec.get("ask") is not None and not isinstance(spec["ask"], str): spec["ask"] = None return spec async def expand_queries(llm_call_fn, q: str, k: int = 3) -> List[str]: if str(os.getenv("RAG_EXPAND_QUERIES","1")).lower() in ("0","false"): return [q] sys = "Geef 3-4 korte NL/EN zoekvarianten als JSON array. Geen toelichting." usr = f"Bronvraag:\n{q}\n\nAlleen JSON array." try: resp = await llm_call_fn( [{"role":"system","content":sys},{"role":"user","content":usr}], stream=False, temperature=0.2, top_p=0.9, max_tokens=240 ) raw = (resp.get("choices",[{}])[0].get("message",{}) or {}).get("content","[]") arr = _safe_json_loads(raw) or [] arr = [str(x).strip() for x in arr if str(x).strip()] seen = {q.lower()} base = [q] for v in arr: lv = v.lower() if lv not in seen: base.append(v); seen.add(lv) return base[: max(1, min(len(base), k + 1))] except Exception: return [q] def _sim_from_chroma_distance(d: float|None) -> float: """ Converteer (Chroma) distance naar similarity in [0,1]; defensief tegen None/NaN/negatief. """ if d is None: return 0.0 try: dv = float(d) except Exception: dv = 0.0 if not math.isfinite(dv) or dv < 0: return 0.0 return 1.0 / (1.0 + dv) async def hybrid_retrieve( rag_query_internal_fn, query: str, *, repo: str|None = None, profile: str|None = None, path_contains: str|None = None, per_query_k: int = 30, n_results: int = 8, alpha: float = 0.6, collection_name: str = "code_docs", llm_call_fn=None, ) -> List[Dict]: """ Multi-variant retrieval met RRF-fusie + path-prior. Return: lijst met dict(document, metadata, score) """ # Optionele query-routing + RRF use_route = str(os.getenv("RAG_ROUTE", "1")).lower() not in ("0", "false") use_rrf = str(os.getenv("RAG_RRF", "1")).lower() not in ("0", "false") # Optionele mini multi-query expansion (default aan) use_expand = str(os.getenv("RAG_MULTI_EXPAND", "1")).lower() in ("1","true","yes") k_variants = max(1, int(os.getenv("RAG_MULTI_K", "3"))) per_query_k = max(1, int(per_query_k)) n_results = max(1, int(n_results)) if not (query or "").strip(): return [] # Multi-query variants: if use_expand: if llm_call_fn is not None: variants = await expand_queries(llm_call_fn, query, k=k_variants) else: variants = _simple_variants(query, max_k=k_variants) else: variants = [query] ranked_lists = [] # voor RRF (alle varianten/buckets) for qv in variants: if use_route: buckets = _route_query_buckets(qv) for b in buckets: # combineer globale path_contains-hint met bucket-specifieke filter pc = b.get("path_contains") if path_contains and not pc: pc = path_contains res = await rag_query_internal_fn( query=b["q"], n_results=per_query_k, collection_name=collection_name, repo=repo, path_contains=pc, profile=profile ) lst = [] for item in (res or {}).get("results", []): # distance kan ontbreken bij oudere backends; defensieve cast dist = item.get("distance", None) try: dist = float(dist) if dist is not None else None except Exception: dist = None emb_sim = _sim_from_chroma_distance(dist) * float(b.get("boost",1.0)) lst.append({**item, "emb_sim_routed": emb_sim}) lst.sort(key=lambda x: x.get("emb_sim_routed",0.0), reverse=True) # Laat RRF voldoende kandidaten zien (niet te vroeg afsnijden): ranked_lists.append(lst[:per_query_k]) else: # geen routing: per variant direct query'en (consistent scoren/sorteren) res = await rag_query_internal_fn( query=qv, n_results=per_query_k, collection_name=collection_name, repo=repo, path_contains=path_contains, profile=profile ) lst = [] for item in (res or {}).get("results", []): dist = item.get("distance", None) try: dist = float(dist) if dist is not None else None except Exception: dist = None emb_sim = _sim_from_chroma_distance(dist) lst.append({**item, "emb_sim_routed": emb_sim}) lst.sort(key=lambda x: x.get("emb_sim_routed", 0.0), reverse=True) ranked_lists.append(lst[:per_query_k]) # Als RRF aanstaat: fuseer nu items = rrf_fuse_ranked_lists(ranked_lists) if use_rrf else [x for rl in ranked_lists for x in rl] if not items: return [] # Eenvoudige lexicale score (op samengevoegde set): # neem het BESTE van alle varianten i.p.v. alleen de hoofdquery. bm: List[float] = [] if variants and len(variants) > 1: for it in items: doc = it.get("document", "") or "" bm.append(max((_jaccard(v, doc) for v in variants), default=_jaccard(query, doc))) else: bm = [_jaccard(query, it.get("document","")) for it in items] bm_norm = _normalize(bm) out = [] for i, it in enumerate(items): # Betere fallback: gebruik routed emb sim → plain emb_sim → distance emb = ( float(it.get("emb_sim_routed", 0.0)) or float(it.get("emb_sim", 0.0)) or _sim_from_chroma_distance(it.get("distance")) ) score = alpha * emb + (1.0 - alpha) * bm_norm[i] meta = (it.get("metadata") or {}) path = meta.get("path","") or "" # — optioneel: path-prior + symbol-boost via env — pp_w = float(os.getenv("RAG_PATH_PRIOR_W", "0.08")) if pp_w > 0.0: score += pp_w * _path_prior(path) sym_w = float(os.getenv("RAG_SYM_BOOST", "0.04")) if sym_w > 0.0: syms_raw = meta.get("symbols") if isinstance(syms_raw, str): syms = [s.strip().lower() for s in syms_raw.split(",") if s.strip()] elif isinstance(syms_raw, list): syms = [str(s).strip().lower() for s in syms_raw if str(s).strip()] else: syms = [] if syms: q_terms = set(_tok(query)) if q_terms & set(syms): score += sym_w out.append({**it, "score": float(score)}) out.sort(key=lambda x: x["score"], reverse=True) return out[:int(n_results)] def assemble_context(chunks: List[Dict], *, max_chars: int = 24000) -> Tuple[str, float]: """ Budgeted stitching: - groepeer per path - per path: neem 1-3 fragmenten (op volgorde van chunk_index indien beschikbaar) - verdeel char-budget over paden, zwaarder voor hogere scores - behoud Laravel stitching Retour: (context_text, top_score) """ if not chunks: return "", 0.0 # 1) Groepeer per path en verzamel scores + (optioneel) chunk_index by_path: Dict[str, List[Dict]] = {} top_score = 0.0 for r in chunks: meta = (r.get("metadata") or {}) path = meta.get("path","") or "" r["_chunk_index"] = meta.get("chunk_index") r["_score"] = float(r.get("score", 0.0) or 0.0) top_score = max(top_score, r["_score"]) by_path.setdefault(path, []).append(r) # 2) Per path: sorteer op chunk_index (indien beschikbaar) anders score; cap op N stukken def _sort_key(x): ci = x.get("_chunk_index") return (0, int(ci)) if isinstance(ci, int) or (isinstance(ci, str) and str(ci).isdigit()) else (1, -x["_score"]) path_items = [] max_pieces = int(os.getenv("CTX_PIECES_PER_PATH_CAP", "3")) for p, lst in by_path.items(): lst_sorted = sorted(lst, key=_sort_key) path_items.append({ "path": p, "best_score": max(x["_score"] for x in lst_sorted), "pieces": lst_sorted[:max(1, max_pieces)], # cap per bestand }) # 3) Sorteer paden op best_score en bereken budgetverdeling (softmax-achtig, maar bounded) path_items.sort(key=lambda t: t["best_score"], reverse=True) # clamp scores naar [0,1] voor stabielere allocatie scores = [min(1.0, max(0.0, t["best_score"])) for t in path_items] # softmax-lite: exp(score*beta) normaliseren; beta iets lager om niet te scherp te verdelen beta = float(os.getenv("CTX_ALLOC_BETA", "2.2")) w = [math.exp(beta * s) for s in scores] S = max(1e-9, sum(w)) weights = [x / S for x in w] # 4) Bouw snelle lookup path->full body (voor Laravel stitching) by_path_first_body: Dict[str, str] = {} for t in path_items: doc0 = (t["pieces"][0].get("document") or "").strip() by_path_first_body[t["path"]] = doc0 # 5) Render met budget per pad out = [] used = 0 for t, w_i in zip(path_items, weights): p = t["path"] # minimaal & maximaal budget per pad (chars) min_chars = int(os.getenv("CTX_ALLOC_MIN_PER_PATH", "1200")) max_chars_path = int(os.getenv("CTX_ALLOC_MAX_PER_PATH", "6000")) alloc = min(max(min_chars, int(max_chars * w_i)), max_chars_path) # stitch 1..3 stukken van dit pad binnen alloc header = f"### {p} (score={t['best_score']:.3f})" block_buf = [header] remaining = max(0, alloc - len(header) - 1) for piece in t["pieces"]: body = (piece.get("document") or "").strip() # knip niet middenin een regel: neem tot remaining en rol terug tot laatste newline if remaining <= 0: break if len(body) > remaining: cut = body[:remaining] nl = cut.rfind("\n") if nl > 300: # laat niet té kort body = cut[:nl] + "\n…" else: body = cut + "…" block_buf.append(body) remaining -= len(body) if remaining <= 300: # hou wat over voor stitching break block = "\n".join(block_buf) # --- Laravel mini-stitch zoals voorheen, maar budgetbewust stitched = [] if p in ("routes/web.php", "routes/api.php"): for ctrl_path, _meth in _laravel_pairs_from_route_text(by_path_first_body.get(p,"")): if ctrl_path in by_path_first_body and remaining > 400: snippet = by_path_first_body[ctrl_path][:min(400, remaining)] stitched.append(f"\n### {ctrl_path} (stitch)\n{snippet}") remaining -= len(snippet) if p.startswith("app/Http/Controllers/"): for vpath in _laravel_guess_view_paths_from_text(by_path_first_body.get(p,"")): if vpath in by_path_first_body and remaining > 400: snippet = by_path_first_body[vpath][:min(400, remaining)] stitched.append(f"\n### {vpath} (stitch)\n{snippet}") remaining -= len(snippet) if stitched: block += "\n" + "\n".join(stitched) # Past het volledige blok niet meer, knip netjes i.p.v. alles laten vallen remaining_total = max_chars - used if remaining_total <= 0: break if len(block) > remaining_total: # Zorg dat we niet midden in markdown header afkappen trimmed = block[:max(0, remaining_total - 1)] block = trimmed + "…" out.append(block) used = max_chars break else: out.append(block) used += len(block) # stop vroeg als we het budget bijna op hebben if max_chars - used < 800: break return ("\n\n".join(out), float(top_score)) # --- Laravel route/controller/view helpers (lightweight, cycle-safe) --- def _laravel_pairs_from_route_text(route_text: str): """ Parse routes/web.php|api.php tekst en yield (controller_path, method) guesses. Ondersteunt: - 'Controller@method' - FQCN zoals App\\Http\\Controllers\\Foo\\BarController::class """ out = [] # 1) 'Controller@method' for m in re.finditer(r"['\"]([A-Za-z0-9_\\]+)@([A-Za-z0-9_]+)['\"]", route_text): fq = m.group(1) method = m.group(2) ctrl = fq.replace("\\\\","/").replace("\\","/") name = ctrl.split("/")[-1] guess = f"app/Http/Controllers/{ctrl}.php" alt = f"app/Http/Controllers/{name}.php" out.append((guess, method)) out.append((alt, method)) # 2) FQCN ::class for m in re.finditer(r"([A-Za-z_][A-Za-z0-9_\\]+)\s*::\s*class", route_text): fq = m.group(1) ctrl = fq.replace("\\\\","/").replace("\\","/") name = ctrl.split("/")[-1] guess = f"app/Http/Controllers/{ctrl}.php" alt = f"app/Http/Controllers/{name}.php" out.append((guess, None)) out.append((alt, None)) # dedupe, behoud orde seen = set(); dedup = [] for p in out: if p not in seen: seen.add(p); dedup.append(p) return dedup def _laravel_guess_view_paths_from_text(controller_text: str): """ Parse simpele 'return view(\"foo.bar\")' patronen → resources/views/foo/bar.blade.php """ out = [] for m in re.finditer(r"view\(\s*['\"]([A-Za-z0-9_.\/-]+)['\"]\s*\)", controller_text): view = m.group(1).strip().strip(".") # 'foo.bar' of 'foo/bar' path = view.replace(".", "/") out.append(f"resources/views/{path}.blade.php") # dedupe seen = set(); dedup = [] for p in out: if p not in seen: seen.add(p); dedup.append(p) return dedup # Public API surface __all__ = [ "enrich_intent", "expand_queries", "hybrid_retrieve", "assemble_context", "_laravel_pairs_from_route_text", "_laravel_guess_view_paths_from_text", ]