Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import logging | |
| from dataclasses import asdict, dataclass, replace | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Mapping, Tuple | |
| from coco_classes import canonicalize_coco_name, coco_class_catalog | |
| from mission_context import ( | |
| MissionClass, | |
| MissionContext, | |
| MissionPlan, | |
| MISSION_TYPE_OPTIONS, | |
| LOCATION_TYPE_OPTIONS, | |
| TIME_OF_DAY_OPTIONS, | |
| PRIORITY_LEVEL_OPTIONS, | |
| PipelineRecommendation, | |
| build_prompt_hints, | |
| ) | |
| from pipeline_registry import ( | |
| PIPELINE_SPECS, | |
| fallback_pipeline_for_context, | |
| filter_pipelines_for_context, | |
| get_pipeline_spec, | |
| ) | |
| from prompt import mission_planner_system_prompt, mission_planner_user_prompt | |
| from utils.openai_client import get_openai_client | |
| DEFAULT_OPENAI_MODEL = "gpt-4o-mini" | |
| class MissionReasoner: | |
| def __init__( | |
| self, | |
| *, | |
| model_name: str = DEFAULT_OPENAI_MODEL, | |
| top_k: int = 10, | |
| ) -> None: | |
| self._model_name = model_name | |
| self._top_k = top_k | |
| self._coco_catalog = coco_class_catalog() | |
| def plan( | |
| self, | |
| mission: str, | |
| *, | |
| context: MissionContext, | |
| cues: Mapping[str, Any] | None = None, | |
| ) -> MissionPlan: | |
| mission = (mission or "").strip() | |
| if not mission: | |
| raise ValueError("Mission prompt cannot be empty.") | |
| available_pipelines = self._candidate_pipelines(mission, context, cues) | |
| candidate_ids = [spec["id"] for spec in available_pipelines] or [PIPELINE_SPECS[0]["id"]] | |
| lock_pipeline_id = candidate_ids[0] if len(candidate_ids) == 1 else None | |
| response_payload = self._query_llm( | |
| mission, | |
| context=context, | |
| cues=None, | |
| pipeline_ids=candidate_ids, | |
| ) | |
| relevant = self._parse_plan(response_payload, fallback_mission=mission) | |
| enriched_context = self._merge_context(context, response_payload.get("context")) | |
| if lock_pipeline_id: | |
| pipeline_rec = PipelineRecommendation( | |
| primary_id=lock_pipeline_id, | |
| primary_reason="Only pipeline compatible with mission context.", | |
| ) | |
| else: | |
| pipeline_rec = self._parse_pipeline_recommendation( | |
| response_payload.get("pipelines") or response_payload.get("pipeline"), | |
| available_pipelines, | |
| context, | |
| ) | |
| return MissionPlan( | |
| mission=response_payload.get("mission", mission), | |
| relevant_classes=relevant[: self._top_k], | |
| context=enriched_context, | |
| pipeline=pipeline_rec, | |
| ) | |
| def _render_pipeline_catalog(self, specs: List[Dict[str, object]]) -> str: | |
| if not specs: | |
| return "No compatible pipelines available." | |
| sections: List[str] = [] | |
| for spec in specs: | |
| reason = spec.get("availability_reason") or "Compatible with mission context." | |
| hf_bindings = spec.get("huggingface") or {} | |
| def _format_models(models: List[Dict[str, object]]) -> str: | |
| if not models: | |
| return "none" | |
| labels = [] | |
| for entry in models: | |
| model_id = entry.get("model_id") or entry.get("name") or "unknown" | |
| label = entry.get("label") or model_id | |
| suffix = " (optional)" if entry.get("optional") else "" | |
| labels.append(f"{label}{suffix}") | |
| return ", ".join(labels) | |
| detection_models = _format_models(hf_bindings.get("detection", [])) | |
| segmentation_models = _format_models(hf_bindings.get("segmentation", [])) | |
| tracking_models = _format_models(hf_bindings.get("tracking", [])) | |
| hf_notes = hf_bindings.get("notes") or "" | |
| sections.append( | |
| "\n".join( | |
| [ | |
| f"{spec['id']} pipeline", | |
| f" Modalities: {', '.join(spec.get('modalities', ())) or 'unspecified'}", | |
| f" Locations: {', '.join(spec.get('location_types', ())) or 'any'}", | |
| f" Time of day: {', '.join(spec.get('time_of_day', ())) or 'any'}", | |
| f" Availability: {reason}", | |
| f" HF detection: {detection_models}", | |
| f" HF segmentation: {segmentation_models}", | |
| f" Tracking: {tracking_models}", | |
| f" Notes: {hf_notes or 'n/a'}", | |
| ] | |
| ) | |
| ) | |
| return "\n\n".join(sections) | |
| def _candidate_pipelines( | |
| self, | |
| mission: str, | |
| context: MissionContext, | |
| cues: Mapping[str, Any] | None, | |
| ) -> List[Dict[str, object]]: | |
| filtered = filter_pipelines_for_context(context) | |
| if filtered: | |
| return filtered | |
| fallback_spec = fallback_pipeline_for_context(context, []) | |
| if fallback_spec is None: | |
| logging.error("No fallback pipeline available; mission context=%s", context) | |
| return [dict(spec) for spec in PIPELINE_SPECS] | |
| logging.warning( | |
| "No compatible pipelines for context %s; selecting fallback %s.", | |
| context, | |
| fallback_spec["id"], | |
| ) | |
| fallback_copy = dict(fallback_spec) | |
| fallback_copy["availability_reason"] = ( | |
| "Fallback engaged because no specialized pipeline matched this mission context." | |
| ) | |
| return [fallback_copy] | |
| def _query_llm( | |
| self, | |
| mission: str, | |
| *, | |
| context: MissionContext, | |
| cues: Mapping[str, Any] | None = None, | |
| pipeline_ids: List[str] | None, | |
| ) -> Dict[str, object]: | |
| client = get_openai_client() | |
| system_prompt = mission_planner_system_prompt() | |
| context_payload = context.to_prompt_payload() | |
| user_prompt = mission_planner_user_prompt( | |
| mission, | |
| self._top_k, | |
| context=context_payload, | |
| cues=cues, | |
| pipeline_candidates=pipeline_ids, | |
| coco_catalog=self._coco_catalog, | |
| ) | |
| completion = client.chat.completions.create( | |
| model=self._model_name, | |
| temperature=0.1, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| content = completion.choices[0].message.content or "{}" | |
| try: | |
| return json.loads(content) | |
| except json.JSONDecodeError: | |
| logging.exception("LLM returned non-JSON content: %s", content) | |
| return {"mission": mission, "classes": []} | |
| def _parse_plan(self, payload: Dict[str, object], fallback_mission: str) -> List[MissionClass]: | |
| entries = ( | |
| payload.get("entities") | |
| or payload.get("classes") | |
| or payload.get("relevant_classes") | |
| or [] | |
| ) | |
| mission = payload.get("mission") or fallback_mission | |
| parsed: List[MissionClass] = [] | |
| seen = set() | |
| for entry in entries: | |
| if not isinstance(entry, dict): | |
| continue | |
| name = str(entry.get("name") or "").strip() | |
| if not name: | |
| continue | |
| canonical_name = canonicalize_coco_name(name) | |
| if not canonical_name: | |
| logging.warning("Skipping non-COCO entity '%s'.", name) | |
| continue | |
| if canonical_name in seen: | |
| continue | |
| seen.add(canonical_name) | |
| score_raw = entry.get("score") | |
| try: | |
| score = float(score_raw) | |
| except (TypeError, ValueError): | |
| score = 0.5 | |
| rationale = str(entry.get("rationale") or f"Track '{name}' for mission '{mission}'.") | |
| parsed.append( | |
| MissionClass( | |
| name=canonical_name, | |
| score=max(0.0, min(1.0, score)), | |
| rationale=rationale, | |
| ) | |
| ) | |
| if not parsed: | |
| raise RuntimeError("LLM returned no semantic entities; aborting instead of fabricating outputs.") | |
| return parsed | |
| def _merge_context( | |
| self, | |
| base_context: MissionContext, | |
| context_payload: Dict[str, object] | None, | |
| ) -> MissionContext: | |
| payload = context_payload or {} | |
| if not isinstance(payload, dict): | |
| return base_context | |
| def _coerce_choice(value: object | None, allowed: Tuple[str, ...]) -> str | None: | |
| if value is None: | |
| return None | |
| candidate = str(value).strip().lower() | |
| return candidate if candidate in allowed else None | |
| updates: Dict[str, Any] = {} | |
| new_mission_type = _coerce_choice(payload.get("mission_type"), MISSION_TYPE_OPTIONS) | |
| new_location_type = _coerce_choice(payload.get("location_type"), LOCATION_TYPE_OPTIONS) | |
| new_time_of_day = _coerce_choice(payload.get("time_of_day"), TIME_OF_DAY_OPTIONS) | |
| new_priority = _coerce_choice(payload.get("priority_level"), PRIORITY_LEVEL_OPTIONS) | |
| if new_mission_type: | |
| updates["mission_type"] = new_mission_type | |
| if new_location_type: | |
| updates["location_type"] = new_location_type | |
| if new_time_of_day: | |
| updates["time_of_day"] = new_time_of_day | |
| if new_priority: | |
| updates["priority_level"] = new_priority | |
| if not updates: | |
| return base_context | |
| return replace(base_context, **updates) | |
| def _parse_pipeline_recommendation( | |
| self, | |
| payload: object, | |
| available_specs: List[Dict[str, object]], | |
| context: MissionContext, | |
| ) -> PipelineRecommendation | None: | |
| if not isinstance(payload, dict): | |
| return self._validate_pipeline_selection(None, available_specs, context) | |
| if "id" in payload or "pipeline_id" in payload or "pipeline" in payload: | |
| pipeline_id_raw = payload.get("id") or payload.get("pipeline_id") or payload.get("pipeline") | |
| pipeline_id = str(pipeline_id_raw or "").strip() | |
| reason = str(payload.get("reason") or "").strip() or None | |
| candidate = PipelineRecommendation(primary_id=pipeline_id or None, primary_reason=reason) | |
| return self._validate_pipeline_selection(candidate, available_specs, context) | |
| def _extract_entry(entry_key: str) -> tuple[str | None, str | None]: | |
| value = payload.get(entry_key) | |
| if not isinstance(value, dict): | |
| return None, None | |
| pipeline_id_raw = value.get("id") or value.get("pipeline_id") or value.get("pipeline") | |
| pipeline_id = str(pipeline_id_raw).strip() | |
| if not pipeline_id: | |
| return None, None | |
| if not get_pipeline_spec(pipeline_id): | |
| return None, None | |
| reason = str(value.get("reason") or "").strip() or None | |
| return pipeline_id, reason | |
| primary_id, primary_reason = _extract_entry("primary") | |
| fallback_id, fallback_reason = _extract_entry("fallback") | |
| rec = PipelineRecommendation( | |
| primary_id=primary_id, | |
| primary_reason=primary_reason, | |
| fallback_id=fallback_id, | |
| fallback_reason=fallback_reason, | |
| ) | |
| return self._validate_pipeline_selection(rec, available_specs, context) | |
| def _validate_pipeline_selection( | |
| self, | |
| candidate: PipelineRecommendation | None, | |
| available_specs: List[Dict[str, object]], | |
| context: MissionContext, | |
| ) -> PipelineRecommendation | None: | |
| if not available_specs: | |
| return None | |
| available_ids = {spec["id"] for spec in available_specs} | |
| def _normalize_reason(reason: str | None, default: str) -> str: | |
| text = (reason or "").strip() | |
| return text or default | |
| primary_id = candidate.primary_id if candidate and candidate.primary_id in available_ids else None | |
| if not primary_id: | |
| fallback_spec = fallback_pipeline_for_context(context, available_specs) | |
| if fallback_spec is None: | |
| logging.warning("No pipelines available even after fallback.") | |
| return None | |
| logging.warning( | |
| "Pipeline recommendation invalid or missing. Defaulting to %s.", fallback_spec["id"] | |
| ) | |
| return PipelineRecommendation( | |
| primary_id=fallback_spec["id"], | |
| primary_reason=_normalize_reason( | |
| candidate.primary_reason if candidate else None, | |
| "Auto-selected based on available sensors and context.", | |
| ), | |
| fallback_id=None, | |
| fallback_reason=None, | |
| ) | |
| primary_reason = _normalize_reason(candidate.primary_reason if candidate else None, "LLM-selected.") | |
| fallback_allowed = context.priority_level in {"elevated", "high"} | |
| fallback_id = candidate.fallback_id if candidate else None | |
| fallback_reason = candidate.fallback_reason if candidate else None | |
| if not fallback_allowed or fallback_id not in available_ids or fallback_id == primary_id: | |
| if fallback_id: | |
| logging.info("Dropping fallback pipeline %s due to priority/context constraints.", fallback_id) | |
| fallback_id_valid = None | |
| fallback_reason_valid = None | |
| else: | |
| fallback_id_valid = fallback_id | |
| fallback_reason_valid = _normalize_reason(fallback_reason, "Fallback allowed due to priority level.") | |
| return PipelineRecommendation( | |
| primary_id=primary_id, | |
| primary_reason=primary_reason, | |
| fallback_id=fallback_id_valid, | |
| fallback_reason=fallback_reason_valid, | |
| ) | |
| _REASONER: MissionReasoner | None = None | |
| def get_mission_plan( | |
| mission: str, | |
| *, | |
| latitude: float | None = None, | |
| longitude: float | None = None, | |
| context_overrides: MissionContext | None = None, | |
| ) -> MissionPlan: | |
| global _REASONER | |
| if _REASONER is None: | |
| _REASONER = MissionReasoner() | |
| context = context_overrides or MissionContext() | |
| cues = build_prompt_hints(mission, latitude, longitude) | |
| if latitude is not None and longitude is not None: | |
| logging.info("Mission location coordinates: lat=%s, lon=%s", latitude, longitude) | |
| local_time_hint = cues.get("local_time") if isinstance(cues, Mapping) else None | |
| if local_time_hint: | |
| logging.info("Derived local mission time: %s", local_time_hint) | |
| timezone_hint = cues.get("timezone") if isinstance(cues, Mapping) else None | |
| if timezone_hint: | |
| logging.info("Derived local timezone: %s", timezone_hint) | |
| locality_hint = cues.get("nearest_locality") if isinstance(cues, Mapping) else None | |
| if locality_hint: | |
| logging.info("Reverse geocoded locality: %s", locality_hint) | |
| inferred_time = _infer_time_of_day_from_cues(context, cues) | |
| if inferred_time and context.time_of_day != inferred_time: | |
| context = replace(context, time_of_day=inferred_time) | |
| return _REASONER.plan(mission, context=context, cues=cues) | |
| def _infer_time_of_day_from_cues(context: MissionContext, cues: Mapping[str, Any] | None) -> str | None: | |
| if context.time_of_day or not cues: | |
| return context.time_of_day | |
| local_time_raw = cues.get("local_time") if isinstance(cues, Mapping) else None | |
| if not local_time_raw: | |
| return None | |
| try: | |
| local_dt = datetime.fromisoformat(str(local_time_raw)) | |
| except (ValueError, TypeError): | |
| return None | |
| hour = local_dt.hour | |
| return "day" if 6 <= hour < 18 else "night" | |