mistral-api/windowing_utils.py

168 lines
6.1 KiB
Python
Raw Normal View History

2025-11-06 13:48:51 +00:00
# windowing_utils.py
from __future__ import annotations
from dataclasses import dataclass, field
from typing import List, Dict, Callable, Optional, Tuple
import hashlib
import os
import time
# ---------- Token counting (vervang door echte tokenizer indien je wilt)
def approx_token_count(text: str) -> int:
# ~4 chars ≈ 1 token (ruwe maar stabiele vuistregel)
return max(1, len(text) // 4)
def count_message_tokens(messages: List[Dict], tok_len: Callable[[str], int]) -> int:
total = 0
for m in messages:
total += tok_len(m.get("content", ""))
return total
# ---------- Thread ID + summary store
def derive_thread_id(body: Dict) -> str:
for key in ("conversation_id", "thread_id", "chat_id", "session_id", "room_id"):
if key in body and body[key]:
return str(body[key])
parts = [str(body.get("model", ""))]
msgs = body.get("messages", [])[:2]
for m in msgs:
parts.append(m.get("role", ""))
parts.append(m.get("content", "")[:256])
raw = "||".join(parts)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()[:16]
class RunningSummaryStore:
def __init__(self):
self._mem: dict[str, str] = {}
def get(self, thread_id: str) -> str:
return self._mem.get(thread_id, "")
def update(self, thread_id: str, new_summary: str):
self._mem[thread_id] = new_summary
SUMMARY_STORE = RunningSummaryStore()
# ---------- Sliding window + running summary
@dataclass
class ConversationWindow:
max_ctx_tokens: int
response_reserve: int = 2048
tok_len: Callable[[str], int] = approx_token_count
running_summary: str = ""
summary_header: str = "Samenvatting tot nu toe"
history: List[Dict] = field(default_factory=list)
def add(self, role: str, content: str):
self.history.append({"role": role, "content": content})
def _base_messages(self, system_prompt: Optional[str]) -> List[Dict]:
msgs: List[Dict] = []
if system_prompt:
msgs.append({"role": "system", "content": system_prompt})
if self.running_summary:
msgs.append({"role": "system", "content": f"{self.summary_header}:\n{self.running_summary}"})
return msgs
async def build_within_budget(
self,
system_prompt: Optional[str],
summarizer: Optional[Callable[[str, List[Dict]], "awaitable[str]"]] = None
) -> List[Dict]:
budget = self.max_ctx_tokens - max(1, self.response_reserve)
working = self.history[:]
candidate = self._base_messages(system_prompt) + working
if count_message_tokens(candidate, self.tok_len) <= budget:
return candidate
# 1) trim oudste turns
while working and count_message_tokens(self._base_messages(system_prompt) + working, self.tok_len) > budget:
working.pop(0)
candidate = self._base_messages(system_prompt) + working
if count_message_tokens(candidate, self.tok_len) <= budget:
self.history = working
return candidate
# 2) samenvatten indien mogelijk
if summarizer is None:
while working and count_message_tokens(self._base_messages(system_prompt) + working, self.tok_len) > budget:
working.pop(0)
self.history = working
return self._base_messages(system_prompt) + working
# samenvat in batches
working = self.history[:]
chunk_buf: List[Dict] = []
async def build_candidate(_summary: str, _working: List[Dict]) -> List[Dict]:
base = []
if system_prompt:
base.append({"role": "system", "content": system_prompt})
if _summary:
base.append({"role": "system", "content": f"{self.summary_header}:\n{_summary}"})
return base + _working
while working and count_message_tokens(await build_candidate(self.running_summary, working), self.tok_len) > budget:
chunk_buf.append(working.pop(0))
# bij ~1500 tokens in buffer (ruw) samenvatten
if count_message_tokens([{"role":"system","content":str(chunk_buf)}], self.tok_len) > 1500 or not working:
self.running_summary = await summarizer(self.running_summary, chunk_buf)
chunk_buf = []
self.history = working
return await build_candidate(self.running_summary, working)
# ---------- Repo chunking
from typing import Iterable
def split_text_tokens(
text: str,
tok_len: Callable[[str], int],
max_tokens: int,
overlap_tokens: int = 60
) -> List[str]:
if tok_len(text) <= max_tokens:
return [text]
approx_ratio = max_tokens / max(1, tok_len(text))
step = max(1000, int(len(text) * approx_ratio))
chunks: List[str] = []
i = 0
while i < len(text):
ch = text[i:i+step]
while tok_len(ch) > max_tokens and len(ch) > 200:
ch = ch[:-200]
chunks.append(ch)
if overlap_tokens > 0:
ov_chars = max(100, overlap_tokens * 4)
i += max(1, len(ch) - ov_chars)
else:
i += len(ch)
return chunks
def fit_context_under_budget(
items: List[Tuple[str,str]], tok_len: Callable[[str], int], budget_tokens: int
) -> List[Tuple[str,str]]:
res: List[Tuple[str,str]] = []
used = 0
for title, text in items:
t = tok_len(text)
if used + t <= budget_tokens:
res.append((title, text))
used += t
else:
break
return res
def build_repo_context(
files_ranked: List[Tuple[str, str, float]],
per_chunk_tokens: int = 1200,
overlap_tokens: int = 60,
ctx_budget_tokens: int = 4000,
tok_len: Callable[[str], int] = approx_token_count
) -> str:
expanded: List[Tuple[str,str]] = []
for path, content, _ in files_ranked:
for i, ch in enumerate(split_text_tokens(content, tok_len, per_chunk_tokens, overlap_tokens)):
expanded.append((f"{path}#chunk{i+1}", ch))
selected = fit_context_under_budget(expanded, tok_len, ctx_budget_tokens)
ctx = ""
for title, ch in selected:
ctx += f"\n\n=== {title} ===\n{ch}"
return ctx.strip()