Demo-2025 / mission_planner.py
zye0616's picture
Update: Add mission pipeline registry
34b56b2
from __future__ import annotations
import json
import logging
from dataclasses import asdict, dataclass, replace
from datetime import datetime
from typing import Any, Dict, List, Mapping, Tuple
from coco_classes import canonicalize_coco_name, coco_class_catalog
from mission_context import (
MissionClass,
MissionContext,
MissionPlan,
MISSION_TYPE_OPTIONS,
LOCATION_TYPE_OPTIONS,
TIME_OF_DAY_OPTIONS,
PRIORITY_LEVEL_OPTIONS,
PipelineRecommendation,
build_prompt_hints,
)
from pipeline_registry import (
PIPELINE_SPECS,
fallback_pipeline_for_context,
filter_pipelines_for_context,
get_pipeline_spec,
)
from prompt import mission_planner_system_prompt, mission_planner_user_prompt
from utils.openai_client import get_openai_client
DEFAULT_OPENAI_MODEL = "gpt-4o-mini"
class MissionReasoner:
def __init__(
self,
*,
model_name: str = DEFAULT_OPENAI_MODEL,
top_k: int = 10,
) -> None:
self._model_name = model_name
self._top_k = top_k
self._coco_catalog = coco_class_catalog()
def plan(
self,
mission: str,
*,
context: MissionContext,
cues: Mapping[str, Any] | None = None,
) -> MissionPlan:
mission = (mission or "").strip()
if not mission:
raise ValueError("Mission prompt cannot be empty.")
available_pipelines = self._candidate_pipelines(mission, context, cues)
candidate_ids = [spec["id"] for spec in available_pipelines] or [PIPELINE_SPECS[0]["id"]]
lock_pipeline_id = candidate_ids[0] if len(candidate_ids) == 1 else None
response_payload = self._query_llm(
mission,
context=context,
cues=None,
pipeline_ids=candidate_ids,
)
relevant = self._parse_plan(response_payload, fallback_mission=mission)
enriched_context = self._merge_context(context, response_payload.get("context"))
if lock_pipeline_id:
pipeline_rec = PipelineRecommendation(
primary_id=lock_pipeline_id,
primary_reason="Only pipeline compatible with mission context.",
)
else:
pipeline_rec = self._parse_pipeline_recommendation(
response_payload.get("pipelines") or response_payload.get("pipeline"),
available_pipelines,
context,
)
return MissionPlan(
mission=response_payload.get("mission", mission),
relevant_classes=relevant[: self._top_k],
context=enriched_context,
pipeline=pipeline_rec,
)
def _render_pipeline_catalog(self, specs: List[Dict[str, object]]) -> str:
if not specs:
return "No compatible pipelines available."
sections: List[str] = []
for spec in specs:
reason = spec.get("availability_reason") or "Compatible with mission context."
hf_bindings = spec.get("huggingface") or {}
def _format_models(models: List[Dict[str, object]]) -> str:
if not models:
return "none"
labels = []
for entry in models:
model_id = entry.get("model_id") or entry.get("name") or "unknown"
label = entry.get("label") or model_id
suffix = " (optional)" if entry.get("optional") else ""
labels.append(f"{label}{suffix}")
return ", ".join(labels)
detection_models = _format_models(hf_bindings.get("detection", []))
segmentation_models = _format_models(hf_bindings.get("segmentation", []))
tracking_models = _format_models(hf_bindings.get("tracking", []))
hf_notes = hf_bindings.get("notes") or ""
sections.append(
"\n".join(
[
f"{spec['id']} pipeline",
f" Modalities: {', '.join(spec.get('modalities', ())) or 'unspecified'}",
f" Locations: {', '.join(spec.get('location_types', ())) or 'any'}",
f" Time of day: {', '.join(spec.get('time_of_day', ())) or 'any'}",
f" Availability: {reason}",
f" HF detection: {detection_models}",
f" HF segmentation: {segmentation_models}",
f" Tracking: {tracking_models}",
f" Notes: {hf_notes or 'n/a'}",
]
)
)
return "\n\n".join(sections)
def _candidate_pipelines(
self,
mission: str,
context: MissionContext,
cues: Mapping[str, Any] | None,
) -> List[Dict[str, object]]:
filtered = filter_pipelines_for_context(context)
if filtered:
return filtered
fallback_spec = fallback_pipeline_for_context(context, [])
if fallback_spec is None:
logging.error("No fallback pipeline available; mission context=%s", context)
return [dict(spec) for spec in PIPELINE_SPECS]
logging.warning(
"No compatible pipelines for context %s; selecting fallback %s.",
context,
fallback_spec["id"],
)
fallback_copy = dict(fallback_spec)
fallback_copy["availability_reason"] = (
"Fallback engaged because no specialized pipeline matched this mission context."
)
return [fallback_copy]
def _query_llm(
self,
mission: str,
*,
context: MissionContext,
cues: Mapping[str, Any] | None = None,
pipeline_ids: List[str] | None,
) -> Dict[str, object]:
client = get_openai_client()
system_prompt = mission_planner_system_prompt()
context_payload = context.to_prompt_payload()
user_prompt = mission_planner_user_prompt(
mission,
self._top_k,
context=context_payload,
cues=cues,
pipeline_candidates=pipeline_ids,
coco_catalog=self._coco_catalog,
)
completion = client.chat.completions.create(
model=self._model_name,
temperature=0.1,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
content = completion.choices[0].message.content or "{}"
try:
return json.loads(content)
except json.JSONDecodeError:
logging.exception("LLM returned non-JSON content: %s", content)
return {"mission": mission, "classes": []}
def _parse_plan(self, payload: Dict[str, object], fallback_mission: str) -> List[MissionClass]:
entries = (
payload.get("entities")
or payload.get("classes")
or payload.get("relevant_classes")
or []
)
mission = payload.get("mission") or fallback_mission
parsed: List[MissionClass] = []
seen = set()
for entry in entries:
if not isinstance(entry, dict):
continue
name = str(entry.get("name") or "").strip()
if not name:
continue
canonical_name = canonicalize_coco_name(name)
if not canonical_name:
logging.warning("Skipping non-COCO entity '%s'.", name)
continue
if canonical_name in seen:
continue
seen.add(canonical_name)
score_raw = entry.get("score")
try:
score = float(score_raw)
except (TypeError, ValueError):
score = 0.5
rationale = str(entry.get("rationale") or f"Track '{name}' for mission '{mission}'.")
parsed.append(
MissionClass(
name=canonical_name,
score=max(0.0, min(1.0, score)),
rationale=rationale,
)
)
if not parsed:
raise RuntimeError("LLM returned no semantic entities; aborting instead of fabricating outputs.")
return parsed
def _merge_context(
self,
base_context: MissionContext,
context_payload: Dict[str, object] | None,
) -> MissionContext:
payload = context_payload or {}
if not isinstance(payload, dict):
return base_context
def _coerce_choice(value: object | None, allowed: Tuple[str, ...]) -> str | None:
if value is None:
return None
candidate = str(value).strip().lower()
return candidate if candidate in allowed else None
updates: Dict[str, Any] = {}
new_mission_type = _coerce_choice(payload.get("mission_type"), MISSION_TYPE_OPTIONS)
new_location_type = _coerce_choice(payload.get("location_type"), LOCATION_TYPE_OPTIONS)
new_time_of_day = _coerce_choice(payload.get("time_of_day"), TIME_OF_DAY_OPTIONS)
new_priority = _coerce_choice(payload.get("priority_level"), PRIORITY_LEVEL_OPTIONS)
if new_mission_type:
updates["mission_type"] = new_mission_type
if new_location_type:
updates["location_type"] = new_location_type
if new_time_of_day:
updates["time_of_day"] = new_time_of_day
if new_priority:
updates["priority_level"] = new_priority
if not updates:
return base_context
return replace(base_context, **updates)
def _parse_pipeline_recommendation(
self,
payload: object,
available_specs: List[Dict[str, object]],
context: MissionContext,
) -> PipelineRecommendation | None:
if not isinstance(payload, dict):
return self._validate_pipeline_selection(None, available_specs, context)
if "id" in payload or "pipeline_id" in payload or "pipeline" in payload:
pipeline_id_raw = payload.get("id") or payload.get("pipeline_id") or payload.get("pipeline")
pipeline_id = str(pipeline_id_raw or "").strip()
reason = str(payload.get("reason") or "").strip() or None
candidate = PipelineRecommendation(primary_id=pipeline_id or None, primary_reason=reason)
return self._validate_pipeline_selection(candidate, available_specs, context)
def _extract_entry(entry_key: str) -> tuple[str | None, str | None]:
value = payload.get(entry_key)
if not isinstance(value, dict):
return None, None
pipeline_id_raw = value.get("id") or value.get("pipeline_id") or value.get("pipeline")
pipeline_id = str(pipeline_id_raw).strip()
if not pipeline_id:
return None, None
if not get_pipeline_spec(pipeline_id):
return None, None
reason = str(value.get("reason") or "").strip() or None
return pipeline_id, reason
primary_id, primary_reason = _extract_entry("primary")
fallback_id, fallback_reason = _extract_entry("fallback")
rec = PipelineRecommendation(
primary_id=primary_id,
primary_reason=primary_reason,
fallback_id=fallback_id,
fallback_reason=fallback_reason,
)
return self._validate_pipeline_selection(rec, available_specs, context)
def _validate_pipeline_selection(
self,
candidate: PipelineRecommendation | None,
available_specs: List[Dict[str, object]],
context: MissionContext,
) -> PipelineRecommendation | None:
if not available_specs:
return None
available_ids = {spec["id"] for spec in available_specs}
def _normalize_reason(reason: str | None, default: str) -> str:
text = (reason or "").strip()
return text or default
primary_id = candidate.primary_id if candidate and candidate.primary_id in available_ids else None
if not primary_id:
fallback_spec = fallback_pipeline_for_context(context, available_specs)
if fallback_spec is None:
logging.warning("No pipelines available even after fallback.")
return None
logging.warning(
"Pipeline recommendation invalid or missing. Defaulting to %s.", fallback_spec["id"]
)
return PipelineRecommendation(
primary_id=fallback_spec["id"],
primary_reason=_normalize_reason(
candidate.primary_reason if candidate else None,
"Auto-selected based on available sensors and context.",
),
fallback_id=None,
fallback_reason=None,
)
primary_reason = _normalize_reason(candidate.primary_reason if candidate else None, "LLM-selected.")
fallback_allowed = context.priority_level in {"elevated", "high"}
fallback_id = candidate.fallback_id if candidate else None
fallback_reason = candidate.fallback_reason if candidate else None
if not fallback_allowed or fallback_id not in available_ids or fallback_id == primary_id:
if fallback_id:
logging.info("Dropping fallback pipeline %s due to priority/context constraints.", fallback_id)
fallback_id_valid = None
fallback_reason_valid = None
else:
fallback_id_valid = fallback_id
fallback_reason_valid = _normalize_reason(fallback_reason, "Fallback allowed due to priority level.")
return PipelineRecommendation(
primary_id=primary_id,
primary_reason=primary_reason,
fallback_id=fallback_id_valid,
fallback_reason=fallback_reason_valid,
)
_REASONER: MissionReasoner | None = None
def get_mission_plan(
mission: str,
*,
latitude: float | None = None,
longitude: float | None = None,
context_overrides: MissionContext | None = None,
) -> MissionPlan:
global _REASONER
if _REASONER is None:
_REASONER = MissionReasoner()
context = context_overrides or MissionContext()
cues = build_prompt_hints(mission, latitude, longitude)
if latitude is not None and longitude is not None:
logging.info("Mission location coordinates: lat=%s, lon=%s", latitude, longitude)
local_time_hint = cues.get("local_time") if isinstance(cues, Mapping) else None
if local_time_hint:
logging.info("Derived local mission time: %s", local_time_hint)
timezone_hint = cues.get("timezone") if isinstance(cues, Mapping) else None
if timezone_hint:
logging.info("Derived local timezone: %s", timezone_hint)
locality_hint = cues.get("nearest_locality") if isinstance(cues, Mapping) else None
if locality_hint:
logging.info("Reverse geocoded locality: %s", locality_hint)
inferred_time = _infer_time_of_day_from_cues(context, cues)
if inferred_time and context.time_of_day != inferred_time:
context = replace(context, time_of_day=inferred_time)
return _REASONER.plan(mission, context=context, cues=cues)
def _infer_time_of_day_from_cues(context: MissionContext, cues: Mapping[str, Any] | None) -> str | None:
if context.time_of_day or not cues:
return context.time_of_day
local_time_raw = cues.get("local_time") if isinstance(cues, Mapping) else None
if not local_time_raw:
return None
try:
local_dt = datetime.fromisoformat(str(local_time_raw))
except (ValueError, TypeError):
return None
hour = local_dt.hour
return "day" if 6 <= hour < 18 else "night"