This commit is contained in:
@@ -0,0 +1,265 @@
|
||||
import aiosqlite
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.config import settings
|
||||
|
||||
|
||||
class ResearchStatus(str, Enum):
|
||||
RUNNING = "running"
|
||||
SATURATED = "saturated"
|
||||
FINISHED = "finished"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class OutputType(str, Enum):
|
||||
PODCAST = "podcast"
|
||||
BLOG = "blog"
|
||||
REPORT = "report"
|
||||
THREAD = "thread"
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS research_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
topic TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'running',
|
||||
telegram_chat_id INTEGER NOT NULL,
|
||||
telegram_message_id INTEGER,
|
||||
created_at REAL NOT NULL,
|
||||
updated_at REAL NOT NULL,
|
||||
iterations INTEGER DEFAULT 0,
|
||||
total_sources INTEGER DEFAULT 0,
|
||||
total_chunks INTEGER DEFAULT 0,
|
||||
total_words INTEGER DEFAULT 0,
|
||||
meta JSON DEFAULT '{}'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS sources (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id INTEGER NOT NULL REFERENCES research_sessions(id),
|
||||
url TEXT NOT NULL,
|
||||
title TEXT,
|
||||
source_type TEXT, -- wikipedia, reddit, youtube, pdf, web, rss
|
||||
depth INTEGER DEFAULT 0,
|
||||
quality_score REAL DEFAULT 0,
|
||||
word_count INTEGER DEFAULT 0,
|
||||
scraped_at REAL,
|
||||
status TEXT DEFAULT 'pending', -- pending, scraped, failed, skipped
|
||||
error TEXT,
|
||||
UNIQUE(session_id, url)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS chunks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id INTEGER NOT NULL REFERENCES research_sessions(id),
|
||||
source_id INTEGER NOT NULL REFERENCES sources(id),
|
||||
content TEXT NOT NULL,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
token_count INTEGER,
|
||||
quality_score REAL DEFAULT 0,
|
||||
embedding JSON, -- stored as JSON array for sqlite-vec compat
|
||||
created_at REAL NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS outputs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
session_id INTEGER NOT NULL REFERENCES research_sessions(id),
|
||||
output_type TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at REAL NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS source_contents (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
source_id INTEGER NOT NULL UNIQUE REFERENCES sources(id),
|
||||
content TEXT NOT NULL,
|
||||
created_at REAL NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_sources_session ON sources(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_session ON chunks(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_quality ON chunks(session_id, quality_score DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_source_contents ON source_contents(source_id);
|
||||
"""
|
||||
|
||||
|
||||
async def get_db() -> aiosqlite.Connection:
|
||||
Path(settings.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
db = await aiosqlite.connect(settings.db_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await db.executescript(SCHEMA)
|
||||
await db.commit()
|
||||
return db
|
||||
|
||||
|
||||
class ResearchDB:
|
||||
def __init__(self, db: aiosqlite.Connection):
|
||||
self.db = db
|
||||
|
||||
# --- Sessions ---
|
||||
|
||||
async def create_session(self, topic: str, chat_id: int) -> int:
|
||||
now = time.time()
|
||||
cursor = await self.db.execute(
|
||||
"""INSERT INTO research_sessions (topic, status, telegram_chat_id, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(topic, ResearchStatus.RUNNING, chat_id, now, now)
|
||||
)
|
||||
await self.db.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
async def get_session(self, session_id: int) -> Optional[dict]:
|
||||
cursor = await self.db.execute(
|
||||
"SELECT * FROM research_sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
async def get_active_session(self, chat_id: int) -> Optional[dict]:
|
||||
cursor = await self.db.execute(
|
||||
"""SELECT * FROM research_sessions
|
||||
WHERE telegram_chat_id = ? AND status = 'running'
|
||||
ORDER BY created_at DESC LIMIT 1""",
|
||||
(chat_id,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
async def update_session(self, session_id: int, **kwargs):
|
||||
kwargs["updated_at"] = time.time()
|
||||
sets = ", ".join(f"{k} = ?" for k in kwargs)
|
||||
values = list(kwargs.values()) + [session_id]
|
||||
await self.db.execute(
|
||||
f"UPDATE research_sessions SET {sets} WHERE id = ?", values
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
async def get_session_stats(self, session_id: int) -> dict:
|
||||
cursor = await self.db.execute(
|
||||
"""SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN status='scraped' THEN 1 ELSE 0 END) as scraped,
|
||||
SUM(CASE WHEN status='failed' THEN 1 ELSE 0 END) as failed,
|
||||
SUM(CASE WHEN status='pending' THEN 1 ELSE 0 END) as pending
|
||||
FROM sources WHERE session_id = ?""",
|
||||
(session_id,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return dict(row) if row else {}
|
||||
|
||||
# --- Sources ---
|
||||
|
||||
async def add_source(self, session_id: int, url: str, source_type: str,
|
||||
depth: int = 0, title: str = None) -> Optional[int]:
|
||||
try:
|
||||
cursor = await self.db.execute(
|
||||
"""INSERT OR IGNORE INTO sources (session_id, url, title, source_type, depth)
|
||||
VALUES (?, ?, ?, ?, ?)""",
|
||||
(session_id, url, title, source_type, depth)
|
||||
)
|
||||
await self.db.commit()
|
||||
return cursor.lastrowid if cursor.rowcount > 0 else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_source(self, source_id: int, **kwargs):
|
||||
sets = ", ".join(f"{k} = ?" for k in kwargs)
|
||||
values = list(kwargs.values()) + [source_id]
|
||||
await self.db.execute(f"UPDATE sources SET {sets} WHERE id = ?", values)
|
||||
await self.db.commit()
|
||||
|
||||
async def get_pending_sources(self, session_id: int, limit: int = 10) -> list[dict]:
|
||||
cursor = await self.db.execute(
|
||||
"""SELECT * FROM sources WHERE session_id = ? AND status = 'pending'
|
||||
ORDER BY depth ASC, id ASC LIMIT ?""",
|
||||
(session_id, limit)
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
async def get_all_sources(self, session_id: int) -> list[dict]:
|
||||
cursor = await self.db.execute(
|
||||
"SELECT * FROM sources WHERE session_id = ? ORDER BY quality_score DESC",
|
||||
(session_id,)
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
async def source_exists(self, session_id: int, url: str) -> bool:
|
||||
cursor = await self.db.execute(
|
||||
"SELECT 1 FROM sources WHERE session_id = ? AND url = ?",
|
||||
(session_id, url)
|
||||
)
|
||||
return await cursor.fetchone() is not None
|
||||
|
||||
# --- Chunks ---
|
||||
|
||||
async def add_chunk(self, session_id: int, source_id: int, content: str,
|
||||
chunk_index: int, token_count: int, quality_score: float,
|
||||
embedding: Optional[list] = None) -> int:
|
||||
cursor = await self.db.execute(
|
||||
"""INSERT INTO chunks (session_id, source_id, content, chunk_index,
|
||||
token_count, quality_score, embedding, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(session_id, source_id, content, chunk_index,
|
||||
token_count, quality_score,
|
||||
json.dumps(embedding) if embedding else None,
|
||||
time.time())
|
||||
)
|
||||
await self.db.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
async def get_top_chunks(self, session_id: int, limit: int = 50) -> list[dict]:
|
||||
cursor = await self.db.execute(
|
||||
"""SELECT c.*, s.url, s.title, s.source_type FROM chunks c
|
||||
JOIN sources s ON c.source_id = s.id
|
||||
WHERE c.session_id = ? AND c.quality_score >= ?
|
||||
ORDER BY c.quality_score DESC LIMIT ?""",
|
||||
(session_id, settings.quality_threshold, limit)
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
async def get_chunks_count(self, session_id: int) -> int:
|
||||
cursor = await self.db.execute(
|
||||
"SELECT COUNT(*) FROM chunks WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return row[0]
|
||||
|
||||
# --- Outputs ---
|
||||
|
||||
async def save_output(self, session_id: int, output_type: str, content: str) -> int:
|
||||
cursor = await self.db.execute(
|
||||
"INSERT INTO outputs (session_id, output_type, content, created_at) VALUES (?, ?, ?, ?)",
|
||||
(session_id, output_type, content, time.time())
|
||||
)
|
||||
await self.db.commit()
|
||||
return cursor.lastrowid
|
||||
|
||||
async def save_source_content(self, source_id: int, content: str):
|
||||
await self.db.execute(
|
||||
"""INSERT OR REPLACE INTO source_contents (source_id, content, created_at)
|
||||
VALUES (?, ?, ?)""",
|
||||
(source_id, content, time.time())
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
async def get_source_content(self, source_id: int) -> Optional[str]:
|
||||
cursor = await self.db.execute(
|
||||
"SELECT content FROM source_contents WHERE source_id = ?", (source_id,)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
async def get_outputs(self, session_id: int) -> list[dict]:
|
||||
cursor = await self.db.execute(
|
||||
"SELECT * FROM outputs WHERE session_id = ? ORDER BY created_at DESC",
|
||||
(session_id,)
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
Reference in New Issue
Block a user