import aiosqlite import json import time from pathlib import Path from typing import Optional from enum import Enum import structlog from src.config import settings logger = structlog.get_logger() class ResearchStatus(str, Enum): RUNNING = "running" SATURATED = "saturated" FINISHED = "finished" ERROR = "error" class OutputType(str, Enum): PODCAST = "podcast" BLOG = "blog" REPORT = "report" THREAD = "thread" REPORT_EXTENDED = "report_extended" BLOG_EXTENDED = "blog_extended" PODCAST_EXTENDED = "podcast_extended" SCHEMA = """ CREATE TABLE IF NOT EXISTS research_sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, topic TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'running', telegram_chat_id INTEGER NOT NULL, telegram_message_id INTEGER, created_at REAL NOT NULL, updated_at REAL NOT NULL, iterations INTEGER DEFAULT 0, total_sources INTEGER DEFAULT 0, total_chunks INTEGER DEFAULT 0, total_words INTEGER DEFAULT 0, meta JSON DEFAULT '{}' ); CREATE TABLE IF NOT EXISTS sources ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER NOT NULL REFERENCES research_sessions(id), url TEXT NOT NULL, title TEXT, source_type TEXT, -- wikipedia, reddit, youtube, pdf, web, rss depth INTEGER DEFAULT 0, quality_score REAL DEFAULT 0, word_count INTEGER DEFAULT 0, scraped_at REAL, status TEXT DEFAULT 'pending', -- pending, scraped, failed, skipped error TEXT, UNIQUE(session_id, url) ); CREATE TABLE IF NOT EXISTS chunks ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER NOT NULL REFERENCES research_sessions(id), source_id INTEGER NOT NULL REFERENCES sources(id), content TEXT NOT NULL, chunk_index INTEGER NOT NULL, token_count INTEGER, quality_score REAL DEFAULT 0, embedding JSON, -- stored as JSON array for sqlite-vec compat created_at REAL NOT NULL ); CREATE TABLE IF NOT EXISTS outputs ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER NOT NULL REFERENCES research_sessions(id), output_type TEXT NOT NULL, content TEXT NOT NULL, created_at REAL NOT NULL ); CREATE TABLE IF NOT EXISTS source_contents ( id INTEGER PRIMARY KEY AUTOINCREMENT, source_id INTEGER NOT NULL UNIQUE REFERENCES sources(id), content TEXT NOT NULL, created_at REAL NOT NULL ); CREATE INDEX IF NOT EXISTS idx_sources_session ON sources(session_id); CREATE INDEX IF NOT EXISTS idx_chunks_session ON chunks(session_id); CREATE INDEX IF NOT EXISTS idx_chunks_quality ON chunks(session_id, quality_score DESC); CREATE INDEX IF NOT EXISTS idx_source_contents ON source_contents(source_id); CREATE TABLE IF NOT EXISTS api_usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id INTEGER REFERENCES research_sessions(id), call_type TEXT NOT NULL, model TEXT NOT NULL, input_tokens INTEGER NOT NULL, output_tokens INTEGER NOT NULL, cost_usd REAL NOT NULL, created_at REAL NOT NULL ); CREATE TABLE IF NOT EXISTS watched_topics ( id INTEGER PRIMARY KEY AUTOINCREMENT, topic TEXT NOT NULL, chat_id INTEGER NOT NULL, interval_hours INTEGER NOT NULL DEFAULT 24, next_run_at REAL NOT NULL, last_run_at REAL, enabled INTEGER NOT NULL DEFAULT 1, created_at REAL NOT NULL, UNIQUE(topic, chat_id) ); """ async def get_db() -> aiosqlite.Connection: Path(settings.db_path).parent.mkdir(parents=True, exist_ok=True) db = await aiosqlite.connect(settings.db_path) db.row_factory = aiosqlite.Row await db.execute("PRAGMA journal_mode=WAL") await db.execute("PRAGMA synchronous=NORMAL") await db.executescript(SCHEMA) await db.commit() return db class ResearchDB: def __init__(self, db: aiosqlite.Connection): self.db = db # --- Sessions --- async def create_session(self, topic: str, chat_id: int) -> int: now = time.time() cursor = await self.db.execute( """INSERT INTO research_sessions (topic, status, telegram_chat_id, created_at, updated_at) VALUES (?, ?, ?, ?, ?)""", (topic, ResearchStatus.RUNNING, chat_id, now, now) ) await self.db.commit() return cursor.lastrowid async def get_session(self, session_id: int) -> Optional[dict]: cursor = await self.db.execute( "SELECT * FROM research_sessions WHERE id = ?", (session_id,) ) row = await cursor.fetchone() return dict(row) if row else None async def get_latest_session(self, chat_id: int) -> Optional[dict]: cursor = await self.db.execute( "SELECT * FROM research_sessions WHERE telegram_chat_id = ? ORDER BY created_at DESC LIMIT 1", (chat_id,) ) row = await cursor.fetchone() return dict(row) if row else None async def get_session_urls(self, session_id: int) -> set: async with self.db.execute( "SELECT url FROM sources WHERE session_id = ?", (session_id,) ) as cur: rows = await cur.fetchall() return {r[0] for r in rows} async def get_previous_session(self, chat_id: int, topic: str, exclude_session_id: int) -> Optional[dict]: async with self.db.execute( """SELECT id, topic, status, created_at FROM research_sessions WHERE telegram_chat_id = ? AND topic = ? AND id != ? ORDER BY created_at DESC LIMIT 1""", (chat_id, topic, exclude_session_id) ) as cur: row = await cur.fetchone() if not row: return None return {"id": row[0], "topic": row[1], "status": row[2], "created_at": row[3]} async def get_active_session(self, chat_id: int) -> Optional[dict]: cursor = await self.db.execute( """SELECT * FROM research_sessions WHERE telegram_chat_id = ? AND status = 'running' ORDER BY created_at DESC LIMIT 1""", (chat_id,) ) row = await cursor.fetchone() return dict(row) if row else None async def update_session(self, session_id: int, **kwargs): kwargs["updated_at"] = time.time() sets = ", ".join(f"{k} = ?" for k in kwargs) values = list(kwargs.values()) + [session_id] await self.db.execute( f"UPDATE research_sessions SET {sets} WHERE id = ?", values ) await self.db.commit() async def get_session_stats(self, session_id: int) -> dict: cursor = await self.db.execute( """SELECT COUNT(*) as total, SUM(CASE WHEN status='scraped' THEN 1 ELSE 0 END) as scraped, SUM(CASE WHEN status='failed' THEN 1 ELSE 0 END) as failed, SUM(CASE WHEN status='pending' THEN 1 ELSE 0 END) as pending, SUM(CASE WHEN status='skipped' THEN 1 ELSE 0 END) as skipped FROM sources WHERE session_id = ?""", (session_id,) ) row = await cursor.fetchone() return dict(row) if row else {} # --- Sources --- async def add_source(self, session_id: int, url: str, source_type: str, depth: int = 0, title: str = None) -> Optional[int]: try: cursor = await self.db.execute( """INSERT OR IGNORE INTO sources (session_id, url, title, source_type, depth) VALUES (?, ?, ?, ?, ?)""", (session_id, url, title, source_type, depth) ) await self.db.commit() return cursor.lastrowid if cursor.rowcount > 0 else None except Exception: return None async def update_source(self, source_id: int, **kwargs): sets = ", ".join(f"{k} = ?" for k in kwargs) values = list(kwargs.values()) + [source_id] await self.db.execute(f"UPDATE sources SET {sets} WHERE id = ?", values) await self.db.commit() async def get_pending_sources(self, session_id: int, limit: int = 10) -> list[dict]: cursor = await self.db.execute( """SELECT * FROM sources WHERE session_id = ? AND status = 'pending' ORDER BY depth ASC, id ASC LIMIT ?""", (session_id, limit) ) rows = await cursor.fetchall() return [dict(r) for r in rows] async def get_all_sources(self, session_id: int) -> list[dict]: cursor = await self.db.execute( "SELECT * FROM sources WHERE session_id = ? ORDER BY quality_score DESC", (session_id,) ) rows = await cursor.fetchall() return [dict(r) for r in rows] async def source_exists(self, session_id: int, url: str) -> bool: cursor = await self.db.execute( "SELECT 1 FROM sources WHERE session_id = ? AND url = ?", (session_id, url) ) return await cursor.fetchone() is not None # --- Chunks --- async def add_chunk(self, session_id: int, source_id: int, content: str, chunk_index: int, token_count: int, quality_score: float, embedding: Optional[list] = None) -> int: cursor = await self.db.execute( """INSERT INTO chunks (session_id, source_id, content, chunk_index, token_count, quality_score, embedding, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", (session_id, source_id, content, chunk_index, token_count, quality_score, json.dumps(embedding) if embedding else None, time.time()) ) await self.db.commit() return cursor.lastrowid async def get_top_chunks(self, session_id: int, limit: int = 50) -> list[dict]: cursor = await self.db.execute( """SELECT c.*, s.url, s.title, s.source_type FROM chunks c JOIN sources s ON c.source_id = s.id WHERE c.session_id = ? AND c.quality_score >= ? ORDER BY c.quality_score DESC LIMIT ?""", (session_id, settings.quality_threshold, limit) ) rows = await cursor.fetchall() return [dict(r) for r in rows] async def get_chunks_count(self, session_id: int) -> int: cursor = await self.db.execute( "SELECT COUNT(*) FROM chunks WHERE session_id = ?", (session_id,) ) row = await cursor.fetchone() return row[0] # --- Outputs --- async def save_output(self, session_id: int, output_type: str, content: str) -> int: cursor = await self.db.execute( "INSERT INTO outputs (session_id, output_type, content, created_at) VALUES (?, ?, ?, ?)", (session_id, output_type, content, time.time()) ) await self.db.commit() return cursor.lastrowid async def save_source_content(self, source_id: int, content: str): await self.db.execute( """INSERT OR REPLACE INTO source_contents (source_id, content, created_at) VALUES (?, ?, ?)""", (source_id, content, time.time()) ) await self.db.commit() async def get_source_content(self, source_id: int) -> Optional[str]: cursor = await self.db.execute( "SELECT content FROM source_contents WHERE source_id = ?", (source_id,) ) row = await cursor.fetchone() return row[0] if row else None async def get_cached_content(self, url: str, max_age_days: int = 7) -> Optional[str]: threshold = time.time() - (max_age_days * 86400) async with self.db.execute( """SELECT sc.content FROM source_contents sc JOIN sources s ON s.id = sc.source_id WHERE s.url = ? AND sc.created_at > ? ORDER BY sc.created_at DESC LIMIT 1""", (url, threshold) ) as cur: row = await cur.fetchone() return row[0] if row else None async def get_outputs(self, session_id: int) -> list[dict]: cursor = await self.db.execute( "SELECT * FROM outputs WHERE session_id = ? ORDER BY created_at DESC", (session_id,) ) rows = await cursor.fetchall() return [dict(r) for r in rows] # --- API Usage --- async def log_api_call(self, session_id, call_type: str, model: str, input_tokens: int, output_tokens: int): # Precios Claude Haiku (claude-haiku-4-5): # input: $0.80 / 1M tokens output: $4.00 / 1M tokens cost = (input_tokens * 0.80 + output_tokens * 4.00) / 1_000_000 await self.db.execute( """INSERT INTO api_usage (session_id, call_type, model, input_tokens, output_tokens, cost_usd, created_at) VALUES (?,?,?,?,?,?,?)""", (session_id, call_type, model, input_tokens, output_tokens, cost, time.time()) ) await self.db.commit() async def get_usage_stats(self, session_id: int) -> list[dict]: cursor = await self.db.execute( """SELECT call_type, COUNT(*) as calls, SUM(input_tokens + output_tokens) as total_tokens, SUM(cost_usd) as total_cost FROM api_usage WHERE session_id = ? GROUP BY call_type""", (session_id,) ) rows = await cursor.fetchall() return [dict(r) for r in rows] async def get_total_usage_stats(self) -> dict: cursor = await self.db.execute( """SELECT COUNT(DISTINCT session_id) as sessions, SUM(cost_usd) as total_cost FROM api_usage""" ) row = await cursor.fetchone() return dict(row) if row else {"sessions": 0, "total_cost": 0} # --- Watched Topics --- async def add_watch(self, topic: str, chat_id: int, interval_hours: int) -> int: now = time.time() cursor = await self.db.execute( """INSERT OR REPLACE INTO watched_topics (topic, chat_id, interval_hours, next_run_at, created_at) VALUES (?, ?, ?, ?, ?)""", (topic, chat_id, interval_hours, now + interval_hours * 3600, now) ) await self.db.commit() return cursor.lastrowid async def remove_watch(self, topic: str, chat_id: int) -> bool: cursor = await self.db.execute( "DELETE FROM watched_topics WHERE topic = ? AND chat_id = ?", (topic, chat_id) ) await self.db.commit() return cursor.rowcount > 0 async def list_watches(self, chat_id: int) -> list[dict]: cursor = await self.db.execute( "SELECT * FROM watched_topics WHERE chat_id = ? ORDER BY created_at ASC", (chat_id,) ) rows = await cursor.fetchall() return [dict(r) for r in rows] async def get_due_watches(self) -> list[dict]: cursor = await self.db.execute( "SELECT * FROM watched_topics WHERE enabled = 1 AND next_run_at <= ?", (time.time(),) ) rows = await cursor.fetchall() return [dict(r) for r in rows] async def update_watch_run(self, watch_id: int): cursor = await self.db.execute( "SELECT interval_hours FROM watched_topics WHERE id = ?", (watch_id,) ) row = await cursor.fetchone() if not row: return now = time.time() await self.db.execute( "UPDATE watched_topics SET last_run_at = ?, next_run_at = ? WHERE id = ?", (now, now + row[0] * 3600, watch_id) ) await self.db.commit() # --- Maintenance --- async def purge_old_sessions(self, max_age_days: int = 30) -> dict: await self.db.execute("PRAGMA foreign_keys = ON") threshold = time.time() - max_age_days * 86400 cursor = await self.db.execute( "SELECT id FROM research_sessions WHERE created_at < ? AND status != 'running'", (threshold,) ) session_ids = [row[0] for row in await cursor.fetchall()] counts = {"sessions": 0, "sources": 0, "chunks": 0, "outputs": 0} for sid in session_ids: await self.db.execute( "DELETE FROM source_contents WHERE source_id IN (SELECT id FROM sources WHERE session_id = ?)", (sid,) ) cur = await self.db.execute("DELETE FROM chunks WHERE session_id = ?", (sid,)) counts["chunks"] += cur.rowcount cur = await self.db.execute("DELETE FROM outputs WHERE session_id = ?", (sid,)) counts["outputs"] += cur.rowcount cur = await self.db.execute("DELETE FROM sources WHERE session_id = ?", (sid,)) counts["sources"] += cur.rowcount cur = await self.db.execute("DELETE FROM research_sessions WHERE id = ?", (sid,)) counts["sessions"] += cur.rowcount await self.db.commit() logger.info("Purged sessions older than days", sessions=counts["sessions"], days=max_age_days) return counts