462 lines
17 KiB
Python
462 lines
17 KiB
Python
import aiosqlite
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from enum import Enum
|
|
|
|
import structlog
|
|
|
|
from src.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class ResearchStatus(str, Enum):
|
|
RUNNING = "running"
|
|
SATURATED = "saturated"
|
|
FINISHED = "finished"
|
|
ERROR = "error"
|
|
|
|
|
|
class OutputType(str, Enum):
|
|
PODCAST = "podcast"
|
|
BLOG = "blog"
|
|
REPORT = "report"
|
|
THREAD = "thread"
|
|
REPORT_EXTENDED = "report_extended"
|
|
BLOG_EXTENDED = "blog_extended"
|
|
PODCAST_EXTENDED = "podcast_extended"
|
|
|
|
|
|
SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS research_sessions (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
topic TEXT NOT NULL,
|
|
status TEXT NOT NULL DEFAULT 'running',
|
|
telegram_chat_id INTEGER NOT NULL,
|
|
telegram_message_id INTEGER,
|
|
created_at REAL NOT NULL,
|
|
updated_at REAL NOT NULL,
|
|
iterations INTEGER DEFAULT 0,
|
|
total_sources INTEGER DEFAULT 0,
|
|
total_chunks INTEGER DEFAULT 0,
|
|
total_words INTEGER DEFAULT 0,
|
|
meta JSON DEFAULT '{}'
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS sources (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id INTEGER NOT NULL REFERENCES research_sessions(id),
|
|
url TEXT NOT NULL,
|
|
title TEXT,
|
|
source_type TEXT, -- wikipedia, reddit, youtube, pdf, web, rss
|
|
depth INTEGER DEFAULT 0,
|
|
quality_score REAL DEFAULT 0,
|
|
word_count INTEGER DEFAULT 0,
|
|
scraped_at REAL,
|
|
status TEXT DEFAULT 'pending', -- pending, scraped, failed, skipped
|
|
error TEXT,
|
|
UNIQUE(session_id, url)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS chunks (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id INTEGER NOT NULL REFERENCES research_sessions(id),
|
|
source_id INTEGER NOT NULL REFERENCES sources(id),
|
|
content TEXT NOT NULL,
|
|
chunk_index INTEGER NOT NULL,
|
|
token_count INTEGER,
|
|
quality_score REAL DEFAULT 0,
|
|
embedding JSON, -- stored as JSON array for sqlite-vec compat
|
|
created_at REAL NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS outputs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id INTEGER NOT NULL REFERENCES research_sessions(id),
|
|
output_type TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
created_at REAL NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS source_contents (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
source_id INTEGER NOT NULL UNIQUE REFERENCES sources(id),
|
|
content TEXT NOT NULL,
|
|
created_at REAL NOT NULL
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_sources_session ON sources(session_id);
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_session ON chunks(session_id);
|
|
CREATE INDEX IF NOT EXISTS idx_chunks_quality ON chunks(session_id, quality_score DESC);
|
|
CREATE INDEX IF NOT EXISTS idx_source_contents ON source_contents(source_id);
|
|
|
|
CREATE TABLE IF NOT EXISTS api_usage (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
session_id INTEGER REFERENCES research_sessions(id),
|
|
call_type TEXT NOT NULL,
|
|
model TEXT NOT NULL,
|
|
input_tokens INTEGER NOT NULL,
|
|
output_tokens INTEGER NOT NULL,
|
|
cost_usd REAL NOT NULL,
|
|
created_at REAL NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS watched_topics (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
topic TEXT NOT NULL,
|
|
chat_id INTEGER NOT NULL,
|
|
interval_hours INTEGER NOT NULL DEFAULT 24,
|
|
next_run_at REAL NOT NULL,
|
|
last_run_at REAL,
|
|
enabled INTEGER NOT NULL DEFAULT 1,
|
|
created_at REAL NOT NULL,
|
|
UNIQUE(topic, chat_id)
|
|
);
|
|
"""
|
|
|
|
|
|
async def get_db() -> aiosqlite.Connection:
|
|
Path(settings.db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
db = await aiosqlite.connect(settings.db_path)
|
|
db.row_factory = aiosqlite.Row
|
|
await db.execute("PRAGMA journal_mode=WAL")
|
|
await db.execute("PRAGMA synchronous=NORMAL")
|
|
await db.executescript(SCHEMA)
|
|
await db.commit()
|
|
return db
|
|
|
|
|
|
class ResearchDB:
|
|
def __init__(self, db: aiosqlite.Connection):
|
|
self.db = db
|
|
|
|
# --- Sessions ---
|
|
|
|
async def create_session(self, topic: str, chat_id: int) -> int:
|
|
now = time.time()
|
|
cursor = await self.db.execute(
|
|
"""INSERT INTO research_sessions (topic, status, telegram_chat_id, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?)""",
|
|
(topic, ResearchStatus.RUNNING, chat_id, now, now)
|
|
)
|
|
await self.db.commit()
|
|
return cursor.lastrowid
|
|
|
|
async def get_session(self, session_id: int) -> Optional[dict]:
|
|
cursor = await self.db.execute(
|
|
"SELECT * FROM research_sessions WHERE id = ?", (session_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
async def get_latest_session(self, chat_id: int) -> Optional[dict]:
|
|
cursor = await self.db.execute(
|
|
"SELECT * FROM research_sessions WHERE telegram_chat_id = ? ORDER BY created_at DESC LIMIT 1",
|
|
(chat_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
async def get_session_urls(self, session_id: int) -> set:
|
|
async with self.db.execute(
|
|
"SELECT url FROM sources WHERE session_id = ?", (session_id,)
|
|
) as cur:
|
|
rows = await cur.fetchall()
|
|
return {r[0] for r in rows}
|
|
|
|
async def get_previous_session(self, chat_id: int, topic: str,
|
|
exclude_session_id: int) -> Optional[dict]:
|
|
async with self.db.execute(
|
|
"""SELECT id, topic, status, created_at FROM research_sessions
|
|
WHERE telegram_chat_id = ? AND topic = ? AND id != ?
|
|
ORDER BY created_at DESC LIMIT 1""",
|
|
(chat_id, topic, exclude_session_id)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
if not row:
|
|
return None
|
|
return {"id": row[0], "topic": row[1],
|
|
"status": row[2], "created_at": row[3]}
|
|
|
|
async def get_active_session(self, chat_id: int) -> Optional[dict]:
|
|
cursor = await self.db.execute(
|
|
"""SELECT * FROM research_sessions
|
|
WHERE telegram_chat_id = ? AND status = 'running'
|
|
ORDER BY created_at DESC LIMIT 1""",
|
|
(chat_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else None
|
|
|
|
async def update_session(self, session_id: int, **kwargs):
|
|
kwargs["updated_at"] = time.time()
|
|
sets = ", ".join(f"{k} = ?" for k in kwargs)
|
|
values = list(kwargs.values()) + [session_id]
|
|
await self.db.execute(
|
|
f"UPDATE research_sessions SET {sets} WHERE id = ?", values
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_session_stats(self, session_id: int) -> dict:
|
|
cursor = await self.db.execute(
|
|
"""SELECT
|
|
COUNT(*) as total,
|
|
SUM(CASE WHEN status='scraped' THEN 1 ELSE 0 END) as scraped,
|
|
SUM(CASE WHEN status='failed' THEN 1 ELSE 0 END) as failed,
|
|
SUM(CASE WHEN status='pending' THEN 1 ELSE 0 END) as pending,
|
|
SUM(CASE WHEN status='skipped' THEN 1 ELSE 0 END) as skipped
|
|
FROM sources WHERE session_id = ?""",
|
|
(session_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else {}
|
|
|
|
# --- Sources ---
|
|
|
|
async def add_source(self, session_id: int, url: str, source_type: str,
|
|
depth: int = 0, title: str = None) -> Optional[int]:
|
|
try:
|
|
cursor = await self.db.execute(
|
|
"""INSERT OR IGNORE INTO sources (session_id, url, title, source_type, depth)
|
|
VALUES (?, ?, ?, ?, ?)""",
|
|
(session_id, url, title, source_type, depth)
|
|
)
|
|
await self.db.commit()
|
|
return cursor.lastrowid if cursor.rowcount > 0 else None
|
|
except Exception:
|
|
return None
|
|
|
|
async def update_source(self, source_id: int, **kwargs):
|
|
sets = ", ".join(f"{k} = ?" for k in kwargs)
|
|
values = list(kwargs.values()) + [source_id]
|
|
await self.db.execute(f"UPDATE sources SET {sets} WHERE id = ?", values)
|
|
await self.db.commit()
|
|
|
|
async def get_pending_sources(self, session_id: int, limit: int = 10) -> list[dict]:
|
|
cursor = await self.db.execute(
|
|
"""SELECT * FROM sources WHERE session_id = ? AND status = 'pending'
|
|
ORDER BY depth ASC, id ASC LIMIT ?""",
|
|
(session_id, limit)
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
async def get_all_sources(self, session_id: int) -> list[dict]:
|
|
cursor = await self.db.execute(
|
|
"SELECT * FROM sources WHERE session_id = ? ORDER BY quality_score DESC",
|
|
(session_id,)
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
async def source_exists(self, session_id: int, url: str) -> bool:
|
|
cursor = await self.db.execute(
|
|
"SELECT 1 FROM sources WHERE session_id = ? AND url = ?",
|
|
(session_id, url)
|
|
)
|
|
return await cursor.fetchone() is not None
|
|
|
|
# --- Chunks ---
|
|
|
|
async def add_chunk(self, session_id: int, source_id: int, content: str,
|
|
chunk_index: int, token_count: int, quality_score: float,
|
|
embedding: Optional[list] = None) -> int:
|
|
cursor = await self.db.execute(
|
|
"""INSERT INTO chunks (session_id, source_id, content, chunk_index,
|
|
token_count, quality_score, embedding, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(session_id, source_id, content, chunk_index,
|
|
token_count, quality_score,
|
|
json.dumps(embedding) if embedding else None,
|
|
time.time())
|
|
)
|
|
await self.db.commit()
|
|
return cursor.lastrowid
|
|
|
|
async def get_top_chunks(self, session_id: int, limit: int = 50) -> list[dict]:
|
|
cursor = await self.db.execute(
|
|
"""SELECT c.*, s.url, s.title, s.source_type FROM chunks c
|
|
JOIN sources s ON c.source_id = s.id
|
|
WHERE c.session_id = ? AND c.quality_score >= ?
|
|
ORDER BY c.quality_score DESC LIMIT ?""",
|
|
(session_id, settings.quality_threshold, limit)
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
async def get_chunks_count(self, session_id: int) -> int:
|
|
cursor = await self.db.execute(
|
|
"SELECT COUNT(*) FROM chunks WHERE session_id = ?", (session_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
return row[0]
|
|
|
|
# --- Outputs ---
|
|
|
|
async def save_output(self, session_id: int, output_type: str, content: str) -> int:
|
|
cursor = await self.db.execute(
|
|
"INSERT INTO outputs (session_id, output_type, content, created_at) VALUES (?, ?, ?, ?)",
|
|
(session_id, output_type, content, time.time())
|
|
)
|
|
await self.db.commit()
|
|
return cursor.lastrowid
|
|
|
|
async def save_source_content(self, source_id: int, content: str):
|
|
await self.db.execute(
|
|
"""INSERT OR REPLACE INTO source_contents (source_id, content, created_at)
|
|
VALUES (?, ?, ?)""",
|
|
(source_id, content, time.time())
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_source_content(self, source_id: int) -> Optional[str]:
|
|
cursor = await self.db.execute(
|
|
"SELECT content FROM source_contents WHERE source_id = ?", (source_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
return row[0] if row else None
|
|
|
|
async def get_cached_content(self, url: str,
|
|
max_age_days: int = 7) -> Optional[str]:
|
|
threshold = time.time() - (max_age_days * 86400)
|
|
async with self.db.execute(
|
|
"""SELECT sc.content FROM source_contents sc
|
|
JOIN sources s ON s.id = sc.source_id
|
|
WHERE s.url = ? AND sc.created_at > ?
|
|
ORDER BY sc.created_at DESC LIMIT 1""",
|
|
(url, threshold)
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
return row[0] if row else None
|
|
|
|
async def get_outputs(self, session_id: int) -> list[dict]:
|
|
cursor = await self.db.execute(
|
|
"SELECT * FROM outputs WHERE session_id = ? ORDER BY created_at DESC",
|
|
(session_id,)
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
# --- API Usage ---
|
|
|
|
async def log_api_call(self, session_id, call_type: str, model: str,
|
|
input_tokens: int, output_tokens: int):
|
|
# Precios Claude Haiku (claude-haiku-4-5):
|
|
# input: $0.80 / 1M tokens output: $4.00 / 1M tokens
|
|
cost = (input_tokens * 0.80 + output_tokens * 4.00) / 1_000_000
|
|
await self.db.execute(
|
|
"""INSERT INTO api_usage
|
|
(session_id, call_type, model, input_tokens, output_tokens, cost_usd, created_at)
|
|
VALUES (?,?,?,?,?,?,?)""",
|
|
(session_id, call_type, model, input_tokens, output_tokens, cost, time.time())
|
|
)
|
|
await self.db.commit()
|
|
|
|
async def get_usage_stats(self, session_id: int) -> list[dict]:
|
|
cursor = await self.db.execute(
|
|
"""SELECT call_type,
|
|
COUNT(*) as calls,
|
|
SUM(input_tokens + output_tokens) as total_tokens,
|
|
SUM(cost_usd) as total_cost
|
|
FROM api_usage WHERE session_id = ?
|
|
GROUP BY call_type""",
|
|
(session_id,)
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
async def get_total_usage_stats(self) -> dict:
|
|
cursor = await self.db.execute(
|
|
"""SELECT COUNT(DISTINCT session_id) as sessions,
|
|
SUM(cost_usd) as total_cost
|
|
FROM api_usage"""
|
|
)
|
|
row = await cursor.fetchone()
|
|
return dict(row) if row else {"sessions": 0, "total_cost": 0}
|
|
|
|
# --- Watched Topics ---
|
|
|
|
async def add_watch(self, topic: str, chat_id: int, interval_hours: int) -> int:
|
|
now = time.time()
|
|
cursor = await self.db.execute(
|
|
"""INSERT OR REPLACE INTO watched_topics
|
|
(topic, chat_id, interval_hours, next_run_at, created_at)
|
|
VALUES (?, ?, ?, ?, ?)""",
|
|
(topic, chat_id, interval_hours, now + interval_hours * 3600, now)
|
|
)
|
|
await self.db.commit()
|
|
return cursor.lastrowid
|
|
|
|
async def remove_watch(self, topic: str, chat_id: int) -> bool:
|
|
cursor = await self.db.execute(
|
|
"DELETE FROM watched_topics WHERE topic = ? AND chat_id = ?",
|
|
(topic, chat_id)
|
|
)
|
|
await self.db.commit()
|
|
return cursor.rowcount > 0
|
|
|
|
async def list_watches(self, chat_id: int) -> list[dict]:
|
|
cursor = await self.db.execute(
|
|
"SELECT * FROM watched_topics WHERE chat_id = ? ORDER BY created_at ASC",
|
|
(chat_id,)
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
async def get_due_watches(self) -> list[dict]:
|
|
cursor = await self.db.execute(
|
|
"SELECT * FROM watched_topics WHERE enabled = 1 AND next_run_at <= ?",
|
|
(time.time(),)
|
|
)
|
|
rows = await cursor.fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
async def update_watch_run(self, watch_id: int):
|
|
cursor = await self.db.execute(
|
|
"SELECT interval_hours FROM watched_topics WHERE id = ?", (watch_id,)
|
|
)
|
|
row = await cursor.fetchone()
|
|
if not row:
|
|
return
|
|
now = time.time()
|
|
await self.db.execute(
|
|
"UPDATE watched_topics SET last_run_at = ?, next_run_at = ? WHERE id = ?",
|
|
(now, now + row[0] * 3600, watch_id)
|
|
)
|
|
await self.db.commit()
|
|
|
|
# --- Maintenance ---
|
|
|
|
async def purge_old_sessions(self, max_age_days: int = 30) -> dict:
|
|
await self.db.execute("PRAGMA foreign_keys = ON")
|
|
|
|
threshold = time.time() - max_age_days * 86400
|
|
cursor = await self.db.execute(
|
|
"SELECT id FROM research_sessions WHERE created_at < ? AND status != 'running'",
|
|
(threshold,)
|
|
)
|
|
session_ids = [row[0] for row in await cursor.fetchall()]
|
|
|
|
counts = {"sessions": 0, "sources": 0, "chunks": 0, "outputs": 0}
|
|
|
|
for sid in session_ids:
|
|
await self.db.execute(
|
|
"DELETE FROM source_contents WHERE source_id IN (SELECT id FROM sources WHERE session_id = ?)",
|
|
(sid,)
|
|
)
|
|
cur = await self.db.execute("DELETE FROM chunks WHERE session_id = ?", (sid,))
|
|
counts["chunks"] += cur.rowcount
|
|
cur = await self.db.execute("DELETE FROM outputs WHERE session_id = ?", (sid,))
|
|
counts["outputs"] += cur.rowcount
|
|
cur = await self.db.execute("DELETE FROM sources WHERE session_id = ?", (sid,))
|
|
counts["sources"] += cur.rowcount
|
|
cur = await self.db.execute("DELETE FROM research_sessions WHERE id = ?", (sid,))
|
|
counts["sessions"] += cur.rowcount
|
|
|
|
await self.db.commit()
|
|
logger.info("Purged sessions older than days",
|
|
sessions=counts["sessions"], days=max_age_days)
|
|
return counts
|