diff --git a/src/bot/bot.py b/src/bot/bot.py index 31de065..eb48759 100644 --- a/src/bot/bot.py +++ b/src/bot/bot.py @@ -78,6 +78,7 @@ async def cmd_start(update: Update, ctx: ContextTypes.DEFAULT_TYPE): "`/generate ` — Generate output (podcast|blog|report|thread)\n" "`/sources` — List all sources found\n" "`/outputs` — List generated outputs\n" + "`/costs` — Show API usage costs\n" "`/cancel` — Cancel current research\n" "`/help` — Show this message", parse_mode=ParseMode.MARKDOWN @@ -419,6 +420,57 @@ async def cmd_outputs(update: Update, ctx: ContextTypes.DEFAULT_TYPE): await db_conn.close() +async def cmd_costs(update: Update, ctx: ContextTypes.DEFAULT_TYPE): + if not is_authorized(update.effective_user.id): + return + + chat_id = update.effective_chat.id + db_conn = await get_db() + db = ResearchDB(db_conn) + + try: + cursor = await db_conn.execute( + "SELECT * FROM research_sessions WHERE telegram_chat_id = ? ORDER BY created_at DESC LIMIT 1", + (chat_id,) + ) + row = await cursor.fetchone() + if not row: + await update.message.reply_text("No sessions found.") + return + + session_id = row["id"] + topic = row["topic"] + + by_type = {r["call_type"]: r for r in await db.get_usage_stats(session_id)} + totals = await db.get_total_usage_stats() + + lines = [f"📊 *Costes ResearchOwl*\n"] + lines.append(f"Última sesión (`{topic}`):") + + session_total = 0.0 + for call_type, label in [("scoring", "Scoring"), ("generation", "Generación")]: + row_data = by_type.get(call_type) + if row_data: + calls = row_data["calls"] + tokens = row_data["total_tokens"] + cost = row_data["total_cost"] + session_total += cost + lines.append(f" {label}: {calls} llamadas · {tokens:,} tokens · ${cost:.4f}") + else: + lines.append(f" {label}: —") + + lines.append(f" Total: ${session_total:.4f}") + lines.append("") + lines.append("Acumulado total:") + acc_cost = totals.get("total_cost") or 0.0 + acc_sessions = totals.get("sessions") or 0 + lines.append(f" ${acc_cost:.4f} ({acc_sessions} sesiones)") + + await update.message.reply_text("\n".join(lines), parse_mode=ParseMode.MARKDOWN) + finally: + await db_conn.close() + + async def cmd_process(update: Update, ctx: ContextTypes.DEFAULT_TYPE): if not is_authorized(update.effective_user.id): return @@ -576,6 +628,7 @@ def create_bot() -> Application: app.add_handler(CommandHandler("generate", cmd_generate)) app.add_handler(CommandHandler("sources", cmd_sources)) app.add_handler(CommandHandler("outputs", cmd_outputs)) + app.add_handler(CommandHandler("costs", cmd_costs)) app.add_handler(CommandHandler("process", cmd_process)) app.add_handler(CommandHandler("cancel", cmd_cancel)) app.add_handler(CommandHandler("purge", cmd_purge)) diff --git a/src/db/database.py b/src/db/database.py index c8561f8..cebdf6a 100644 --- a/src/db/database.py +++ b/src/db/database.py @@ -88,6 +88,17 @@ CREATE INDEX IF NOT EXISTS idx_sources_session ON sources(session_id); CREATE INDEX IF NOT EXISTS idx_chunks_session ON chunks(session_id); CREATE INDEX IF NOT EXISTS idx_chunks_quality ON chunks(session_id, quality_score DESC); CREATE INDEX IF NOT EXISTS idx_source_contents ON source_contents(source_id); + +CREATE TABLE IF NOT EXISTS api_usage ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER REFERENCES research_sessions(id), + call_type TEXT NOT NULL, + model TEXT NOT NULL, + input_tokens INTEGER NOT NULL, + output_tokens INTEGER NOT NULL, + cost_usd REAL NOT NULL, + created_at REAL NOT NULL +); """ @@ -271,6 +282,43 @@ class ResearchDB: rows = await cursor.fetchall() return [dict(r) for r in rows] + # --- API Usage --- + + async def log_api_call(self, session_id, call_type: str, model: str, + input_tokens: int, output_tokens: int): + # Precios Claude Haiku (claude-haiku-4-5): + # input: $0.80 / 1M tokens output: $4.00 / 1M tokens + cost = (input_tokens * 0.80 + output_tokens * 4.00) / 1_000_000 + await self.db.execute( + """INSERT INTO api_usage + (session_id, call_type, model, input_tokens, output_tokens, cost_usd, created_at) + VALUES (?,?,?,?,?,?,?)""", + (session_id, call_type, model, input_tokens, output_tokens, cost, time.time()) + ) + await self.db.commit() + + async def get_usage_stats(self, session_id: int) -> list[dict]: + cursor = await self.db.execute( + """SELECT call_type, + COUNT(*) as calls, + SUM(input_tokens + output_tokens) as total_tokens, + SUM(cost_usd) as total_cost + FROM api_usage WHERE session_id = ? + GROUP BY call_type""", + (session_id,) + ) + rows = await cursor.fetchall() + return [dict(r) for r in rows] + + async def get_total_usage_stats(self) -> dict: + cursor = await self.db.execute( + """SELECT COUNT(DISTINCT session_id) as sessions, + SUM(cost_usd) as total_cost + FROM api_usage""" + ) + row = await cursor.fetchone() + return dict(row) if row else {"sessions": 0, "total_cost": 0} + # --- Maintenance --- async def purge_old_sessions(self, max_age_days: int = 30) -> dict: diff --git a/src/generator/generator.py b/src/generator/generator.py index 4d8d7ea..0039d70 100644 --- a/src/generator/generator.py +++ b/src/generator/generator.py @@ -179,7 +179,7 @@ class OutputGenerator: system = self._get_system(output_type) prompt = PROMPTS[output_type].format(topic=topic, context=context) - output = await self._generate(prompt, system, output_type) + output = await self._generate(prompt, system, output_type, session_id) # Add metadata header stats = await self.db.get_session_stats(session_id) @@ -192,12 +192,14 @@ class OutputGenerator: logger.info("Output generated", type=output_type, length=len(full_output)) return full_output - async def _generate(self, prompt: str, system: str, output_type: OutputType) -> str: + async def _generate(self, prompt: str, system: str, output_type: OutputType, + session_id: int | None = None) -> str: if settings.anthropic_api_key: - return await self._generate_with_claude(prompt, system, output_type) + return await self._generate_with_claude(prompt, system, output_type, session_id) return await self._generate_with_ollama(prompt, system) - async def _generate_with_claude(self, prompt: str, system: str, output_type: OutputType) -> str: + async def _generate_with_claude(self, prompt: str, system: str, output_type: OutputType, + session_id: int | None = None) -> str: import anthropic max_tokens = 4096 if output_type == OutputType.THREAD else 8192 try: @@ -208,6 +210,14 @@ class OutputGenerator: system=system, messages=[{"role": "user", "content": prompt}], ) + if session_id is not None: + try: + await self.db.log_api_call( + session_id, "generation", settings.claude_model, + msg.usage.input_tokens, msg.usage.output_tokens + ) + except Exception as log_err: + logger.warning("Failed to log API usage", error=str(log_err)) return msg.content[0].text.strip() except Exception as e: logger.warning("Claude generation failed, falling back to Ollama", error=str(e)) diff --git a/src/processor/processor.py b/src/processor/processor.py index 294ac0f..af7d8f3 100644 --- a/src/processor/processor.py +++ b/src/processor/processor.py @@ -182,7 +182,7 @@ class ContentProcessor: if words < 30: continue - quality = await self._score_quality(chunk, topic) + quality = await self._score_quality(chunk, topic, session_id) if quality < settings.quality_threshold: filtered_quality += 1 logger.debug("Chunk filtered by quality", source_id=source_id, @@ -215,13 +215,15 @@ class ContentProcessor: logger.info("Source processed", source_id=source_id, stored=stored) return stored - async def _score_quality(self, chunk: str, topic: str) -> float: + async def _score_quality(self, chunk: str, topic: str, + session_id: int | None = None) -> float: """Score 0-1 relevance to topic. Uses Claude Haiku if API key set, else Ollama.""" if settings.anthropic_api_key: - return await self._score_with_claude(chunk, topic) + return await self._score_with_claude(chunk, topic, session_id) return await self._score_with_ollama(chunk, topic) - async def _score_with_claude(self, chunk: str, topic: str) -> float: + async def _score_with_claude(self, chunk: str, topic: str, + session_id: int | None = None) -> float: import anthropic prompt = ( f'Rate 0-10 how relevant this text is to the topic "{topic}". ' @@ -234,6 +236,14 @@ class ContentProcessor: max_tokens=10, messages=[{"role": "user", "content": prompt}] ) + if session_id is not None: + try: + await self.db.log_api_call( + session_id, "scoring", settings.claude_model, + msg.usage.input_tokens, msg.usage.output_tokens + ) + except Exception as log_err: + logger.warning("Failed to log API usage", error=str(log_err)) response = msg.content[0].text.strip() numbers = re.findall(r'\b(\d+(?:\.\d+)?)\b', response) if numbers: