|
|
""" |
|
|
main.py |
|
|
Production-Ready Real-Time Intelligence Platform Backend |
|
|
- Uses combinedAgentGraph for multi-agent orchestration |
|
|
- Threading for concurrent graph execution and WebSocket server |
|
|
- Database-driven feed updates with polling |
|
|
- Duplicate prevention |
|
|
- District-based feed categorization for map display |
|
|
|
|
|
Updated: Resilient WebSocket handling for long scraping operations (60s+ cycles) |
|
|
""" |
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from typing import Dict, Any, List, Set, Optional |
|
|
import asyncio |
|
|
import json |
|
|
from datetime import datetime, timedelta, timezone |
|
|
import sys |
|
|
import os |
|
|
import logging |
|
|
import threading |
|
|
import time |
|
|
import uuid |
|
|
|
|
|
|
|
|
def utc_now() -> datetime: |
|
|
"""Return current UTC time (Python 3.12+ compatible).""" |
|
|
return datetime.now(timezone.utc) |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) |
|
|
|
|
|
from src.graphs.combinedAgentGraph import graph |
|
|
from src.states.combinedAgentState import CombinedAgentState |
|
|
from src.storage.storage_manager import StorageManager |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger("Roger_api") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_train_models(): |
|
|
""" |
|
|
Check if ML models are trained. If not, trigger training in background. |
|
|
Called on startup to ensure models are available. |
|
|
""" |
|
|
from pathlib import Path |
|
|
import subprocess |
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent |
|
|
|
|
|
|
|
|
model_checks = [ |
|
|
{ |
|
|
"name": "Anomaly Detection", |
|
|
"check_paths": [ |
|
|
PROJECT_ROOT / "models" / "anomaly-detection" / "artifacts" / "models", |
|
|
], |
|
|
"check_files": ["*.joblib", "*.pkl"], |
|
|
"train_cmd": [ |
|
|
sys.executable, |
|
|
str(PROJECT_ROOT / "models" / "anomaly-detection" / "main.py") |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"name": "Weather Prediction", |
|
|
"check_paths": [ |
|
|
PROJECT_ROOT / "models" / "weather-prediction" / "artifacts" / "models", |
|
|
], |
|
|
"check_files": ["*.h5", "*.keras"], |
|
|
"train_cmd": [ |
|
|
sys.executable, |
|
|
str(PROJECT_ROOT / "models" / "weather-prediction" / "main.py"), |
|
|
"--mode", "full" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"name": "Currency Prediction", |
|
|
"check_paths": [ |
|
|
PROJECT_ROOT / "models" / "currency-volatility-prediction" |
|
|
/ "artifacts" / "models", |
|
|
], |
|
|
"check_files": ["*.h5", "*.keras"], |
|
|
"train_cmd": [ |
|
|
sys.executable, |
|
|
str(PROJECT_ROOT / "models" / "currency-volatility-prediction" |
|
|
/ "main.py"), |
|
|
"--mode", "full" |
|
|
] |
|
|
}, |
|
|
{ |
|
|
"name": "Stock Prediction", |
|
|
"check_paths": [ |
|
|
PROJECT_ROOT / "models" / "stock-price-prediction" |
|
|
/ "Artifacts", |
|
|
], |
|
|
"check_files": ["*.pkl", "*.h5", "*.keras"], |
|
|
"train_cmd": [ |
|
|
sys.executable, |
|
|
str(PROJECT_ROOT / "models" / "stock-price-prediction" |
|
|
/ "main.py") |
|
|
] |
|
|
}, |
|
|
] |
|
|
|
|
|
def has_trained_model(check_paths, check_files): |
|
|
"""Check if any trained model files exist.""" |
|
|
for path in check_paths: |
|
|
if path.exists(): |
|
|
for pattern in check_files: |
|
|
if list(path.glob(pattern)): |
|
|
return True |
|
|
|
|
|
if list(path.glob(f"**/{pattern}")): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def train_in_background(name, cmd): |
|
|
"""Run training in a background thread.""" |
|
|
def _train(): |
|
|
logger.info(f"[AUTO-TRAIN] Starting {name} training...") |
|
|
try: |
|
|
result = subprocess.run( |
|
|
cmd, |
|
|
cwd=str(PROJECT_ROOT), |
|
|
capture_output=True, |
|
|
text=True, |
|
|
timeout=1800 |
|
|
) |
|
|
if result.returncode == 0: |
|
|
logger.info(f"[AUTO-TRAIN] ✓ {name} training complete!") |
|
|
else: |
|
|
logger.warning(f"[AUTO-TRAIN] ⚠ {name} training failed: {result.stderr[:500]}") |
|
|
except subprocess.TimeoutExpired: |
|
|
logger.error(f"[AUTO-TRAIN] ✗ {name} training timed out (30 min)") |
|
|
except Exception as e: |
|
|
logger.error(f"[AUTO-TRAIN] ✗ {name} training error: {e}") |
|
|
|
|
|
thread = threading.Thread(target=_train, daemon=True, name=f"train_{name}") |
|
|
thread.start() |
|
|
return thread |
|
|
|
|
|
|
|
|
training_threads = [] |
|
|
for model in model_checks: |
|
|
if has_trained_model(model["check_paths"], model["check_files"]): |
|
|
logger.info(f"[MODEL CHECK] ✓ {model['name']} - Model found") |
|
|
else: |
|
|
logger.warning(f"[MODEL CHECK] ⚠ {model['name']} - No model found, starting training...") |
|
|
thread = train_in_background(model["name"], model["train_cmd"]) |
|
|
training_threads.append((model["name"], thread)) |
|
|
|
|
|
if training_threads: |
|
|
logger.info(f"[AUTO-TRAIN] Started {len(training_threads)} background training jobs") |
|
|
else: |
|
|
logger.info("[MODEL CHECK] All models found - no training needed") |
|
|
|
|
|
return training_threads |
|
|
|
|
|
|
|
|
|
|
|
logger.info("=" * 60) |
|
|
logger.info("[STARTUP] Checking ML models...") |
|
|
logger.info("=" * 60) |
|
|
_training_threads = check_and_train_models() |
|
|
|
|
|
app = FastAPI(title="Roger Intelligence Platform API") |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
current_state: Dict[str, Any] = { |
|
|
"final_ranked_feed": [], |
|
|
"risk_dashboard_snapshot": { |
|
|
"logistics_friction": 0.0, |
|
|
"compliance_volatility": 0.0, |
|
|
"market_instability": 0.0, |
|
|
"opportunity_index": 0.0, |
|
|
"avg_confidence": 0.0, |
|
|
"high_priority_count": 0, |
|
|
"total_events": 0, |
|
|
"last_updated": utc_now().isoformat() |
|
|
}, |
|
|
"run_count": 0, |
|
|
"status": "initializing", |
|
|
"first_run_complete": False |
|
|
} |
|
|
|
|
|
|
|
|
feed_update_queue = asyncio.Queue() |
|
|
seen_event_ids: Set[str] = set() |
|
|
|
|
|
|
|
|
main_event_loop = None |
|
|
|
|
|
|
|
|
storage_manager = StorageManager() |
|
|
|
|
|
|
|
|
|
|
|
HEARTBEAT_INTERVAL = 60.0 |
|
|
HEARTBEAT_TIMEOUT = 45.0 |
|
|
HEARTBEAT_MISS_THRESHOLD = 5 |
|
|
SEND_TIMEOUT = 15.0 |
|
|
|
|
|
class ConnectionManager: |
|
|
"""Manages active WebSocket with heartbeat""" |
|
|
def __init__(self): |
|
|
self.active_connections: Dict[WebSocket, Dict[str, Any]] = {} |
|
|
self._lock = asyncio.Lock() |
|
|
|
|
|
async def connect(self, websocket: WebSocket): |
|
|
await websocket.accept() |
|
|
async with self._lock: |
|
|
meta = { |
|
|
"heartbeat_task": asyncio.create_task(self._heartbeat_loop(websocket)), |
|
|
"last_pong": utc_now(), |
|
|
"misses": 0 |
|
|
} |
|
|
self.active_connections[websocket] = meta |
|
|
logger.info(f"[WebSocket] Connected. Total: {len(self.active_connections)}") |
|
|
|
|
|
async def disconnect(self, websocket: WebSocket): |
|
|
async with self._lock: |
|
|
meta = self.active_connections.pop(websocket, None) |
|
|
if meta: |
|
|
task = meta.get("heartbeat_task") |
|
|
if task and not task.done(): |
|
|
task.cancel() |
|
|
try: |
|
|
await task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
try: |
|
|
await websocket.close() |
|
|
except Exception: |
|
|
pass |
|
|
logger.info(f"[WebSocket] Disconnected. Total: {len(self.active_connections)}") |
|
|
|
|
|
async def _send_with_timeout(self, websocket: WebSocket, message_json: str): |
|
|
try: |
|
|
await asyncio.wait_for(websocket.send_text(message_json), timeout=SEND_TIMEOUT) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.debug(f"[WebSocket] Send failed: {e}") |
|
|
return False |
|
|
|
|
|
async def _heartbeat_loop(self, websocket: WebSocket): |
|
|
"""Per-connection heartbeat task""" |
|
|
try: |
|
|
while True: |
|
|
await asyncio.sleep(HEARTBEAT_INTERVAL) |
|
|
if websocket not in self.active_connections: |
|
|
break |
|
|
|
|
|
ping_payload = json.dumps({"type": "ping"}) |
|
|
ok = await self._send_with_timeout(websocket, ping_payload) |
|
|
if not ok: |
|
|
async with self._lock: |
|
|
meta = self.active_connections.get(websocket) |
|
|
if meta is not None: |
|
|
meta['misses'] += 1 |
|
|
else: |
|
|
waited = 0.0 |
|
|
sleep_step = 0.5 |
|
|
pong_received = False |
|
|
while waited < HEARTBEAT_TIMEOUT: |
|
|
await asyncio.sleep(sleep_step) |
|
|
waited += sleep_step |
|
|
async with self._lock: |
|
|
meta = self.active_connections.get(websocket) |
|
|
if meta is None: |
|
|
return |
|
|
last_pong = meta.get("last_pong") |
|
|
if last_pong and (utc_now() - last_pong).total_seconds() < (HEARTBEAT_INTERVAL + HEARTBEAT_TIMEOUT): |
|
|
pong_received = True |
|
|
meta['misses'] = 0 |
|
|
break |
|
|
if not pong_received: |
|
|
async with self._lock: |
|
|
meta = self.active_connections.get(websocket) |
|
|
if meta is not None: |
|
|
meta['misses'] += 1 |
|
|
|
|
|
async with self._lock: |
|
|
meta = self.active_connections.get(websocket) |
|
|
if meta is None: |
|
|
return |
|
|
if meta.get('misses', 0) >= HEARTBEAT_MISS_THRESHOLD: |
|
|
logger.warning("[WebSocket] Miss threshold exceeded, disconnecting") |
|
|
try: |
|
|
await websocket.close(code=1001) |
|
|
except Exception: |
|
|
pass |
|
|
await self.disconnect(websocket) |
|
|
return |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
return |
|
|
except Exception as e: |
|
|
logger.exception(f"[WebSocket] Heartbeat error: {e}") |
|
|
try: |
|
|
await self.disconnect(websocket) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
async def broadcast(self, message: dict): |
|
|
"""Broadcast to all connections""" |
|
|
async with self._lock: |
|
|
conns = list(self.active_connections.keys()) |
|
|
if not conns: |
|
|
return |
|
|
message_json = json.dumps(message, default=str) |
|
|
dead: List[WebSocket] = [] |
|
|
for conn in conns: |
|
|
ok = await self._send_with_timeout(conn, message_json) |
|
|
if not ok: |
|
|
dead.append(conn) |
|
|
for conn in dead: |
|
|
logger.info("[WebSocket] Removing dead connection") |
|
|
await self.disconnect(conn) |
|
|
|
|
|
manager = ConnectionManager() |
|
|
|
|
|
|
|
|
def categorize_feed_by_district(feed: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Categorize feed by Sri Lankan district based on summary text. |
|
|
Returns district name or "National" if not district-specific. |
|
|
NOTE: This returns the FIRST match. Use get_all_matching_districts() for multi-district feeds. |
|
|
""" |
|
|
districts = get_all_matching_districts(feed) |
|
|
return districts[0] if districts else "National" |
|
|
|
|
|
|
|
|
def get_all_matching_districts(feed: Dict[str, Any]) -> List[str]: |
|
|
""" |
|
|
Get ALL districts mentioned in a feed (direct or via province). |
|
|
|
|
|
Supports: |
|
|
- Direct district names (Colombo, Kandy, etc.) |
|
|
- Province names that map to multiple districts |
|
|
- Commonly referenced regions |
|
|
|
|
|
Returns list of all matching district names. |
|
|
""" |
|
|
summary = feed.get("summary", "").lower() |
|
|
|
|
|
|
|
|
districts = [ |
|
|
"Colombo", "Gampaha", "Kalutara", "Kandy", "Matale", "Nuwara Eliya", |
|
|
"Galle", "Matara", "Hambantota", "Jaffna", "Kilinochchi", "Mannar", |
|
|
"Vavuniya", "Mullaitivu", "Batticaloa", "Ampara", "Trincomalee", |
|
|
"Kurunegala", "Puttalam", "Anuradhapura", "Polonnaruwa", "Badulla", |
|
|
"Moneragala", "Ratnapura", "Kegalle" |
|
|
] |
|
|
|
|
|
|
|
|
province_mapping = { |
|
|
"western province": ["Colombo", "Gampaha", "Kalutara"], |
|
|
"western": ["Colombo", "Gampaha", "Kalutara"], |
|
|
"central province": ["Kandy", "Matale", "Nuwara Eliya"], |
|
|
"central": ["Kandy", "Matale", "Nuwara Eliya"], |
|
|
"southern province": ["Galle", "Matara", "Hambantota"], |
|
|
"southern provinces": ["Galle", "Matara", "Hambantota"], |
|
|
"southern": ["Galle", "Matara", "Hambantota"], |
|
|
"south": ["Galle", "Matara", "Hambantota"], |
|
|
"northern province": ["Jaffna", "Kilinochchi", "Mannar", "Vavuniya", "Mullaitivu"], |
|
|
"northern": ["Jaffna", "Kilinochchi", "Mannar", "Vavuniya", "Mullaitivu"], |
|
|
"north": ["Jaffna", "Kilinochchi", "Mannar", "Vavuniya", "Mullaitivu"], |
|
|
"eastern province": ["Batticaloa", "Ampara", "Trincomalee"], |
|
|
"eastern": ["Batticaloa", "Ampara", "Trincomalee"], |
|
|
"east": ["Batticaloa", "Ampara", "Trincomalee"], |
|
|
"north western province": ["Kurunegala", "Puttalam"], |
|
|
"north western": ["Kurunegala", "Puttalam"], |
|
|
"north central province": ["Anuradhapura", "Polonnaruwa"], |
|
|
"north central": ["Anuradhapura", "Polonnaruwa"], |
|
|
"uva province": ["Badulla", "Moneragala"], |
|
|
"uva": ["Badulla", "Moneragala"], |
|
|
"sabaragamuwa province": ["Ratnapura", "Kegalle"], |
|
|
"sabaragamuwa": ["Ratnapura", "Kegalle"], |
|
|
} |
|
|
|
|
|
matched_districts = set() |
|
|
|
|
|
|
|
|
for province, province_districts in province_mapping.items(): |
|
|
if province in summary: |
|
|
matched_districts.update(province_districts) |
|
|
|
|
|
|
|
|
for district in districts: |
|
|
if district.lower() in summary: |
|
|
matched_districts.add(district) |
|
|
|
|
|
return list(matched_districts) |
|
|
|
|
|
|
|
|
def run_graph_loop(): |
|
|
""" |
|
|
Graph execution in separate thread. |
|
|
Runs the combinedAgentGraph every 60 seconds (non-blocking pattern). |
|
|
|
|
|
UPDATED: Graph now runs single cycles and this loop handles the 60s interval |
|
|
externally, making the pattern non-blocking and interruptible. |
|
|
""" |
|
|
REFRESH_INTERVAL_SECONDS = 60 |
|
|
shutdown_event = threading.Event() |
|
|
|
|
|
logger.info("="*80) |
|
|
logger.info("[GRAPH THREAD] Starting Roger combinedAgentGraph loop (60s interval)") |
|
|
logger.info("="*80) |
|
|
|
|
|
cycle_count = 0 |
|
|
|
|
|
while not shutdown_event.is_set(): |
|
|
cycle_count += 1 |
|
|
cycle_start = time.time() |
|
|
|
|
|
logger.info(f"[GRAPH THREAD] Starting cycle #{cycle_count}") |
|
|
|
|
|
initial_state = CombinedAgentState( |
|
|
domain_insights=[], |
|
|
final_ranked_feed=[], |
|
|
run_count=cycle_count, |
|
|
max_runs=1, |
|
|
route=None |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
config = {"recursion_limit": 100} |
|
|
for event in graph.stream(initial_state, config=config): |
|
|
logger.info(f"[GRAPH] Event nodes: {list(event.keys())}") |
|
|
|
|
|
for node_name, node_output in event.items(): |
|
|
|
|
|
if hasattr(node_output, 'final_ranked_feed'): |
|
|
feeds = node_output.final_ranked_feed |
|
|
elif isinstance(node_output, dict): |
|
|
feeds = node_output.get('final_ranked_feed', []) |
|
|
else: |
|
|
continue |
|
|
|
|
|
if feeds: |
|
|
logger.info(f"[GRAPH] {node_name} produced {len(feeds)} feeds") |
|
|
|
|
|
|
|
|
for feed_item in feeds: |
|
|
if isinstance(feed_item, dict): |
|
|
event_data = feed_item |
|
|
else: |
|
|
event_data = feed_item.__dict__ if hasattr(feed_item, '__dict__') else {} |
|
|
|
|
|
|
|
|
event_id = event_data.get("event_id", str(uuid.uuid4())) |
|
|
summary = event_data.get("content_summary") or event_data.get("summary", "") |
|
|
domain = event_data.get("target_agent") or event_data.get("domain", "unknown") |
|
|
severity = event_data.get("severity", "medium") |
|
|
impact_type = event_data.get("impact_type", "risk") |
|
|
confidence = event_data.get("confidence_score", event_data.get("confidence", 0.5)) |
|
|
timestamp = event_data.get("timestamp", utc_now().isoformat()) |
|
|
|
|
|
|
|
|
is_dup, _, _ = storage_manager.is_duplicate(summary) |
|
|
|
|
|
if not is_dup: |
|
|
try: |
|
|
storage_manager.store_event( |
|
|
event_id=event_id, |
|
|
summary=summary, |
|
|
domain=domain, |
|
|
severity=severity, |
|
|
impact_type=impact_type, |
|
|
confidence_score=confidence |
|
|
) |
|
|
logger.info(f"[GRAPH] Stored new feed: {summary[:60]}...") |
|
|
except Exception as storage_error: |
|
|
logger.warning(f"[GRAPH] Storage error (continuing): {storage_error}") |
|
|
|
|
|
|
|
|
if not current_state.get('first_run_complete'): |
|
|
current_state['first_run_complete'] = True |
|
|
current_state['status'] = 'operational' |
|
|
logger.info("[GRAPH] FIRST RUN COMPLETE - Broadcasting to frontend!") |
|
|
|
|
|
|
|
|
if main_event_loop: |
|
|
asyncio.run_coroutine_threadsafe( |
|
|
manager.broadcast(current_state), |
|
|
main_event_loop |
|
|
) |
|
|
|
|
|
except RuntimeError as e: |
|
|
if "cannot schedule new futures after interpreter shutdown" in str(e): |
|
|
logger.warning("[GRAPH THREAD] Interpreter shutting down, stopping graph loop gracefully") |
|
|
break |
|
|
else: |
|
|
logger.error(f"[GRAPH THREAD] RuntimeError in cycle #{cycle_count}: {e}", exc_info=True) |
|
|
except Exception as e: |
|
|
logger.error(f"[GRAPH THREAD] Error in cycle #{cycle_count}: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
cycle_duration = time.time() - cycle_start |
|
|
logger.info(f"[GRAPH THREAD] Cycle #{cycle_count} completed in {cycle_duration:.1f}s") |
|
|
|
|
|
|
|
|
wait_time = max(0, REFRESH_INTERVAL_SECONDS - cycle_duration) |
|
|
if wait_time > 0: |
|
|
logger.info(f"[GRAPH THREAD] Waiting {wait_time:.1f}s before next cycle...") |
|
|
|
|
|
shutdown_event.wait(timeout=wait_time) |
|
|
|
|
|
logger.info("[GRAPH THREAD] Graph loop stopped") |
|
|
|
|
|
|
|
|
|
|
|
async def database_polling_loop(): |
|
|
""" |
|
|
Polls database for new feeds and broadcasts via WebSocket. |
|
|
Runs concurrently with graph thread. |
|
|
""" |
|
|
global current_state |
|
|
last_check = utc_now() |
|
|
|
|
|
logger.info("[DB_POLLER] Starting database polling loop") |
|
|
|
|
|
while True: |
|
|
try: |
|
|
await asyncio.sleep(2.0) |
|
|
|
|
|
|
|
|
new_feeds = storage_manager.get_feeds_since(last_check) |
|
|
last_check = utc_now() |
|
|
|
|
|
if new_feeds: |
|
|
logger.info(f"[DB_POLLER] Found {len(new_feeds)} new feeds") |
|
|
|
|
|
|
|
|
unique_feeds = [] |
|
|
for feed in new_feeds: |
|
|
event_id = feed.get("event_id") |
|
|
if event_id and event_id not in seen_event_ids: |
|
|
seen_event_ids.add(event_id) |
|
|
|
|
|
|
|
|
feed["district"] = categorize_feed_by_district(feed) |
|
|
unique_feeds.append(feed) |
|
|
|
|
|
if unique_feeds: |
|
|
|
|
|
current_state['final_ranked_feed'] = unique_feeds + current_state.get('final_ranked_feed', []) |
|
|
current_state['final_ranked_feed'] = current_state['final_ranked_feed'][:100] |
|
|
current_state['status'] = 'operational' |
|
|
current_state['last_update'] = utc_now().isoformat() |
|
|
|
|
|
|
|
|
if not current_state.get('first_run_complete'): |
|
|
current_state['first_run_complete'] = True |
|
|
logger.info("[DB_POLLER] First graph run complete! Frontend loading screen can now hide.") |
|
|
|
|
|
|
|
|
await manager.broadcast(current_state) |
|
|
logger.info(f"[DB_POLLER] Broadcasted {len(unique_feeds)} unique feeds") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[DB_POLLER] Error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
global main_event_loop |
|
|
main_event_loop = asyncio.get_event_loop() |
|
|
|
|
|
logger.info("[API] Starting Roger API...") |
|
|
|
|
|
|
|
|
graph_thread = threading.Thread(target=run_graph_loop, daemon=True) |
|
|
graph_thread.start() |
|
|
logger.info("[API] Graph thread started") |
|
|
|
|
|
|
|
|
asyncio.create_task(database_polling_loop()) |
|
|
logger.info("[API] Database polling started") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def read_root(): |
|
|
return { |
|
|
"service": "Roger Intelligence Platform", |
|
|
"status": current_state.get("status"), |
|
|
"version": "2.0.0 (Database-Driven)" |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/status") |
|
|
def get_status(): |
|
|
return { |
|
|
"status": current_state.get("status"), |
|
|
"run_count": current_state.get("run_count"), |
|
|
"last_update": current_state.get("last_update"), |
|
|
"active_connections": len(manager.active_connections), |
|
|
"total_events": len(current_state.get("final_ranked_feed", [])) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/dashboard") |
|
|
def get_dashboard(): |
|
|
return current_state.get("risk_dashboard_snapshot", {}) |
|
|
|
|
|
|
|
|
@app.get("/api/feed") |
|
|
def get_feed(): |
|
|
"""Get current feed from memory""" |
|
|
return { |
|
|
"events": current_state.get("final_ranked_feed", []), |
|
|
"total": len(current_state.get("final_ranked_feed", [])) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/feeds") |
|
|
def get_feeds_from_db(limit: int = 100): |
|
|
"""Get feeds directly from database (for initial load)""" |
|
|
try: |
|
|
feeds = storage_manager.get_recent_feeds(limit=limit) |
|
|
|
|
|
|
|
|
normalized_feeds = [] |
|
|
for feed in feeds: |
|
|
|
|
|
normalized = { |
|
|
"event_id": feed.get("event_id"), |
|
|
"summary": feed.get("summary", ""), |
|
|
"domain": feed.get("domain", "unknown"), |
|
|
"severity": feed.get("severity", "medium"), |
|
|
"impact_type": feed.get("impact_type", "risk"), |
|
|
"confidence": feed.get("confidence", 0.5), |
|
|
"timestamp": feed.get("timestamp"), |
|
|
"district": categorize_feed_by_district(feed) |
|
|
} |
|
|
normalized_feeds.append(normalized) |
|
|
|
|
|
return { |
|
|
"events": normalized_feeds, |
|
|
"total": len(normalized_feeds), |
|
|
"source": "database" |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching feeds: {e}") |
|
|
return {"events": [], "total": 0, "error": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/feeds/by_district/{district}") |
|
|
def get_feeds_by_district(district: str, limit: int = 50): |
|
|
"""Get feeds for specific district""" |
|
|
try: |
|
|
all_feeds = storage_manager.get_recent_feeds(limit=200) |
|
|
|
|
|
|
|
|
district_feeds = [] |
|
|
for feed in all_feeds: |
|
|
feed["district"] = categorize_feed_by_district(feed) |
|
|
if feed["district"].lower() == district.lower(): |
|
|
district_feeds.append(feed) |
|
|
if len(district_feeds) >= limit: |
|
|
break |
|
|
|
|
|
return { |
|
|
"district": district, |
|
|
"events": district_feeds, |
|
|
"total": len(district_feeds) |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching district feeds: {e}") |
|
|
return {"events": [], "total": 0, "error": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/rivernet") |
|
|
def get_rivernet_status(): |
|
|
"""Get real-time river monitoring data from RiverNet.lk""" |
|
|
try: |
|
|
from src.utils.utils import tool_rivernet_status |
|
|
river_data = tool_rivernet_status() |
|
|
return river_data |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching rivernet data: {e}") |
|
|
return { |
|
|
"rivers": [], |
|
|
"alerts": [], |
|
|
"summary": {"total_monitored": 0, "overall_status": "error", "has_alerts": False}, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/weather/historical") |
|
|
def get_historical_climate_data(): |
|
|
""" |
|
|
Get 30-year historical flood pattern analysis. |
|
|
|
|
|
Returns climate trend data including: |
|
|
- Average annual rainfall |
|
|
- Maximum daily rainfall records |
|
|
- Heavy/extreme rain day counts |
|
|
- Decadal comparison (1995-2025) |
|
|
- Key climate change findings |
|
|
""" |
|
|
try: |
|
|
from src.utils.utils import tool_floodwatch_historical |
|
|
historical_data = tool_floodwatch_historical() |
|
|
return { |
|
|
"status": "success", |
|
|
"data": historical_data |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching historical data: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/weather/threat") |
|
|
def get_national_threat_score(): |
|
|
""" |
|
|
Get national flood threat score (0-100). |
|
|
|
|
|
Aggregates river status, DMC alerts, and seasonal factors |
|
|
to compute an overall threat level for Sri Lanka. |
|
|
|
|
|
Returns: |
|
|
- national_threat_score (0-100) |
|
|
- threat_level (CRITICAL/HIGH/MODERATE/LOW) |
|
|
- breakdown by category |
|
|
- risk district lists |
|
|
""" |
|
|
try: |
|
|
from src.utils.utils import tool_rivernet_status, tool_calculate_national_threat, tool_dmc_alerts |
|
|
|
|
|
|
|
|
river_data = None |
|
|
try: |
|
|
river_data = tool_rivernet_status() |
|
|
except Exception as e: |
|
|
logger.warning(f"[ThreatAPI] RiverNet unavailable: {e}") |
|
|
|
|
|
|
|
|
dmc_data = None |
|
|
try: |
|
|
dmc_result = tool_dmc_alerts() |
|
|
dmc_data = dmc_result.get("alerts", []) |
|
|
except Exception as e: |
|
|
logger.warning(f"[ThreatAPI] DMC unavailable: {e}") |
|
|
|
|
|
|
|
|
threat_data = tool_calculate_national_threat( |
|
|
river_data=river_data, |
|
|
dmc_alerts=dmc_data |
|
|
) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
**threat_data |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error calculating threat: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"national_threat_score": 0, |
|
|
"threat_level": "UNKNOWN", |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INTEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "data", "intel_config.json") |
|
|
|
|
|
|
|
|
DEFAULT_INTEL_CONFIG = { |
|
|
"user_profiles": { |
|
|
"twitter": [], |
|
|
"facebook": [], |
|
|
"linkedin": [] |
|
|
}, |
|
|
"user_keywords": [], |
|
|
"user_products": [] |
|
|
} |
|
|
|
|
|
|
|
|
def load_intel_config() -> dict: |
|
|
"""Load intel config from JSON file.""" |
|
|
try: |
|
|
if os.path.exists(INTEL_CONFIG_PATH): |
|
|
with open(INTEL_CONFIG_PATH, "r", encoding="utf-8") as f: |
|
|
return json.load(f) |
|
|
except Exception as e: |
|
|
logger.warning(f"[Intel Config] Error loading config: {e}") |
|
|
return DEFAULT_INTEL_CONFIG.copy() |
|
|
|
|
|
|
|
|
def save_intel_config(config: dict) -> bool: |
|
|
"""Save intel config to JSON file.""" |
|
|
try: |
|
|
os.makedirs(os.path.dirname(INTEL_CONFIG_PATH), exist_ok=True) |
|
|
with open(INTEL_CONFIG_PATH, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, indent=2, ensure_ascii=False) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"[Intel Config] Error saving config: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
intel_config = load_intel_config() |
|
|
|
|
|
|
|
|
@app.get("/api/intel/config") |
|
|
def get_intel_config(): |
|
|
""" |
|
|
Get current intelligence configuration. |
|
|
|
|
|
Returns user-defined keywords, products, and social profiles to monitor. |
|
|
""" |
|
|
global intel_config |
|
|
intel_config = load_intel_config() |
|
|
return { |
|
|
"status": "success", |
|
|
"config": intel_config |
|
|
} |
|
|
|
|
|
|
|
|
class IntelConfigUpdate(BaseModel): |
|
|
user_profiles: dict = None |
|
|
user_keywords: list = None |
|
|
user_products: list = None |
|
|
|
|
|
|
|
|
@app.post("/api/intel/config") |
|
|
def update_intel_config(config_update: IntelConfigUpdate): |
|
|
""" |
|
|
Update intelligence configuration. |
|
|
|
|
|
Accepts user-defined keywords, products, and social profiles. |
|
|
Changes take effect on the next agent collection cycle. |
|
|
""" |
|
|
global intel_config |
|
|
|
|
|
try: |
|
|
|
|
|
if config_update.user_profiles is not None: |
|
|
intel_config["user_profiles"] = config_update.user_profiles |
|
|
if config_update.user_keywords is not None: |
|
|
intel_config["user_keywords"] = config_update.user_keywords |
|
|
if config_update.user_products is not None: |
|
|
intel_config["user_products"] = config_update.user_products |
|
|
|
|
|
|
|
|
if save_intel_config(intel_config): |
|
|
logger.info(f"[Intel Config] Updated: {len(intel_config.get('user_keywords', []))} keywords, " |
|
|
f"{sum(len(v) for v in intel_config.get('user_profiles', {}).values())} profiles") |
|
|
return { |
|
|
"status": "updated", |
|
|
"config": intel_config |
|
|
} |
|
|
else: |
|
|
return {"status": "error", "error": "Failed to save configuration"} |
|
|
except Exception as e: |
|
|
logger.error(f"[Intel Config] Update error: {e}") |
|
|
return {"status": "error", "error": str(e)} |
|
|
|
|
|
|
|
|
def get_user_intel_config() -> dict: |
|
|
""" |
|
|
Get the current intel config for use by agents. |
|
|
This function is called by social agents to get user-defined keywords and profiles. |
|
|
""" |
|
|
global intel_config |
|
|
return intel_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/power") |
|
|
def get_power_status(): |
|
|
""" |
|
|
Get CEB power outage / load shedding status. |
|
|
|
|
|
Returns current power supply status, active load shedding schedules, |
|
|
and any CEB announcements. |
|
|
""" |
|
|
try: |
|
|
from src.utils.utils import tool_ceb_power_status |
|
|
power_data = tool_ceb_power_status() |
|
|
return { |
|
|
"status": "success", |
|
|
**power_data |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching power status: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"load_shedding_active": False, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/fuel") |
|
|
def get_fuel_prices(): |
|
|
""" |
|
|
Get current fuel prices in Sri Lanka. |
|
|
|
|
|
Returns prices for Petrol 92/95, Diesel, Super Diesel, and Kerosene. |
|
|
""" |
|
|
try: |
|
|
from src.utils.utils import tool_fuel_prices |
|
|
fuel_data = tool_fuel_prices() |
|
|
return { |
|
|
"status": "success", |
|
|
**fuel_data |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching fuel prices: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"prices": {}, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/economy") |
|
|
def get_economic_indicators(): |
|
|
""" |
|
|
Get key economic indicators from CBSL. |
|
|
|
|
|
Returns inflation rates, policy rates, exchange rates, and forex reserves. |
|
|
""" |
|
|
try: |
|
|
from src.utils.utils import tool_cbsl_indicators |
|
|
economy_data = tool_cbsl_indicators() |
|
|
return { |
|
|
"status": "success", |
|
|
**economy_data |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching economic indicators: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"indicators": {}, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/health") |
|
|
def get_health_alerts(): |
|
|
""" |
|
|
Get health alerts and disease information. |
|
|
|
|
|
Returns current health alerts, dengue case data, and health advisories. |
|
|
""" |
|
|
try: |
|
|
from src.utils.utils import tool_health_alerts |
|
|
health_data = tool_health_alerts() |
|
|
return { |
|
|
"status": "success", |
|
|
**health_data |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching health data: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"alerts": [], |
|
|
"dengue": {}, |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/commodities") |
|
|
def get_commodity_prices(): |
|
|
""" |
|
|
Get prices for essential commodities. |
|
|
|
|
|
Returns current prices for rice, sugar, dhal, milk powder, and other staples. |
|
|
""" |
|
|
try: |
|
|
from src.utils.utils import tool_commodity_prices |
|
|
commodity_data = tool_commodity_prices() |
|
|
return { |
|
|
"status": "success", |
|
|
**commodity_data |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching commodity prices: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"commodities": [], |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/water") |
|
|
def get_water_supply_status(): |
|
|
""" |
|
|
Get water supply disruption alerts from NWSDB. |
|
|
|
|
|
Returns active disruptions, affected areas, and restoration estimates. |
|
|
""" |
|
|
try: |
|
|
from src.utils.utils import tool_water_supply_alerts |
|
|
water_data = tool_water_supply_alerts() |
|
|
return { |
|
|
"status": "success", |
|
|
**water_data |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[API] Error fetching water status: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"active_disruptions": [], |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/currency/history") |
|
|
def get_currency_history(days: int = 7): |
|
|
""" |
|
|
Get historical USD/LKR exchange rate data. |
|
|
|
|
|
Args: |
|
|
days: Number of days of history to return (default 7) |
|
|
|
|
|
Returns: |
|
|
List of historical rates with date and close price. |
|
|
""" |
|
|
try: |
|
|
from pathlib import Path |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
data_dir = Path(__file__).parent / "models" / "currency-volatility-prediction" / "artifacts" / "data" |
|
|
|
|
|
|
|
|
data_files = list(data_dir.glob("currency_data_*.csv")) if data_dir.exists() else [] |
|
|
|
|
|
if data_files: |
|
|
|
|
|
latest_file = max(data_files, key=lambda p: p.stem) |
|
|
df = pd.read_csv(latest_file) |
|
|
|
|
|
|
|
|
df['date'] = pd.to_datetime(df['date']) |
|
|
df = df.sort_values('date', ascending=False).head(days) |
|
|
df = df.sort_values('date', ascending=True) |
|
|
|
|
|
history = [] |
|
|
for _, row in df.iterrows(): |
|
|
history.append({ |
|
|
"date": row['date'].strftime("%Y-%m-%d"), |
|
|
"close": float(row['close']), |
|
|
"high": float(row.get('high', row['close'])), |
|
|
"low": float(row.get('low', row['close'])) |
|
|
}) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"history": history, |
|
|
"days": len(history) |
|
|
} |
|
|
|
|
|
return { |
|
|
"status": "no_data", |
|
|
"message": "No historical data available. Run data ingestion first.", |
|
|
"history": [] |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[CurrencyAPI] Error fetching history: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"error": str(e), |
|
|
"history": [] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/api/trending") |
|
|
def get_trending_topics(limit: int = 10): |
|
|
""" |
|
|
Get currently trending topics. |
|
|
|
|
|
Returns topics with momentum > 2x (gaining traction). |
|
|
""" |
|
|
try: |
|
|
from src.utils.trending_detector import get_trending_now, get_spikes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trending = get_trending_now(limit=limit) |
|
|
spikes = get_spikes() |
|
|
|
|
|
|
|
|
for topic in trending[:5]: |
|
|
keyword = topic["topic"] |
|
|
|
|
|
try: |
|
|
related = storage_manager.search_feeds(keyword, limit=2) |
|
|
topic["related_feeds"] = related |
|
|
except Exception as e: |
|
|
logger.warning(f"Error searching feeds for topic {keyword}: {e}") |
|
|
topic["related_feeds"] = [] |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"trending_topics": trending, |
|
|
"spike_alerts": spikes, |
|
|
"total_trending": len(trending), |
|
|
"total_spikes": len(spikes) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[TrendingAPI] Error: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"error": str(e), |
|
|
"trending_topics": [], |
|
|
"spike_alerts": [] |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/api/trending/topic/{topic}") |
|
|
def get_topic_history(topic: str, hours: int = 24): |
|
|
""" |
|
|
Get hourly mention history for a specific topic. |
|
|
|
|
|
Args: |
|
|
topic: Topic name to get history for |
|
|
hours: Number of hours of history to return (default 24) |
|
|
""" |
|
|
try: |
|
|
from src.utils.trending_detector import get_trending_detector |
|
|
|
|
|
detector = get_trending_detector() |
|
|
history = detector.get_topic_history(topic, hours=hours) |
|
|
momentum = detector.get_momentum(topic) |
|
|
is_spike = detector.is_spike(topic) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"topic": topic, |
|
|
"momentum": momentum, |
|
|
"is_spike": is_spike, |
|
|
"history": history |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[TrendingAPI] Error getting history for {topic}: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"error": str(e), |
|
|
"topic": topic, |
|
|
"momentum": 1.0, |
|
|
"is_spike": False, |
|
|
"history": [] |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/api/trending/record") |
|
|
def record_topic_mention(topic: str, source: str = "manual", domain: str = "general"): |
|
|
""" |
|
|
Record a topic mention (for testing/manual tracking). |
|
|
|
|
|
Args: |
|
|
topic: Topic/keyword being mentioned |
|
|
source: Source of the mention (twitter, news, etc.) |
|
|
domain: Domain category (political, economical, etc.) |
|
|
""" |
|
|
try: |
|
|
from src.utils.trending_detector import record_topic_mention as record_mention |
|
|
|
|
|
record_mention(topic=topic, source=source, domain=domain) |
|
|
|
|
|
|
|
|
from src.utils.trending_detector import get_trending_detector |
|
|
detector = get_trending_detector() |
|
|
momentum = detector.get_momentum(topic) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"message": f"Recorded mention for '{topic}'", |
|
|
"current_momentum": momentum, |
|
|
"is_spike": detector.is_spike(topic) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[TrendingAPI] Error recording mention: {e}") |
|
|
return { |
|
|
"status": "error", |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_anomaly_models = {} |
|
|
_vectorizer = None |
|
|
_language_detector = None |
|
|
|
|
|
|
|
|
def _load_anomaly_components(): |
|
|
"""Load per-language anomaly detection models and vectorizer""" |
|
|
global _anomaly_models, _vectorizer, _language_detector |
|
|
|
|
|
if _anomaly_models: |
|
|
return True |
|
|
|
|
|
try: |
|
|
import joblib |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
output_dir = Path(__file__).parent / "models" / "anomaly-detection" / "output" |
|
|
artifacts_dir = Path(__file__).parent / "models" / "anomaly-detection" / "artifacts" / "model_trainer" |
|
|
|
|
|
|
|
|
for lang in ["english", "sinhala", "tamil"]: |
|
|
for search_dir in [artifacts_dir, output_dir]: |
|
|
model_path = search_dir / f"isolation_forest_{lang}.joblib" |
|
|
if model_path.exists(): |
|
|
_anomaly_models[lang] = joblib.load(model_path) |
|
|
logger.info(f"[AnomalyAPI] Loaded {lang} model from {model_path.name}") |
|
|
break |
|
|
|
|
|
|
|
|
if not _anomaly_models: |
|
|
legacy_paths = [ |
|
|
output_dir / "isolation_forest_embeddings_only.joblib", |
|
|
output_dir / "isolation_forest_model.joblib", |
|
|
] |
|
|
for legacy_path in legacy_paths: |
|
|
if legacy_path.exists(): |
|
|
_anomaly_models["english"] = joblib.load(legacy_path) |
|
|
logger.info(f"[AnomalyAPI] Loaded legacy model: {legacy_path.name}") |
|
|
break |
|
|
|
|
|
if not _anomaly_models: |
|
|
logger.warning("[AnomalyAPI] No trained models found. Run training first.") |
|
|
return False |
|
|
|
|
|
|
|
|
from models.anomaly_detection.src.utils.vectorizer import get_vectorizer |
|
|
from models.anomaly_detection.src.utils.language_detector import detect_language |
|
|
|
|
|
_vectorizer = get_vectorizer() |
|
|
_language_detector = detect_language |
|
|
|
|
|
logger.info(f"[AnomalyAPI] ✓ Loaded models for: {list(_anomaly_models.keys())}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[AnomalyAPI] Failed to load components: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
@app.post("/api/predict") |
|
|
def predict_anomaly(texts: List[str] = None, text: str = None): |
|
|
""" |
|
|
Run anomaly detection on text(s) using per-language models. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to analyze |
|
|
text: Single text to analyze (alternative to texts) |
|
|
|
|
|
Returns: |
|
|
Predictions with anomaly scores |
|
|
""" |
|
|
try: |
|
|
|
|
|
if text and not texts: |
|
|
texts = [text] |
|
|
|
|
|
if not texts: |
|
|
return {"error": "No text provided. Use 'text' or 'texts' field.", "predictions": []} |
|
|
|
|
|
|
|
|
if not _load_anomaly_components(): |
|
|
|
|
|
return { |
|
|
"predictions": [ |
|
|
{ |
|
|
"text": t[:100] + "..." if len(t) > 100 else t, |
|
|
"is_anomaly": False, |
|
|
"anomaly_score": 0.0, |
|
|
"method": "heuristic" |
|
|
} |
|
|
for t in texts |
|
|
], |
|
|
"model_status": "not_trained", |
|
|
"message": "Model not trained yet. Using default scores." |
|
|
} |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
for t in texts: |
|
|
try: |
|
|
|
|
|
lang, lang_conf = _language_detector(t) |
|
|
|
|
|
|
|
|
vector = _vectorizer.vectorize(t, lang) |
|
|
|
|
|
|
|
|
if lang in _anomaly_models: |
|
|
model = _anomaly_models[lang] |
|
|
method = f"isolation_forest_{lang}" |
|
|
elif "english" in _anomaly_models: |
|
|
model = _anomaly_models["english"] |
|
|
method = "isolation_forest_english_fallback" |
|
|
else: |
|
|
|
|
|
predictions.append({ |
|
|
"text": t[:100] + "..." if len(t) > 100 else t, |
|
|
"is_anomaly": False, |
|
|
"anomaly_score": 0.0, |
|
|
"language": lang, |
|
|
"method": "no_model" |
|
|
}) |
|
|
continue |
|
|
|
|
|
|
|
|
prediction = model.predict([vector])[0] |
|
|
|
|
|
|
|
|
if hasattr(model, 'decision_function'): |
|
|
score = -model.decision_function([vector])[0] |
|
|
elif hasattr(model, 'score_samples'): |
|
|
score = -model.score_samples([vector])[0] |
|
|
else: |
|
|
score = 1.0 if prediction == -1 else 0.0 |
|
|
|
|
|
predictions.append({ |
|
|
"text": t[:100] + "..." if len(t) > 100 else t, |
|
|
"is_anomaly": prediction == -1, |
|
|
"anomaly_score": float(score), |
|
|
"language": lang, |
|
|
"method": method |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[AnomalyAPI] Error predicting: {e}") |
|
|
predictions.append({ |
|
|
"text": t[:100] + "..." if len(t) > 100 else t, |
|
|
"is_anomaly": False, |
|
|
"anomaly_score": 0.0, |
|
|
"error": str(e) |
|
|
}) |
|
|
|
|
|
return { |
|
|
"predictions": predictions, |
|
|
"total": len(predictions), |
|
|
"anomalies_found": sum(1 for p in predictions if p.get("is_anomaly")), |
|
|
"model_status": "loaded", |
|
|
"models_available": list(_anomaly_models.keys()) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[AnomalyAPI] Predict error: {e}", exc_info=True) |
|
|
return {"error": str(e), "predictions": []} |
|
|
|
|
|
|
|
|
@app.get("/api/anomalies") |
|
|
def get_anomalies(limit: int = 20, threshold: float = 0.5): |
|
|
""" |
|
|
Get recent feeds that are flagged as anomalies. |
|
|
|
|
|
Args: |
|
|
limit: Max number of results |
|
|
threshold: Anomaly score threshold (0-1) |
|
|
|
|
|
Returns: |
|
|
List of anomalous events |
|
|
""" |
|
|
try: |
|
|
|
|
|
feeds = storage_manager.get_recent_feeds(limit=100) |
|
|
|
|
|
if not feeds: |
|
|
|
|
|
return { |
|
|
"anomalies": [], |
|
|
"total": 0, |
|
|
"model_status": "no_data", |
|
|
"message": "No feed data available yet. Wait for graph execution to complete." |
|
|
} |
|
|
|
|
|
if not _load_anomaly_components(): |
|
|
|
|
|
anomalies = [] |
|
|
anomaly_keywords = ["emergency", "crisis", "breaking", "urgent", "alert", |
|
|
"warning", "critical", "disaster", "flood", "protest"] |
|
|
|
|
|
for f in feeds: |
|
|
score = 0.0 |
|
|
summary = str(f.get("summary", "")).lower() |
|
|
severity = f.get("severity", "low") |
|
|
|
|
|
|
|
|
if severity == "critical": score = 0.9 |
|
|
elif severity == "high": score = 0.75 |
|
|
elif severity == "medium": score = 0.5 |
|
|
else: score = 0.25 |
|
|
|
|
|
|
|
|
keyword_matches = sum(1 for kw in anomaly_keywords if kw in summary) |
|
|
if keyword_matches > 0: |
|
|
score = min(1.0, score + (keyword_matches * 0.1)) |
|
|
|
|
|
|
|
|
if score >= threshold: |
|
|
anomalies.append({ |
|
|
**f, |
|
|
"anomaly_score": round(score, 3), |
|
|
"is_anomaly": score >= 0.7 |
|
|
}) |
|
|
|
|
|
|
|
|
anomalies.sort(key=lambda x: x.get("anomaly_score", 0), reverse=True) |
|
|
|
|
|
return { |
|
|
"anomalies": anomalies[:limit], |
|
|
"total": len(anomalies), |
|
|
"threshold": threshold, |
|
|
"model_status": "fallback_scoring", |
|
|
"message": "Using severity + keyword scoring. Train ML model for advanced detection." |
|
|
} |
|
|
|
|
|
|
|
|
anomalies = [] |
|
|
per_lang_counts = {"english": 0, "sinhala": 0, "tamil": 0} |
|
|
|
|
|
for feed in feeds: |
|
|
summary = feed.get("summary", "") |
|
|
if not summary: |
|
|
continue |
|
|
|
|
|
try: |
|
|
lang, _ = _language_detector(summary) |
|
|
vector = _vectorizer.vectorize(summary, lang) |
|
|
|
|
|
|
|
|
if lang in _anomaly_models: |
|
|
model = _anomaly_models[lang] |
|
|
method = f"isolation_forest_{lang}" |
|
|
elif "english" in _anomaly_models: |
|
|
model = _anomaly_models["english"] |
|
|
method = "isolation_forest_english_fallback" |
|
|
else: |
|
|
continue |
|
|
|
|
|
per_lang_counts[lang] = per_lang_counts.get(lang, 0) + 1 |
|
|
prediction = model.predict([vector])[0] |
|
|
|
|
|
if hasattr(model, 'decision_function'): |
|
|
score = -model.decision_function([vector])[0] |
|
|
else: |
|
|
score = 1.0 if prediction == -1 else 0.0 |
|
|
|
|
|
|
|
|
normalized_score = max(0, min(1, (score + 0.5))) |
|
|
|
|
|
if prediction == -1 or normalized_score >= threshold: |
|
|
anomalies.append({ |
|
|
**feed, |
|
|
"anomaly_score": float(round(normalized_score, 3)), |
|
|
"is_anomaly": prediction == -1, |
|
|
"language": lang, |
|
|
"detection_method": method |
|
|
}) |
|
|
|
|
|
if len(anomalies) >= limit: |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.debug(f"[AnomalyAPI] Error scoring feed: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
anomalies.sort(key=lambda x: x.get("anomaly_score", 0), reverse=True) |
|
|
|
|
|
return { |
|
|
"anomalies": anomalies, |
|
|
"total": len(anomalies), |
|
|
"threshold": threshold, |
|
|
"model_status": "ml_active", |
|
|
"models_loaded": list(_anomaly_models.keys()), |
|
|
"per_language_counts": per_lang_counts |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[AnomalyAPI] Get anomalies error: {e}") |
|
|
return {"anomalies": [], "total": 0, "error": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/model/status") |
|
|
def get_model_status(): |
|
|
"""Get anomaly detection model status""" |
|
|
try: |
|
|
from pathlib import Path |
|
|
|
|
|
output_dir = Path(__file__).parent / "models" / "anomaly-detection" / "output" |
|
|
models_found = [] |
|
|
|
|
|
if output_dir.exists(): |
|
|
for f in output_dir.glob("*.joblib"): |
|
|
models_found.append(f.name) |
|
|
|
|
|
loaded = _anomaly_model is not None |
|
|
|
|
|
return { |
|
|
"model_loaded": loaded, |
|
|
"models_available": models_found, |
|
|
"vectorizer_loaded": _vectorizer is not None, |
|
|
"batch_threshold": int(os.getenv("BATCH_THRESHOLD", "1000")), |
|
|
"output_directory": str(output_dir) |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "model_loaded": False} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_rag_instance = None |
|
|
|
|
|
|
|
|
def _get_rag(): |
|
|
"""Get or create RAG instance""" |
|
|
global _rag_instance |
|
|
if _rag_instance is None: |
|
|
try: |
|
|
from src.rag import RogerRAG |
|
|
_rag_instance = RogerRAG() |
|
|
logger.info("[RAG API] ✓ RAG instance initialized") |
|
|
except Exception as e: |
|
|
logger.error(f"[RAG API] Failed to initialize RAG: {e}") |
|
|
return None |
|
|
return _rag_instance |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatRequest(BaseModel): |
|
|
message: str |
|
|
domain_filter: Optional[str] = None |
|
|
use_history: bool = True |
|
|
|
|
|
|
|
|
class ChatResponse(BaseModel): |
|
|
answer: str |
|
|
sources: List[Dict[str, Any]] = [] |
|
|
reformulated: Optional[str] = None |
|
|
docs_found: int = 0 |
|
|
error: Optional[str] = None |
|
|
|
|
|
|
|
|
@app.post("/api/rag/chat", response_model=ChatResponse) |
|
|
def rag_chat(request: ChatRequest): |
|
|
""" |
|
|
Chat with the RAG system. |
|
|
|
|
|
Args: |
|
|
message: User's question |
|
|
domain_filter: Optional domain (political, economic, weather, social, intelligence) |
|
|
use_history: Whether to use chat history for context (default: True) |
|
|
|
|
|
Returns: |
|
|
AI response with sources |
|
|
""" |
|
|
try: |
|
|
rag = _get_rag() |
|
|
if not rag: |
|
|
return ChatResponse( |
|
|
answer="RAG system not available. Please check server logs.", |
|
|
error="RAG initialization failed" |
|
|
) |
|
|
|
|
|
result = rag.query( |
|
|
question=request.message, |
|
|
domain_filter=request.domain_filter, |
|
|
use_history=request.use_history |
|
|
) |
|
|
|
|
|
return ChatResponse( |
|
|
answer=result.get("answer", "No response generated."), |
|
|
sources=result.get("sources", []), |
|
|
reformulated=result.get("reformulated"), |
|
|
docs_found=result.get("docs_found", 0), |
|
|
error=result.get("error") |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[RAG API] Chat error: {e}", exc_info=True) |
|
|
return ChatResponse( |
|
|
answer=f"Error processing your request: {str(e)}", |
|
|
error=str(e) |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/api/rag/stats") |
|
|
def rag_stats(): |
|
|
"""Get RAG system statistics""" |
|
|
try: |
|
|
rag = _get_rag() |
|
|
if not rag: |
|
|
return {"error": "RAG not available", "status": "offline"} |
|
|
|
|
|
stats = rag.get_stats() |
|
|
stats["status"] = "online" |
|
|
return stats |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "status": "error"} |
|
|
|
|
|
|
|
|
@app.post("/api/rag/clear") |
|
|
def rag_clear_history(): |
|
|
"""Clear RAG chat history""" |
|
|
try: |
|
|
rag = _get_rag() |
|
|
if rag: |
|
|
rag.clear_history() |
|
|
return {"message": "Chat history cleared", "success": True} |
|
|
return {"message": "RAG not available", "success": False} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": str(e), "success": False} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INTEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "src", "config", "intel_config.json") |
|
|
|
|
|
|
|
|
def _ensure_intel_config() -> str: |
|
|
"""Ensure config directory and file exist with default structure""" |
|
|
os.makedirs(os.path.dirname(INTEL_CONFIG_PATH), exist_ok=True) |
|
|
if not os.path.exists(INTEL_CONFIG_PATH): |
|
|
default_config = { |
|
|
"user_profiles": {"twitter": [], "facebook": [], "linkedin": []}, |
|
|
"user_keywords": [], |
|
|
"user_products": [] |
|
|
} |
|
|
with open(INTEL_CONFIG_PATH, "w", encoding="utf-8") as f: |
|
|
json.dump(default_config, f, indent=2) |
|
|
logger.info(f"[IntelConfig] Created default config at {INTEL_CONFIG_PATH}") |
|
|
return INTEL_CONFIG_PATH |
|
|
|
|
|
|
|
|
@app.get("/api/intel/config") |
|
|
def get_intel_config(): |
|
|
""" |
|
|
Get current intelligence monitoring configuration. |
|
|
|
|
|
Returns user-defined profiles, keywords, and products that the |
|
|
Intelligence Agent monitors in addition to defaults. |
|
|
""" |
|
|
try: |
|
|
path = _ensure_intel_config() |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
config = json.load(f) |
|
|
return {"status": "success", "config": config} |
|
|
except Exception as e: |
|
|
logger.error(f"[IntelConfig] Error reading config: {e}") |
|
|
return {"status": "error", "error": str(e)} |
|
|
|
|
|
|
|
|
class IntelConfigUpdate(BaseModel): |
|
|
user_profiles: Optional[Dict[str, List[str]]] = None |
|
|
user_keywords: Optional[List[str]] = None |
|
|
user_products: Optional[List[str]] = None |
|
|
|
|
|
|
|
|
@app.post("/api/intel/config") |
|
|
def update_intel_config(config: IntelConfigUpdate): |
|
|
""" |
|
|
Update intelligence monitoring configuration. |
|
|
|
|
|
Replaces the entire user config with the provided values. |
|
|
""" |
|
|
try: |
|
|
path = _ensure_intel_config() |
|
|
|
|
|
|
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
existing = json.load(f) |
|
|
|
|
|
|
|
|
if config.user_profiles is not None: |
|
|
existing["user_profiles"] = config.user_profiles |
|
|
if config.user_keywords is not None: |
|
|
existing["user_keywords"] = config.user_keywords |
|
|
if config.user_products is not None: |
|
|
existing["user_products"] = config.user_products |
|
|
|
|
|
|
|
|
with open(path, "w", encoding="utf-8") as f: |
|
|
json.dump(existing, f, indent=2) |
|
|
|
|
|
logger.info(f"[IntelConfig] Updated config: {len(existing.get('user_keywords', []))} keywords, {sum(len(v) for v in existing.get('user_profiles', {}).values())} profiles") |
|
|
return {"status": "updated", "config": existing} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[IntelConfig] Error updating config: {e}") |
|
|
return {"status": "error", "error": str(e)} |
|
|
|
|
|
|
|
|
@app.post("/api/intel/config/add") |
|
|
def add_intel_target(target_type: str, value: str, platform: Optional[str] = None): |
|
|
""" |
|
|
Add a single monitoring target. |
|
|
|
|
|
Args: |
|
|
target_type: "keyword", "product", or "profile" |
|
|
value: The value to add |
|
|
platform: Required for "profile" type (twitter, facebook, linkedin) |
|
|
|
|
|
Example: |
|
|
POST /api/intel/config/add?target_type=keyword&value=Colombo+Port |
|
|
POST /api/intel/config/add?target_type=profile&value=CompetitorX&platform=twitter |
|
|
""" |
|
|
try: |
|
|
path = _ensure_intel_config() |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
added = False |
|
|
|
|
|
if target_type == "keyword": |
|
|
if value not in config.get("user_keywords", []): |
|
|
config.setdefault("user_keywords", []).append(value) |
|
|
added = True |
|
|
elif target_type == "product": |
|
|
if value not in config.get("user_products", []): |
|
|
config.setdefault("user_products", []).append(value) |
|
|
added = True |
|
|
elif target_type == "profile": |
|
|
if not platform: |
|
|
return {"status": "error", "error": "platform is required for profile type"} |
|
|
profiles = config.setdefault("user_profiles", {}) |
|
|
platform_list = profiles.setdefault(platform, []) |
|
|
if value not in platform_list: |
|
|
platform_list.append(value) |
|
|
added = True |
|
|
else: |
|
|
return {"status": "error", "error": f"Invalid target_type: {target_type}"} |
|
|
|
|
|
if added: |
|
|
with open(path, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
logger.info(f"[IntelConfig] Added {target_type}: {value}") |
|
|
|
|
|
return {"status": "added" if added else "already_exists", "config": config} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[IntelConfig] Error adding target: {e}") |
|
|
return {"status": "error", "error": str(e)} |
|
|
|
|
|
|
|
|
@app.delete("/api/intel/config/remove") |
|
|
def remove_intel_target(target_type: str, value: str, platform: Optional[str] = None): |
|
|
""" |
|
|
Remove a monitoring target. |
|
|
|
|
|
Args: |
|
|
target_type: "keyword", "product", or "profile" |
|
|
value: The value to remove |
|
|
platform: Required for "profile" type |
|
|
""" |
|
|
try: |
|
|
path = _ensure_intel_config() |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
removed = False |
|
|
|
|
|
if target_type == "keyword": |
|
|
if value in config.get("user_keywords", []): |
|
|
config["user_keywords"].remove(value) |
|
|
removed = True |
|
|
elif target_type == "product": |
|
|
if value in config.get("user_products", []): |
|
|
config["user_products"].remove(value) |
|
|
removed = True |
|
|
elif target_type == "profile": |
|
|
if not platform: |
|
|
return {"status": "error", "error": "platform is required for profile type"} |
|
|
if platform in config.get("user_profiles", {}) and value in config["user_profiles"][platform]: |
|
|
config["user_profiles"][platform].remove(value) |
|
|
removed = True |
|
|
else: |
|
|
return {"status": "error", "error": f"Invalid target_type: {target_type}"} |
|
|
|
|
|
if removed: |
|
|
with open(path, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
logger.info(f"[IntelConfig] Removed {target_type}: {value}") |
|
|
|
|
|
return {"status": "removed" if removed else "not_found", "config": config} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[IntelConfig] Error removing target: {e}") |
|
|
return {"status": "error", "error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_weather_predictor = None |
|
|
|
|
|
def get_weather_predictor(): |
|
|
"""Lazy-load the weather predictor using isolated import.""" |
|
|
global _weather_predictor |
|
|
if _weather_predictor is not None: |
|
|
return _weather_predictor |
|
|
|
|
|
try: |
|
|
import importlib.util |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
|
|
|
weather_src = Path(__file__).parent / "models" / "weather-prediction" / "src" |
|
|
predictor_path = weather_src / "components" / "predictor.py" |
|
|
|
|
|
if not predictor_path.exists(): |
|
|
logger.error(f"[WeatherAPI] predictor.py not found at {predictor_path}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
currency_src = str(Path(__file__).parent / "models" / "currency-volatility-prediction" / "src") |
|
|
stock_src = str(Path(__file__).parent / "models" / "stock-price-prediction" / "src") |
|
|
anomaly_src = str(Path(__file__).parent / "models" / "anomaly-detection" / "src") |
|
|
|
|
|
original_path = sys.path.copy() |
|
|
sys.path = [p for p in sys.path if currency_src not in p and stock_src not in p and anomaly_src not in p] |
|
|
|
|
|
|
|
|
modules_to_clear = [k for k in sys.modules.keys() if 'entity' in k.lower() or 'config_entity' in k.lower()] |
|
|
saved_modules = {} |
|
|
for mod_name in modules_to_clear: |
|
|
saved_modules[mod_name] = sys.modules.pop(mod_name, None) |
|
|
|
|
|
|
|
|
weather_src_str = str(weather_src) |
|
|
if weather_src_str not in sys.path: |
|
|
sys.path.insert(0, weather_src_str) |
|
|
|
|
|
try: |
|
|
|
|
|
spec = importlib.util.spec_from_file_location( |
|
|
"weather_predictor_module", |
|
|
str(predictor_path) |
|
|
) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(module) |
|
|
|
|
|
_weather_predictor = module.WeatherPredictor() |
|
|
logger.info("[WeatherAPI] ✓ Weather predictor initialized via isolated import") |
|
|
finally: |
|
|
|
|
|
sys.path = original_path |
|
|
|
|
|
for mod_name, mod in saved_modules.items(): |
|
|
if mod is not None: |
|
|
sys.modules[mod_name] = mod |
|
|
|
|
|
return _weather_predictor |
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"[WeatherAPI] Failed to initialize predictor: {e}") |
|
|
import traceback |
|
|
logger.error(f"[WeatherAPI] Full traceback:\n{traceback.format_exc()}") |
|
|
return None |
|
|
|
|
|
|
|
|
@app.get("/api/weather/predictions") |
|
|
async def get_weather_predictions(): |
|
|
""" |
|
|
Get weather predictions for all 25 Sri Lankan districts. |
|
|
|
|
|
Returns next-day predictions including: |
|
|
- Temperature (high/low) |
|
|
- Rainfall (amount and probability) |
|
|
- Flood risk |
|
|
- Severity classification |
|
|
""" |
|
|
predictor = get_weather_predictor() |
|
|
|
|
|
if predictor is None: |
|
|
return { |
|
|
"status": "unavailable", |
|
|
"message": "Weather prediction model not loaded", |
|
|
"predictions": None |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
predictions = predictor.get_latest_predictions() |
|
|
|
|
|
if predictions is None: |
|
|
|
|
|
logger.info("[WeatherAPI] Generating new predictions...") |
|
|
predictions = predictor.predict_all_districts() |
|
|
predictor.save_predictions(predictions) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"prediction_date": predictions.get("prediction_date"), |
|
|
"generated_at": predictions.get("generated_at"), |
|
|
"districts": predictions.get("districts", {}), |
|
|
"total_districts": len(predictions.get("districts", {})) |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[WeatherAPI] Error getting predictions: {e}") |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/weather/predictions/{district}") |
|
|
async def get_district_weather(district: str): |
|
|
"""Get weather prediction for a specific district.""" |
|
|
predictor = get_weather_predictor() |
|
|
|
|
|
if predictor is None: |
|
|
return {"status": "unavailable", "message": "Weather predictor not loaded"} |
|
|
|
|
|
try: |
|
|
predictions = predictor.get_latest_predictions() |
|
|
|
|
|
if predictions is None: |
|
|
predictions = predictor.predict_all_districts() |
|
|
|
|
|
districts = predictions.get("districts", {}) |
|
|
|
|
|
|
|
|
district_key = None |
|
|
for d in districts.keys(): |
|
|
if d.lower() == district.lower(): |
|
|
district_key = d |
|
|
break |
|
|
|
|
|
if district_key is None: |
|
|
return { |
|
|
"status": "not_found", |
|
|
"message": f"District '{district}' not found", |
|
|
"available_districts": list(districts.keys()) |
|
|
} |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"district": district_key, |
|
|
"prediction_date": predictions.get("prediction_date"), |
|
|
"prediction": districts[district_key] |
|
|
} |
|
|
except Exception as e: |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/weather/model/status") |
|
|
async def get_weather_model_status(): |
|
|
"""Get weather prediction model status and training info.""" |
|
|
from pathlib import Path |
|
|
|
|
|
models_dir = Path(__file__).parent / "models" / "weather-prediction" / "artifacts" / "models" |
|
|
predictions_dir = Path(__file__).parent / "models" / "weather-prediction" / "output" / "predictions" |
|
|
|
|
|
model_files = list(models_dir.glob("lstm_*.h5")) if models_dir.exists() else [] |
|
|
prediction_files = list(predictions_dir.glob("predictions_*.json")) if predictions_dir.exists() else [] |
|
|
|
|
|
latest_prediction = None |
|
|
if prediction_files: |
|
|
latest = max(prediction_files, key=lambda p: p.stat().st_mtime) |
|
|
latest_prediction = { |
|
|
"file": latest.name, |
|
|
"modified": datetime.fromtimestamp(latest.stat().st_mtime).isoformat() |
|
|
} |
|
|
|
|
|
return { |
|
|
"status": "available" if model_files else "not_trained", |
|
|
"models_trained": len(model_files), |
|
|
"trained_stations": [f.stem.replace("lstm_", "").upper() for f in model_files], |
|
|
"latest_prediction": latest_prediction, |
|
|
"predictions_available": len(prediction_files) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_currency_predictor = None |
|
|
|
|
|
def get_currency_predictor(): |
|
|
"""Lazy-load the currency predictor.""" |
|
|
global _currency_predictor |
|
|
if _currency_predictor is None: |
|
|
try: |
|
|
import sys |
|
|
from pathlib import Path |
|
|
currency_path = Path(__file__).parent / "models" / "currency-volatility-prediction" / "src" |
|
|
sys.path.insert(0, str(currency_path)) |
|
|
from components.predictor import CurrencyPredictor |
|
|
_currency_predictor = CurrencyPredictor() |
|
|
logger.info("[CurrencyAPI] Currency predictor initialized") |
|
|
except Exception as e: |
|
|
logger.warning(f"[CurrencyAPI] Failed to initialize predictor: {e}") |
|
|
_currency_predictor = None |
|
|
return _currency_predictor |
|
|
|
|
|
|
|
|
@app.get("/api/currency/prediction") |
|
|
async def get_currency_prediction(): |
|
|
""" |
|
|
Get USD/LKR currency prediction for next day. |
|
|
|
|
|
Returns: |
|
|
- Current rate |
|
|
- Predicted rate |
|
|
- Expected change percentage |
|
|
- Direction (strengthening/weakening) |
|
|
- Volatility classification |
|
|
""" |
|
|
predictor = get_currency_predictor() |
|
|
|
|
|
if predictor is None: |
|
|
|
|
|
import numpy as np |
|
|
current_rate = 298.0 |
|
|
np.random.seed(int(datetime.now().timestamp()) % 2**31) |
|
|
change_pct = np.random.normal(0.05, 0.3) |
|
|
predicted_rate = current_rate * (1 + change_pct / 100) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"prediction": { |
|
|
"prediction_date": (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d"), |
|
|
"generated_at": datetime.now().isoformat(), |
|
|
"model_version": "fallback", |
|
|
"is_fallback": True, |
|
|
"current_rate": round(current_rate, 2), |
|
|
"predicted_rate": round(predicted_rate, 2), |
|
|
"expected_change": round(predicted_rate - current_rate, 2), |
|
|
"expected_change_pct": round(change_pct, 3), |
|
|
"direction": "strengthening" if change_pct < 0 else "weakening", |
|
|
"direction_emoji": "📈" if change_pct < 0 else "📉", |
|
|
"volatility_class": "low", |
|
|
"note": "Using fallback - model initializing" |
|
|
} |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
prediction = predictor.get_latest_prediction() |
|
|
|
|
|
if prediction is None: |
|
|
|
|
|
logger.info("[CurrencyAPI] No prediction found, generating fallback...") |
|
|
prediction = predictor.generate_fallback_prediction() |
|
|
predictor.save_prediction(prediction) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"prediction": prediction |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[CurrencyAPI] Error: {e}") |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/currency/history") |
|
|
async def get_currency_history(days: int = 30): |
|
|
"""Get historical USD/LKR rates.""" |
|
|
from pathlib import Path |
|
|
import pandas as pd |
|
|
|
|
|
try: |
|
|
data_dir = Path(__file__).parent / "models" / "currency-volatility-prediction" / "artifacts" / "data" |
|
|
csv_files = list(data_dir.glob("currency_data_*.csv")) if data_dir.exists() else [] |
|
|
|
|
|
if not csv_files: |
|
|
return {"status": "no_data", "message": "No currency data available"} |
|
|
|
|
|
latest = max(csv_files, key=lambda p: p.stat().st_mtime) |
|
|
df = pd.read_csv(latest, parse_dates=["date"]) |
|
|
|
|
|
|
|
|
df = df.tail(days) |
|
|
|
|
|
history = [] |
|
|
for _, row in df.iterrows(): |
|
|
history.append({ |
|
|
"date": row["date"].strftime("%Y-%m-%d") if hasattr(row["date"], "strftime") else str(row["date"]), |
|
|
"close": round(row["close"], 2), |
|
|
"high": round(row.get("high", row["close"]), 2), |
|
|
"low": round(row.get("low", row["close"]), 2), |
|
|
"daily_return_pct": round(row.get("daily_return", 0) * 100, 3) |
|
|
}) |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"days": len(history), |
|
|
"history": history |
|
|
} |
|
|
except Exception as e: |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/currency/model/status") |
|
|
async def get_currency_model_status(): |
|
|
"""Get currency prediction model status.""" |
|
|
from pathlib import Path |
|
|
|
|
|
models_dir = Path(__file__).parent / "models" / "currency-volatility-prediction" / "artifacts" / "models" |
|
|
predictions_dir = Path(__file__).parent / "models" / "currency-volatility-prediction" / "output" / "predictions" |
|
|
|
|
|
model_exists = (models_dir / "gru_usd_lkr.h5").exists() if models_dir.exists() else False |
|
|
prediction_files = list(predictions_dir.glob("currency_prediction_*.json")) if predictions_dir.exists() else [] |
|
|
|
|
|
latest_prediction = None |
|
|
if prediction_files: |
|
|
latest = max(prediction_files, key=lambda p: p.stat().st_mtime) |
|
|
latest_prediction = { |
|
|
"file": latest.name, |
|
|
"modified": datetime.fromtimestamp(latest.stat().st_mtime).isoformat() |
|
|
} |
|
|
|
|
|
return { |
|
|
"status": "available" if model_exists else "not_trained", |
|
|
"model_type": "GRU", |
|
|
"target": "USD/LKR", |
|
|
"latest_prediction": latest_prediction, |
|
|
"predictions_available": len(prediction_files) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_stock_predictor = None |
|
|
|
|
|
def get_stock_predictor(): |
|
|
"""Lazy-load the stock predictor.""" |
|
|
global _stock_predictor |
|
|
if _stock_predictor is None: |
|
|
try: |
|
|
import sys |
|
|
from pathlib import Path |
|
|
stock_path = Path(__file__).parent / "models" / "stock-price-prediction" / "src" |
|
|
sys.path.insert(0, str(stock_path)) |
|
|
from components.predictor import StockPredictor |
|
|
_stock_predictor = StockPredictor() |
|
|
logger.info("[StockAPI] Stock predictor initialized") |
|
|
except Exception as e: |
|
|
logger.warning(f"[StockAPI] Failed to initialize predictor: {e}") |
|
|
_stock_predictor = None |
|
|
return _stock_predictor |
|
|
|
|
|
|
|
|
@app.get("/api/stocks/predictions") |
|
|
async def get_stock_predictions(): |
|
|
""" |
|
|
Get stock price predictions for all configured stocks. |
|
|
|
|
|
Returns predictions for 10 popular stocks with: |
|
|
- Current price |
|
|
- Predicted next-day price |
|
|
- Expected change percentage |
|
|
- Trend classification (bullish/bearish/neutral) |
|
|
- Model architecture used |
|
|
""" |
|
|
predictor = get_stock_predictor() |
|
|
|
|
|
if predictor is None: |
|
|
|
|
|
try: |
|
|
import sys |
|
|
from pathlib import Path |
|
|
stock_path = Path(__file__).parent / "models" / "stock-price-prediction" / "src" |
|
|
sys.path.insert(0, str(stock_path)) |
|
|
from constants.training_pipeline import STOCKS_TO_TRAIN |
|
|
|
|
|
from datetime import datetime |
|
|
predictions = { |
|
|
"prediction_date": (datetime.now()).strftime("%Y-%m-%d"), |
|
|
"generated_at": datetime.now().isoformat(), |
|
|
"stocks": {}, |
|
|
"summary": {"total_stocks": len(STOCKS_TO_TRAIN), "bullish": 0, "bearish": 0, "neutral": 0} |
|
|
} |
|
|
|
|
|
import numpy as np |
|
|
for code, info in STOCKS_TO_TRAIN.items(): |
|
|
np.random.seed(hash(code) % 2**31) |
|
|
change_pct = np.random.normal(0.1, 1.0) |
|
|
trend = "bullish" if change_pct > 0.5 else "bearish" if change_pct < -0.5 else "neutral" |
|
|
predictions["summary"][trend] = predictions["summary"].get(trend, 0) + 1 |
|
|
predictions["stocks"][code] = { |
|
|
"symbol": code, |
|
|
"name": info.get("name", code), |
|
|
"sector": info.get("sector", "Unknown"), |
|
|
"current_price": 100.0, |
|
|
"predicted_price": 100.0 * (1 + change_pct / 100), |
|
|
"expected_change_pct": round(change_pct, 3), |
|
|
"trend": trend, |
|
|
"trend_emoji": "📈" if trend == "bullish" else "📉" if trend == "bearish" else "➡️", |
|
|
"confidence": round(np.random.uniform(0.65, 0.85), 2), |
|
|
"is_fallback": True |
|
|
} |
|
|
|
|
|
return {"status": "success", "predictions": predictions} |
|
|
except Exception as e: |
|
|
return {"status": "unavailable", "message": f"Stock prediction model not loaded: {e}"} |
|
|
|
|
|
try: |
|
|
|
|
|
predictions = predictor.get_latest_predictions() |
|
|
|
|
|
if predictions is None: |
|
|
|
|
|
logger.info("[StockAPI] No predictions found, generating fallback...") |
|
|
predictions = predictor.predict_all_stocks() |
|
|
predictions = { |
|
|
"prediction_date": (datetime.now()).strftime("%Y-%m-%d"), |
|
|
"generated_at": datetime.now().isoformat(), |
|
|
"stocks": predictions, |
|
|
"summary": {"total_stocks": len(predictions)} |
|
|
} |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"predictions": predictions |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"[StockAPI] Error: {e}") |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/stocks/predictions/{symbol}") |
|
|
async def get_stock_prediction_by_symbol(symbol: str): |
|
|
"""Get prediction for a specific stock symbol.""" |
|
|
predictor = get_stock_predictor() |
|
|
|
|
|
if predictor is None: |
|
|
return {"status": "unavailable", "message": "Stock prediction model not loaded"} |
|
|
|
|
|
try: |
|
|
predictions = predictor.get_latest_predictions() |
|
|
|
|
|
if predictions and symbol.upper() in predictions.get("stocks", {}): |
|
|
return { |
|
|
"status": "success", |
|
|
"prediction": predictions["stocks"][symbol.upper()] |
|
|
} |
|
|
else: |
|
|
|
|
|
return { |
|
|
"status": "success", |
|
|
"prediction": predictor._generate_fallback_prediction(symbol.upper()) |
|
|
} |
|
|
except Exception as e: |
|
|
return {"status": "error", "message": str(e)} |
|
|
|
|
|
|
|
|
@app.get("/api/stocks/model/status") |
|
|
async def get_stock_model_status(): |
|
|
"""Get stock prediction model status for all stocks.""" |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
models_dir = Path(__file__).parent / "models" / "stock-price-prediction" / "artifacts" / "models" |
|
|
predictions_dir = Path(__file__).parent / "models" / "stock-price-prediction" / "output" / "predictions" |
|
|
|
|
|
model_files = list(models_dir.glob("*_model.h5")) if models_dir.exists() else [] |
|
|
prediction_files = list(predictions_dir.glob("stock_predictions_*.json")) if predictions_dir.exists() else [] |
|
|
|
|
|
|
|
|
summary_path = models_dir / "training_summary.json" if models_dir.exists() else None |
|
|
training_summary = None |
|
|
if summary_path and summary_path.exists(): |
|
|
with open(summary_path) as f: |
|
|
training_summary = json.load(f) |
|
|
|
|
|
latest_prediction = None |
|
|
if prediction_files: |
|
|
latest = max(prediction_files, key=lambda p: p.stat().st_mtime) |
|
|
latest_prediction = { |
|
|
"file": latest.name, |
|
|
"modified": datetime.fromtimestamp(latest.stat().st_mtime).isoformat() |
|
|
} |
|
|
|
|
|
return { |
|
|
"status": "available" if model_files else "not_trained", |
|
|
"models_trained": len(model_files), |
|
|
"trained_stocks": [f.stem.replace("_model", "").upper() for f in model_files], |
|
|
"training_summary": training_summary, |
|
|
"latest_prediction": latest_prediction, |
|
|
"predictions_available": len(prediction_files) |
|
|
} |
|
|
|
|
|
|
|
|
@app.websocket("/ws") |
|
|
async def websocket_endpoint(websocket: WebSocket): |
|
|
await manager.connect(websocket) |
|
|
|
|
|
try: |
|
|
|
|
|
try: |
|
|
await websocket.send_text(json.dumps(current_state, default=str)) |
|
|
except Exception as e: |
|
|
logger.debug(f"[WS] Initial send failed: {e}") |
|
|
await manager.disconnect(websocket) |
|
|
return |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
txt = await websocket.receive_text() |
|
|
except WebSocketDisconnect: |
|
|
logger.info("[WS] Client disconnected") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.debug(f"[WS] Receive error: {e}") |
|
|
break |
|
|
|
|
|
|
|
|
try: |
|
|
payload = json.loads(txt) |
|
|
if isinstance(payload, dict) and payload.get("type") == "pong": |
|
|
async with manager._lock: |
|
|
meta = manager.active_connections.get(websocket) |
|
|
if meta is not None: |
|
|
meta['last_pong'] = utc_now() |
|
|
meta['misses'] = 0 |
|
|
continue |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
finally: |
|
|
await manager.disconnect(websocket) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
import uuid |
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |
|
|
|