Files
chemavx adf2917cda
CI/CD / build-and-push (push) Successful in 1m52s
feat(attribution): dominant_feature per trade + /api/metrics/attribution
Adds alpha attribution by dominant signal feature — which feat_*_lo had
the largest absolute log-odds value on each trade.

Changes:
- _dominant_feature() helper in api/main.py: picks the winning feature
  from signal_components (threshold 0.0001, same as "triggered" in
  /api/metrics/features)
- _enrich_trade() refactored to single exit point; adds dominant_feature
  field to every open trade in /api/trades
- compute_attribution_from_db() in db.py: VALUES subquery finds dominant
  feature per trade in SQL, then aggregates trade_count/avg_edge_net/
  unrealized_pnl_est/realized_pnl/resolved_count/win_rate per group
- /api/metrics/attribution endpoint: returns attribution dict + total_attributed_trades

No schema changes, no strategy changes. Pure observability.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-22 16:35:24 +00:00

499 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Database layer using asyncpg for PostgreSQL."""
import logging
import os
from typing import Optional
import asyncpg
log = logging.getLogger(__name__)
class Database:
def __init__(self) -> None:
self._url = os.getenv("DATABASE_URL", "postgresql://bot:bot@localhost:5432/polymarket")
self._pool: Optional[asyncpg.Pool] = None
async def connect(self) -> None:
self._pool = await asyncpg.create_pool(self._url)
log.info("Database connected")
async def disconnect(self) -> None:
if self._pool:
await self._pool.close()
async def run_migrations(self) -> None:
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
with open(schema_path) as f:
schema = f.read()
async with self._pool.acquire() as conn:
await conn.execute(schema)
log.info("Migrations applied")
async def save_trade(self, trade) -> None:
async with self._pool.acquire() as conn:
await conn.execute("""
INSERT INTO trades (
id, market_id, question, direction, size_usdc,
entry_price, shares, fee_usdc, net_cost, timestamp, reasoning, paper,
edge_gross, edge_net, prior_prob, final_prob,
mid_price, spread_estimate, commission, family_key,
feat_fg_lo, feat_mom_lo, feat_news_lo, feat_mfld_lo, feat_btc_dom_lo
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,
$13,$14,$15,$16,$17,$18,$19,$20,
$21,$22,$23,$24,$25
)
ON CONFLICT (id) DO NOTHING
""",
trade.id, trade.market_id, trade.question, trade.direction,
trade.size_usdc, trade.entry_price, trade.shares, trade.fee_usdc,
trade.net_cost, trade.timestamp, trade.reasoning, trade.paper,
# Phase 1 fields
trade.edge_gross, trade.edge_net, trade.prior_prob, trade.final_prob,
trade.mid_price, trade.spread_estimate, trade.commission, trade.family_key,
# Phase 6 feature log-odds
trade.feat_fg_lo, trade.feat_mom_lo, trade.feat_news_lo,
trade.feat_mfld_lo, trade.feat_btc_dom_lo,
)
async def save_daily_metrics(self, metrics: dict) -> None:
async with self._pool.acquire() as conn:
await conn.execute("""
INSERT INTO metrics_daily (
timestamp, total_trades, total_deployed, total_fees,
unrealized_pnl_est, realized_pnl, total_pnl,
win_rate, avg_edge, sharpe_ratio, calibration_score, paper_mode,
open_count, closed_count, resolved_count
) VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15)
""",
metrics["timestamp"],
metrics["total_trades"],
metrics["total_deployed"],
metrics["total_fees"],
metrics["unrealized_pnl_est"],
metrics["realized_pnl"],
metrics["total_pnl"],
metrics["win_rate"],
metrics["avg_edge"],
metrics["sharpe_ratio"],
metrics["calibration_score"],
metrics["paper_mode"],
metrics["open_count"],
metrics["closed_count"],
metrics["resolved_count"],
)
async def get_open_positions(self) -> dict[str, float]:
"""Return {market_id: total_net_cost} for all open (not closed) trades in DB."""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"SELECT market_id, SUM(net_cost) AS total "
"FROM trades WHERE closed_at IS NULL GROUP BY market_id"
)
return {r["market_id"]: float(r["total"]) for r in rows}
async def get_open_position_data(self) -> tuple[dict[str, float], float]:
"""Return (positions_by_size_usdc, total_net_cost) for all open trades.
positions_by_size_usdc — {market_id: size_usdc} mirrors what live trading
stores in portfolio.positions (no fee included).
total_net_cost — SUM(net_cost) across all open trades, used to
reconstruct cash = bankroll total_net_cost.
Together these let initialize() replicate the exact same accounting model
that execute() uses at runtime, eliminating the phantom exposure overage
caused by the old net_cost-in-positions approach.
"""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"SELECT market_id, SUM(size_usdc) AS sz, SUM(net_cost) AS nc "
"FROM trades WHERE closed_at IS NULL GROUP BY market_id"
)
positions = {r["market_id"]: float(r["sz"]) for r in rows}
total_net_cost = sum(float(r["nc"]) for r in rows)
return positions, total_net_cost
async def get_open_families(self) -> set[str]:
"""Return the set of family_key values from all open positions.
Used at startup to rebuild occupied_families from DB state so the
family-deduplication logic survives pod restarts.
"""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"SELECT DISTINCT family_key FROM trades "
"WHERE family_key IS NOT NULL AND closed_at IS NULL"
)
return {r["family_key"] for r in rows if r["family_key"]}
async def get_open_position_details(self) -> list[dict]:
"""Return one row per open position with family_key and direction.
Used at startup to detect positions that share a family_key (same
underlying event), which indicates a contradictory paper trade entered
before the general-election family fix was deployed.
"""
async with self._pool.acquire() as conn:
rows = await conn.fetch("""
SELECT DISTINCT ON (market_id)
market_id, question, direction, edge_net, family_key, timestamp
FROM trades
WHERE paper = TRUE AND closed_at IS NULL
ORDER BY market_id, timestamp DESC
""")
return [dict(r) for r in rows]
async def close_paper_position(
self, market_id: str, reason: str = "", resolution: Optional[float] = None
) -> None:
"""Mark a paper position as closed.
resolution: 1.0 if YES resolved, 0.0 if NO resolved, None if unknown
(legacy closes, inversion fixes). When resolution is provided, close_pnl
is computed in SQL so it matches the stored entry_price and shares exactly.
"""
async with self._pool.acquire() as conn:
await conn.execute("""
UPDATE trades
SET closed_at = NOW(),
close_reason = $2,
resolution = $3,
close_pnl = CASE
WHEN $3 IS NOT NULL AND direction = 'BUY_YES'
THEN ($3::double precision - entry_price) * shares
WHEN $3 IS NOT NULL AND direction = 'BUY_NO'
THEN ((1.0 - $3::double precision) - entry_price) * shares
ELSE NULL
END
WHERE market_id = $1 AND closed_at IS NULL
""", market_id, reason, resolution)
async def update_family_key(self, market_id: str, new_key: str) -> None:
"""Persist a corrected family_key for all open trades of a market."""
async with self._pool.acquire() as conn:
await conn.execute(
"UPDATE trades SET family_key = $2 WHERE market_id = $1 AND closed_at IS NULL",
market_id, new_key,
)
async def get_legacy_incomplete_count(self) -> int:
"""Return count of open trades with NULL edge_net (legacy data without signal values)."""
async with self._pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT COUNT(*) FROM trades WHERE closed_at IS NULL AND edge_net IS NULL"
)
return int(row[0])
async def get_recently_closed_inverted(self, hours: int = 24) -> set[str]:
"""Return market_ids closed for inversion bug within the last N hours.
Used as a reentry guard: prevents re-entering a market that was just
closed because the signal direction was inverted.
"""
async with self._pool.acquire() as conn:
rows = await conn.fetch("""
SELECT DISTINCT market_id FROM trades
WHERE closed_at > NOW() - ($1 || ' hours')::interval
AND close_reason ILIKE '%inversion bug%'
""", str(hours))
return {r["market_id"] for r in rows}
async def compute_metrics_from_db(self) -> dict:
"""Compute all trading metrics directly from the trades table.
This is the single source of truth for MetricsTracker — no in-memory
state required. Safe to call after pod restarts: always reflects the
full DB history.
Returns a dict with keys:
total_trades, open_count, closed_count, resolved_count,
total_deployed, total_fees,
unrealized_pnl_est — estimated, open trades with edge_net
realized_pnl — exact, closed trades with resolution
wins_realized — closed trades where close_pnl > 0
calibration_score — Brier-based (1 MSE), null if resolved < 10
"""
async with self._pool.acquire() as conn:
row = await conn.fetchrow("""
SELECT
COUNT(*) AS total_trades,
COUNT(*) FILTER (WHERE closed_at IS NULL) AS open_count,
COUNT(*) FILTER (WHERE closed_at IS NOT NULL) AS closed_count,
COUNT(*) FILTER (WHERE resolution IS NOT NULL
AND final_prob IS NOT NULL) AS resolved_count,
COALESCE(SUM(net_cost)
FILTER (WHERE closed_at IS NULL), 0) AS total_deployed,
COALESCE(SUM(fee_usdc), 0) AS total_fees,
-- Estimated unrealized PnL: open trades with known edge.
-- Formula: edge_net × net_cost fee_usdc.
-- Trades with NULL edge_net (legacy data) are excluded.
COALESCE(SUM(edge_net * net_cost - fee_usdc)
FILTER (WHERE closed_at IS NULL
AND edge_net IS NOT NULL), 0) AS unrealized_pnl_est,
-- Realized PnL: closed trades with a known resolution.
-- close_pnl is computed at close time from actual resolution.
COALESCE(SUM(close_pnl)
FILTER (WHERE closed_at IS NOT NULL
AND close_pnl IS NOT NULL), 0) AS realized_pnl,
COUNT(*) FILTER (WHERE closed_at IS NOT NULL
AND close_pnl IS NOT NULL
AND close_pnl > 0) AS wins_realized,
-- Calibration (Brier score transformed to higher-is-better):
-- 1 AVG((final_prob resolution)²) on resolved trades.
-- final_prob is the model's estimated YES probability at entry.
-- resolution is 1.0 (YES won) or 0.0 (NO won).
-- Perfect calibration → 1.0 | Random → ~0.75 | Worst → 0.0
-- Returns NULL if fewer than 10 resolved trades with final_prob.
CASE
WHEN COUNT(*) FILTER (WHERE resolution IS NOT NULL
AND final_prob IS NOT NULL) >= 10
THEN 1.0 - AVG((final_prob - resolution) * (final_prob - resolution))
FILTER (WHERE resolution IS NOT NULL
AND final_prob IS NOT NULL)
ELSE NULL
END AS calibration_score
FROM trades
""")
return dict(row)
async def get_recent_trades(self, limit: int = 100, status: Optional[str] = None) -> list[dict]:
"""Return trades ordered by timestamp DESC.
status: None (all) | "open" (closed_at IS NULL) | "closed" (closed_at IS NOT NULL)
Each row includes a computed "status" field ("open" or "closed").
"""
if status == "open":
where = "WHERE closed_at IS NULL"
elif status == "closed":
where = "WHERE closed_at IS NOT NULL"
else:
where = ""
async with self._pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT * FROM trades {where} ORDER BY timestamp DESC LIMIT $1", limit
)
result = []
for r in rows:
d = dict(r)
d["status"] = "closed" if d.get("closed_at") else "open"
result.append(d)
return result
async def get_metrics_history(self, days: int = 42) -> list[dict]:
async with self._pool.acquire() as conn:
rows = await conn.fetch(
"SELECT * FROM metrics_daily ORDER BY timestamp DESC LIMIT $1", days
)
return [dict(r) for r in rows]
async def backfill_feature_columns(self) -> int:
"""Back-populate feat_*_lo for trades created before Phase 6.
Parses the reasoning string (format: 'fg=+0.0600 mom=... news=... mfld=...').
fg / mom raw values are multiplied by 2 to convert to log-odds.
news / mfld are already in log-odds (no scaling).
feat_btc_dom_lo cannot be recovered from the old reasoning string and
remains NULL for legacy trades.
Returns the number of rows updated.
"""
async with self._pool.acquire() as conn:
result = await conn.execute("""
UPDATE trades
SET
feat_fg_lo = ((regexp_match(reasoning, 'fg=([^ |]+)'))[1])::DOUBLE PRECISION * 2,
feat_mom_lo = ((regexp_match(reasoning, 'mom=([^ |]+)'))[1])::DOUBLE PRECISION * 2,
feat_news_lo = ((regexp_match(reasoning, 'news=([^ |]+)'))[1])::DOUBLE PRECISION,
feat_mfld_lo = ((regexp_match(reasoning, 'mfld=([^ |]+)'))[1])::DOUBLE PRECISION,
feat_btc_dom_lo = NULL
WHERE feat_fg_lo IS NULL
AND reasoning IS NOT NULL
AND reasoning LIKE '%fg=%'
AND reasoning NOT LIKE '%fg_lo=%'
""")
updated = int(result.split()[-1]) if result else 0
if updated:
log.info("backfill_feature_columns: updated %d trade(s)", updated)
return updated
async def get_legacy_incomplete_trades(self) -> list[dict]:
"""Return trades with NULL edge_net — pre-Phase-1 data with no signal quality info."""
async with self._pool.acquire() as conn:
rows = await conn.fetch("""
SELECT id, market_id, question, direction, net_cost, entry_price,
timestamp, reasoning, closed_at, close_reason, family_key,
feat_fg_lo, feat_mom_lo, feat_news_lo, feat_mfld_lo, feat_btc_dom_lo
FROM trades
WHERE edge_net IS NULL
ORDER BY timestamp DESC
""")
return [dict(r) for r in rows]
async def compute_feature_metrics_from_db(self) -> dict:
"""Per-feature performance metrics, all in log-odds space.
For each feature (fg, mom, news, mfld, btc_dom) returns:
unit — always "log_odds"
materiality_threshold — |lo| threshold for "material" classification
triggered_count — trades where |feat_lo| > 0.0001
material_count — trades where |feat_lo| >= materiality_threshold
avg_contribution_lo — mean signed lo value (triggered trades)
avg_abs_contribution_lo — mean absolute lo value (triggered trades)
avg_edge_net_when_material — mean edge_net for material trades
unrealized_pnl_est — sum edge_net*net_costfee for triggered open trades
realized_pnl — sum close_pnl for triggered resolved trades
resolved_count — closed trades with known outcome (triggered)
win_rate — NULL if resolved_count < 5
net_positive_count — triggered trades where feat_lo > 0
net_negative_count — triggered trades where feat_lo < 0
"""
async with self._pool.acquire() as conn:
rows = await conn.fetch("""
WITH feature_values AS (
SELECT 'fg' AS feature,
0.05::DOUBLE PRECISION AS mat_thresh,
feat_fg_lo AS fval,
edge_net, net_cost, fee_usdc, closed_at, close_pnl
FROM trades WHERE feat_fg_lo IS NOT NULL
UNION ALL
SELECT 'mom', 0.05, feat_mom_lo,
edge_net, net_cost, fee_usdc, closed_at, close_pnl
FROM trades WHERE feat_mom_lo IS NOT NULL
UNION ALL
SELECT 'news', 0.10, feat_news_lo,
edge_net, net_cost, fee_usdc, closed_at, close_pnl
FROM trades WHERE feat_news_lo IS NOT NULL
UNION ALL
SELECT 'mfld', 0.10, feat_mfld_lo,
edge_net, net_cost, fee_usdc, closed_at, close_pnl
FROM trades WHERE feat_mfld_lo IS NOT NULL
UNION ALL
SELECT 'btc_dom', 0.05, feat_btc_dom_lo,
edge_net, net_cost, fee_usdc, closed_at, close_pnl
FROM trades WHERE feat_btc_dom_lo IS NOT NULL
)
SELECT
feature,
mat_thresh AS materiality_threshold,
COUNT(*) FILTER (WHERE ABS(fval) > 0.0001) AS triggered_count,
COUNT(*) FILTER (WHERE ABS(fval) >= mat_thresh) AS material_count,
AVG(fval) FILTER (WHERE ABS(fval) > 0.0001) AS avg_contribution_lo,
AVG(ABS(fval)) FILTER (WHERE ABS(fval) > 0.0001) AS avg_abs_contribution_lo,
AVG(edge_net) FILTER (WHERE ABS(fval) >= mat_thresh
AND edge_net IS NOT NULL) AS avg_edge_net_when_material,
COALESCE(SUM(edge_net * net_cost - fee_usdc)
FILTER (WHERE ABS(fval) > 0.0001
AND closed_at IS NULL
AND edge_net IS NOT NULL), 0) AS unrealized_pnl_est,
COALESCE(SUM(close_pnl)
FILTER (WHERE ABS(fval) > 0.0001
AND close_pnl IS NOT NULL), 0) AS realized_pnl,
COUNT(*) FILTER (WHERE ABS(fval) > 0.0001
AND close_pnl IS NOT NULL
AND close_pnl > 0) AS wins_realized,
COUNT(*) FILTER (WHERE ABS(fval) > 0.0001
AND close_pnl IS NOT NULL) AS resolved_count,
COUNT(*) FILTER (WHERE fval > 0.0001) AS net_positive_count,
COUNT(*) FILTER (WHERE fval < -0.0001) AS net_negative_count
FROM feature_values
GROUP BY feature, mat_thresh
ORDER BY feature
""")
result: dict[str, dict] = {}
for r in rows:
d = dict(r)
feature = d["feature"]
resolved = int(d.get("resolved_count") or 0)
wins = int(d.get("wins_realized") or 0)
result[feature] = {
"unit": "log_odds",
"materiality_threshold": float(d["materiality_threshold"]),
"triggered_count": int(d.get("triggered_count") or 0),
"material_count": int(d.get("material_count") or 0),
"avg_contribution_lo": _f(d.get("avg_contribution_lo")),
"avg_abs_contribution_lo": _f(d.get("avg_abs_contribution_lo")),
"avg_edge_net_when_material": _f(d.get("avg_edge_net_when_material")),
"unrealized_pnl_est": float(d.get("unrealized_pnl_est") or 0),
"realized_pnl": float(d.get("realized_pnl") or 0),
"resolved_count": resolved,
"win_rate": (wins / resolved) if resolved >= 5 else None,
"net_positive_count": int(d.get("net_positive_count") or 0),
"net_negative_count": int(d.get("net_negative_count") or 0),
}
return result
async def compute_attribution_from_db(self) -> dict:
"""Alpha attribution grouped by dominant signal feature.
For each Phase 6 trade, the dominant feature is the feat_*_lo with the
largest absolute value (> 0.0001). Trades are then aggregated per group.
Returns {feature_name: {trade_count, avg_edge_net, unrealized_pnl_est,
realized_pnl, resolved_count, win_rate}}.
"none" group collects trades where all features are below threshold.
"""
async with self._pool.acquire() as conn:
rows = await conn.fetch("""
WITH dominant_per_trade AS (
SELECT
edge_net, net_cost, fee_usdc, closed_at, close_pnl,
(
SELECT key
FROM (VALUES
('fg', ABS(COALESCE(feat_fg_lo, 0))),
('mom', ABS(COALESCE(feat_mom_lo, 0))),
('news', ABS(COALESCE(feat_news_lo, 0))),
('mfld', ABS(COALESCE(feat_mfld_lo, 0))),
('btc_dom', ABS(COALESCE(feat_btc_dom_lo, 0)))
) AS t(key, val)
WHERE val > 0.0001
ORDER BY val DESC
LIMIT 1
) AS dominant
FROM trades
WHERE feat_fg_lo IS NOT NULL
)
SELECT
COALESCE(dominant, 'none') AS dominant_feature,
COUNT(*) AS trade_count,
AVG(edge_net) AS avg_edge_net,
COALESCE(SUM(edge_net * net_cost - fee_usdc)
FILTER (WHERE closed_at IS NULL
AND edge_net IS NOT NULL), 0) AS unrealized_pnl_est,
COALESCE(SUM(close_pnl)
FILTER (WHERE close_pnl IS NOT NULL), 0) AS realized_pnl,
COUNT(*) FILTER (WHERE close_pnl IS NOT NULL) AS resolved_count,
COUNT(*) FILTER (WHERE close_pnl IS NOT NULL AND close_pnl > 0) AS wins
FROM dominant_per_trade
GROUP BY dominant_feature
ORDER BY trade_count DESC
""")
result: dict[str, dict] = {}
for r in rows:
d = dict(r)
feature = d["dominant_feature"]
resolved = int(d.get("resolved_count") or 0)
wins = int(d.get("wins") or 0)
result[feature] = {
"trade_count": int(d["trade_count"]),
"avg_edge_net": _f(d.get("avg_edge_net")),
"unrealized_pnl_est": float(d.get("unrealized_pnl_est") or 0),
"realized_pnl": float(d.get("realized_pnl") or 0),
"resolved_count": resolved,
"win_rate": (wins / resolved) if resolved >= 5 else None,
}
return result
def _f(v) -> Optional[float]:
"""None-safe float cast for asyncpg Decimal/None values."""
return float(v) if v is not None else None