"""Database layer using asyncpg for PostgreSQL.""" import logging import os from typing import Optional import asyncpg log = logging.getLogger(__name__) class Database: def __init__(self) -> None: self._url = os.getenv("DATABASE_URL", "postgresql://bot:bot@localhost:5432/polymarket") self._pool: Optional[asyncpg.Pool] = None async def connect(self) -> None: self._pool = await asyncpg.create_pool(self._url) log.info("Database connected") async def disconnect(self) -> None: if self._pool: await self._pool.close() async def run_migrations(self) -> None: schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") with open(schema_path) as f: schema = f.read() async with self._pool.acquire() as conn: await conn.execute(schema) log.info("Migrations applied") async def save_trade(self, trade) -> None: async with self._pool.acquire() as conn: await conn.execute(""" INSERT INTO trades ( id, market_id, question, direction, size_usdc, entry_price, shares, fee_usdc, net_cost, timestamp, reasoning, paper, edge_gross, edge_net, prior_prob, final_prob, mid_price, spread_estimate, commission, family_key ) VALUES ( $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12, $13,$14,$15,$16,$17,$18,$19,$20 ) ON CONFLICT (id) DO NOTHING """, trade.id, trade.market_id, trade.question, trade.direction, trade.size_usdc, trade.entry_price, trade.shares, trade.fee_usdc, trade.net_cost, trade.timestamp, trade.reasoning, trade.paper, # Phase 1 fields trade.edge_gross, trade.edge_net, trade.prior_prob, trade.final_prob, trade.mid_price, trade.spread_estimate, trade.commission, trade.family_key, ) async def save_daily_metrics(self, metrics: dict) -> None: async with self._pool.acquire() as conn: await conn.execute(""" INSERT INTO metrics_daily ( timestamp, total_trades, total_deployed, total_fees, total_pnl, win_rate, avg_edge, sharpe_ratio, calibration_score, paper_mode ) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) """, metrics["timestamp"], metrics["total_trades"], metrics["total_deployed"], metrics["total_fees"], metrics["total_pnl"], metrics["win_rate"], metrics["avg_edge"], metrics["sharpe_ratio"], metrics["calibration_score"], metrics["paper_mode"], ) async def get_open_positions(self) -> dict[str, float]: """Return {market_id: total_net_cost} for all open (not closed) trades in DB.""" async with self._pool.acquire() as conn: rows = await conn.fetch( "SELECT market_id, SUM(net_cost) AS total " "FROM trades WHERE closed_at IS NULL GROUP BY market_id" ) return {r["market_id"]: float(r["total"]) for r in rows} async def get_open_families(self) -> set[str]: """Return the set of family_key values from all open positions. Used at startup to rebuild occupied_families from DB state so the family-deduplication logic survives pod restarts. """ async with self._pool.acquire() as conn: rows = await conn.fetch( "SELECT DISTINCT family_key FROM trades " "WHERE family_key IS NOT NULL AND closed_at IS NULL" ) return {r["family_key"] for r in rows if r["family_key"]} async def get_open_position_details(self) -> list[dict]: """Return one row per open position with family_key and direction. Used at startup to detect positions that share a family_key (same underlying event), which indicates a contradictory paper trade entered before the general-election family fix was deployed. """ async with self._pool.acquire() as conn: rows = await conn.fetch(""" SELECT DISTINCT ON (market_id) market_id, question, direction, edge_net, family_key, timestamp FROM trades WHERE paper = TRUE AND closed_at IS NULL ORDER BY market_id, timestamp DESC """) return [dict(r) for r in rows] async def close_paper_position(self, market_id: str, reason: str = "") -> None: """Mark a paper position as closed (sets closed_at timestamp).""" async with self._pool.acquire() as conn: await conn.execute( "UPDATE trades SET closed_at = NOW(), close_reason = $2 " "WHERE market_id = $1 AND closed_at IS NULL", market_id, reason, ) async def update_family_key(self, market_id: str, new_key: str) -> None: """Persist a corrected family_key for all open trades of a market.""" async with self._pool.acquire() as conn: await conn.execute( "UPDATE trades SET family_key = $2 WHERE market_id = $1 AND closed_at IS NULL", market_id, new_key, ) async def get_recently_closed_inverted(self, hours: int = 24) -> set[str]: """Return market_ids closed for inversion bug within the last N hours. Used as a reentry guard: prevents re-entering a market that was just closed because the signal direction was inverted. """ async with self._pool.acquire() as conn: rows = await conn.fetch(""" SELECT DISTINCT market_id FROM trades WHERE closed_at > NOW() - ($1 || ' hours')::interval AND close_reason ILIKE '%inversion bug%' """, str(hours)) return {r["market_id"] for r in rows} async def get_recent_trades(self, limit: int = 100, status: Optional[str] = None) -> list[dict]: """Return trades ordered by timestamp DESC. status: None (all) | "open" (closed_at IS NULL) | "closed" (closed_at IS NOT NULL) Each row includes a computed "status" field ("open" or "closed"). """ if status == "open": where = "WHERE closed_at IS NULL" elif status == "closed": where = "WHERE closed_at IS NOT NULL" else: where = "" async with self._pool.acquire() as conn: rows = await conn.fetch( f"SELECT * FROM trades {where} ORDER BY timestamp DESC LIMIT $1", limit ) result = [] for r in rows: d = dict(r) d["status"] = "closed" if d.get("closed_at") else "open" result.append(d) return result async def get_metrics_history(self, days: int = 42) -> list[dict]: async with self._pool.acquire() as conn: rows = await conn.fetch( "SELECT * FROM metrics_daily ORDER BY timestamp DESC LIMIT $1", days ) return [dict(r) for r in rows]