173 lines
6.3 KiB
Python
173 lines
6.3 KiB
Python
# windowing_utils.py
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass, field
|
|
from typing import List, Dict, Callable, Optional, Tuple, Awaitable
|
|
import hashlib
|
|
import os
|
|
import time
|
|
|
|
# ---------- Token counting (vervang door echte tokenizer indien je wilt)
|
|
def approx_token_count(text: str) -> int:
|
|
# ~4 chars ≈ 1 token (ruwe maar stabiele vuistregel)
|
|
return max(1, len(text) // 4)
|
|
|
|
def count_message_tokens(messages: List[Dict], tok_len: Callable[[str], int]) -> int:
|
|
total = 0
|
|
for m in messages:
|
|
total += tok_len(m.get("content", ""))
|
|
return total
|
|
|
|
# ---------- Thread ID + summary store
|
|
def derive_thread_id(body: Dict) -> str:
|
|
for key in ("conversation_id", "thread_id", "chat_id", "session_id", "room_id"):
|
|
if key in body and body[key]:
|
|
return str(body[key])
|
|
parts = [str(body.get("model", ""))]
|
|
msgs = body.get("messages", [])[:2]
|
|
for m in msgs:
|
|
parts.append(m.get("role", ""))
|
|
parts.append(m.get("content", "")[:256])
|
|
raw = "||".join(parts)
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16]
|
|
|
|
class RunningSummaryStore:
|
|
def __init__(self):
|
|
self._mem: dict[str, str] = {}
|
|
def get(self, thread_id: str) -> str:
|
|
return self._mem.get(thread_id, "")
|
|
def update(self, thread_id: str, new_summary: str):
|
|
self._mem[thread_id] = new_summary
|
|
|
|
SUMMARY_STORE = RunningSummaryStore()
|
|
|
|
# ---------- Sliding window + running summary
|
|
@dataclass
|
|
class ConversationWindow:
|
|
max_ctx_tokens: int
|
|
response_reserve: int = 2048
|
|
tok_len: Callable[[str], int] = approx_token_count
|
|
running_summary: str = ""
|
|
summary_header: str = "Samenvatting tot nu toe"
|
|
history: List[Dict] = field(default_factory=list)
|
|
|
|
def add(self, role: str, content: str):
|
|
self.history.append({"role": role, "content": content})
|
|
|
|
def _base_messages(self, system_prompt: Optional[str]) -> List[Dict]:
|
|
msgs: List[Dict] = []
|
|
if system_prompt:
|
|
msgs.append({"role": "system", "content": system_prompt})
|
|
if self.running_summary:
|
|
msgs.append({"role": "system", "content": f"{self.summary_header}:\n{self.running_summary}"})
|
|
return msgs
|
|
|
|
async def build_within_budget(
|
|
self,
|
|
system_prompt: Optional[str],
|
|
summarizer: Optional[Callable[[str, List[Dict]], Awaitable[str]]] = None
|
|
) -> List[Dict]:
|
|
budget = self.max_ctx_tokens - max(1, self.response_reserve)
|
|
working = self.history[:]
|
|
candidate = self._base_messages(system_prompt) + working
|
|
if count_message_tokens(candidate, self.tok_len) <= budget:
|
|
return candidate
|
|
|
|
# 1) trim oudste turns
|
|
while working and count_message_tokens(self._base_messages(system_prompt) + working, self.tok_len) > budget:
|
|
working.pop(0)
|
|
candidate = self._base_messages(system_prompt) + working
|
|
if count_message_tokens(candidate, self.tok_len) <= budget:
|
|
self.history = working
|
|
return candidate
|
|
|
|
# 2) samenvatten indien mogelijk
|
|
if summarizer is None:
|
|
while working and count_message_tokens(self._base_messages(system_prompt) + working, self.tok_len) > budget:
|
|
working.pop(0)
|
|
self.history = working
|
|
return self._base_messages(system_prompt) + working
|
|
|
|
# samenvat in batches
|
|
working = self.history[:]
|
|
chunk_buf: List[Dict] = []
|
|
|
|
async def build_candidate(_summary: str, _working: List[Dict]) -> List[Dict]:
|
|
base = []
|
|
if system_prompt:
|
|
base.append({"role": "system", "content": system_prompt})
|
|
if _summary:
|
|
base.append({"role": "system", "content": f"{self.summary_header}:\n{_summary}"})
|
|
return base + _working
|
|
|
|
while working and count_message_tokens(await build_candidate(self.running_summary, working), self.tok_len) > budget:
|
|
chunk_buf.append(working.pop(0))
|
|
# bij ~1500 tokens in buffer (ruw) samenvatten
|
|
if count_message_tokens([{"role":"system","content":str(chunk_buf)}], self.tok_len) > 1500 or not working:
|
|
self.running_summary = await summarizer(self.running_summary, chunk_buf)
|
|
chunk_buf = []
|
|
|
|
# verwerk eventuele overgebleven buffer zodat er geen turns verdwijnen
|
|
if chunk_buf:
|
|
self.running_summary = await summarizer(self.running_summary, chunk_buf)
|
|
chunk_buf = []
|
|
|
|
self.history = working
|
|
return await build_candidate(self.running_summary, working)
|
|
|
|
# ---------- Repo chunking
|
|
from typing import Iterable
|
|
def split_text_tokens(
|
|
text: str,
|
|
tok_len: Callable[[str], int],
|
|
max_tokens: int,
|
|
overlap_tokens: int = 60
|
|
) -> List[str]:
|
|
if tok_len(text) <= max_tokens:
|
|
return [text]
|
|
approx_ratio = max_tokens / max(1, tok_len(text))
|
|
step = max(1000, int(len(text) * approx_ratio))
|
|
chunks: List[str] = []
|
|
i = 0
|
|
while i < len(text):
|
|
ch = text[i:i+step]
|
|
while tok_len(ch) > max_tokens and len(ch) > 200:
|
|
ch = ch[:-200]
|
|
chunks.append(ch)
|
|
if overlap_tokens > 0:
|
|
ov_chars = max(100, overlap_tokens * 4)
|
|
i += max(1, len(ch) - ov_chars)
|
|
else:
|
|
i += len(ch)
|
|
return chunks
|
|
|
|
def fit_context_under_budget(
|
|
items: List[Tuple[str,str]], tok_len: Callable[[str], int], budget_tokens: int
|
|
) -> List[Tuple[str,str]]:
|
|
res: List[Tuple[str,str]] = []
|
|
used = 0
|
|
for title, text in items:
|
|
t = tok_len(text)
|
|
if used + t <= budget_tokens:
|
|
res.append((title, text))
|
|
used += t
|
|
else:
|
|
break
|
|
return res
|
|
|
|
def build_repo_context(
|
|
files_ranked: List[Tuple[str, str, float]],
|
|
per_chunk_tokens: int = 1200,
|
|
overlap_tokens: int = 60,
|
|
ctx_budget_tokens: int = 4000,
|
|
tok_len: Callable[[str], int] = approx_token_count
|
|
) -> str:
|
|
expanded: List[Tuple[str,str]] = []
|
|
for path, content, _ in files_ranked:
|
|
for i, ch in enumerate(split_text_tokens(content, tok_len, per_chunk_tokens, overlap_tokens)):
|
|
expanded.append((f"{path}#chunk{i+1}", ch))
|
|
selected = fit_context_under_budget(expanded, tok_len, ctx_budget_tokens)
|
|
ctx = ""
|
|
for title, ch in selected:
|
|
ctx += f"\n\n=== {title} ===\n{ch}"
|
|
return ctx.strip()
|