This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
ResearchOwl Processor
|
||||
Chunking → Quality scoring via Ollama → Embeddings → RAG synthesis
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
|
||||
from src.config import settings
|
||||
from src.db.database import ResearchDB
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Async client for Ollama API"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = settings.ollama_url.rstrip("/")
|
||||
self.model = settings.ollama_model
|
||||
|
||||
async def generate(self, prompt: str, system: str = None,
|
||||
timeout: int = 120) -> str:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": 0.1, "num_predict": 512}
|
||||
}
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
resp = await client.post(f"{self.base_url}/api/generate", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("response", "").strip()
|
||||
|
||||
async def embed(self, text: str) -> Optional[list[float]]:
|
||||
"""Get embedding vector for a text"""
|
||||
payload = {"model": self.model, "prompt": text}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
resp = await client.post(f"{self.base_url}/api/embeddings", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("embedding")
|
||||
except Exception as e:
|
||||
logger.warning("Embedding failed", error=str(e))
|
||||
return None
|
||||
|
||||
async def is_available(self) -> bool:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
resp = await client.get(f"{self.base_url}/api/tags")
|
||||
return resp.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def simple_chunk(text: str, chunk_size: int = 800, overlap: int = 100) -> list[str]:
|
||||
"""
|
||||
Split text into overlapping chunks by approximate word count.
|
||||
Respects paragraph boundaries when possible.
|
||||
"""
|
||||
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
||||
chunks = []
|
||||
current = []
|
||||
current_words = 0
|
||||
|
||||
for para in paragraphs:
|
||||
para_words = len(para.split())
|
||||
if current_words + para_words > chunk_size and current:
|
||||
chunks.append("\n\n".join(current))
|
||||
# overlap: keep last paragraph
|
||||
if overlap > 0 and current:
|
||||
current = [current[-1]]
|
||||
current_words = len(current[0].split())
|
||||
else:
|
||||
current = []
|
||||
current_words = 0
|
||||
current.append(para)
|
||||
current_words += para_words
|
||||
|
||||
if current:
|
||||
chunks.append("\n\n".join(current))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Simple cosine similarity"""
|
||||
if not a or not b or len(a) != len(b):
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
norm_a = math.sqrt(sum(x * x for x in a))
|
||||
norm_b = math.sqrt(sum(x * x for x in b))
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
|
||||
class ContentProcessor:
|
||||
"""
|
||||
Processes scraped sources:
|
||||
1. Chunks content
|
||||
2. Scores quality with Ollama
|
||||
3. Generates embeddings
|
||||
4. Stores high-quality chunks
|
||||
"""
|
||||
|
||||
def __init__(self, db: ResearchDB, ollama: OllamaClient):
|
||||
self.db = db
|
||||
self.ollama = ollama
|
||||
|
||||
async def process_session(self, session_id: int, topic: str,
|
||||
progress_callback=None) -> dict:
|
||||
"""Process all scraped sources for a session"""
|
||||
from src.db.database import ResearchDB
|
||||
sources = await self.db.get_all_sources(session_id)
|
||||
scraped = [s for s in sources if s["status"] == "scraped"]
|
||||
|
||||
logger.info("Processing sources", total=len(scraped))
|
||||
total_chunks = 0
|
||||
total_words = 0
|
||||
|
||||
semaphore = asyncio.Semaphore(3) # process 3 sources at once
|
||||
|
||||
async def process_one(source):
|
||||
async with semaphore:
|
||||
n = await self._process_source(session_id, topic, source)
|
||||
return n
|
||||
|
||||
results = await asyncio.gather(*[process_one(s) for s in scraped],
|
||||
return_exceptions=True)
|
||||
|
||||
for r in results:
|
||||
if isinstance(r, int):
|
||||
total_chunks += r
|
||||
|
||||
total_words = sum(s.get("word_count", 0) for s in scraped)
|
||||
await self.db.update_session(
|
||||
session_id,
|
||||
total_chunks=total_chunks,
|
||||
total_words=total_words
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(total_chunks=total_chunks, total_words=total_words)
|
||||
|
||||
return {"total_chunks": total_chunks, "total_words": total_words}
|
||||
|
||||
async def _process_source(self, session_id: int, topic: str, source: dict) -> int:
|
||||
"""Chunk, score, embed and store a single source. Returns chunk count."""
|
||||
source_id = source["id"]
|
||||
|
||||
content = await self.db.get_source_content(source_id)
|
||||
if not content:
|
||||
return 0
|
||||
|
||||
chunks = simple_chunk(content, settings.chunk_size, settings.chunk_overlap)
|
||||
stored = 0
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
if len(chunk.split()) < 30:
|
||||
continue
|
||||
|
||||
quality = await self._score_quality(chunk, topic)
|
||||
if quality < settings.quality_threshold:
|
||||
continue
|
||||
|
||||
embedding = await self.ollama.embed(chunk[:1000])
|
||||
|
||||
await self.db.add_chunk(
|
||||
session_id=session_id,
|
||||
source_id=source_id,
|
||||
content=chunk,
|
||||
chunk_index=i,
|
||||
token_count=len(chunk.split()),
|
||||
quality_score=quality,
|
||||
embedding=embedding
|
||||
)
|
||||
stored += 1
|
||||
|
||||
return stored
|
||||
|
||||
async def _score_quality(self, chunk: str, topic: str) -> float:
|
||||
"""
|
||||
Ask Ollama to score relevance and quality of a chunk.
|
||||
Returns 0.0-1.0
|
||||
"""
|
||||
prompt = f"""Rate this text chunk on a scale of 0-10 for:
|
||||
1. Relevance to topic: "{topic}"
|
||||
2. Information density (facts, data, insights)
|
||||
3. Credibility (not speculation, not clickbait)
|
||||
|
||||
Text:
|
||||
{chunk[:500]}
|
||||
|
||||
Respond with ONLY a single number 0-10. No explanation."""
|
||||
|
||||
try:
|
||||
response = await self.ollama.generate(prompt)
|
||||
# Extract number from response
|
||||
numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', response)
|
||||
if numbers:
|
||||
score = float(numbers[0])
|
||||
return min(1.0, score / 10.0)
|
||||
return 0.5
|
||||
except Exception:
|
||||
return 0.5 # default on error
|
||||
|
||||
async def rag_query(self, session_id: int, query: str, top_k: int = 20) -> str:
|
||||
"""
|
||||
Retrieve most relevant chunks for a query using embeddings + keyword fallback
|
||||
"""
|
||||
# Get query embedding
|
||||
query_embedding = await self.ollama.embed(query)
|
||||
|
||||
# Get top quality chunks
|
||||
chunks = await self.db.get_top_chunks(session_id, limit=100)
|
||||
|
||||
if query_embedding and chunks:
|
||||
# Rank by embedding similarity
|
||||
scored = []
|
||||
for chunk in chunks:
|
||||
emb = chunk.get("embedding")
|
||||
if emb and isinstance(emb, str):
|
||||
try:
|
||||
emb = json.loads(emb)
|
||||
except Exception:
|
||||
emb = None
|
||||
sim = cosine_similarity(query_embedding, emb) if emb else 0.5
|
||||
scored.append((sim * 0.7 + chunk["quality_score"] * 0.3, chunk))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
top_chunks = [c for _, c in scored[:top_k]]
|
||||
else:
|
||||
# Fallback: just use quality score
|
||||
top_chunks = chunks[:top_k]
|
||||
|
||||
# Build context
|
||||
context_parts = []
|
||||
for chunk in top_chunks:
|
||||
source_label = f"[{chunk.get('source_type', 'web').upper()}] {chunk.get('title', 'Unknown')}"
|
||||
context_parts.append(f"{source_label}:\n{chunk['content']}")
|
||||
|
||||
return "\n\n---\n\n".join(context_parts)
|
||||
Reference in New Issue
Block a user