File size: 15,915 Bytes
a8d3381
 
 
 
34b56b2
 
 
a8d3381
34b56b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4cb2d06
a8d3381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34b56b2
a8d3381
34b56b2
 
 
 
 
 
 
a8d3381
 
 
34b56b2
 
 
 
 
 
 
 
 
a8d3381
34b56b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8d3381
34b56b2
 
 
 
 
 
 
 
a8d3381
4cb2d06
34b56b2
 
 
 
 
 
 
 
 
a8d3381
 
34b56b2
a8d3381
 
 
 
 
 
 
 
 
 
 
 
 
 
34b56b2
 
 
 
 
 
a8d3381
 
 
 
 
 
 
34b56b2
 
 
 
 
a8d3381
34b56b2
 
 
a8d3381
 
 
 
 
 
34b56b2
a8d3381
34b56b2
 
 
a8d3381
34b56b2
 
 
 
a8d3381
 
34b56b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8d3381
 
 
 
34b56b2
 
 
 
 
 
 
a8d3381
 
 
34b56b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
from __future__ import annotations

import json
import logging
from dataclasses import asdict, dataclass, replace
from datetime import datetime
from typing import Any, Dict, List, Mapping, Tuple

from coco_classes import canonicalize_coco_name, coco_class_catalog
from mission_context import (
    MissionClass,
    MissionContext,
    MissionPlan,
    MISSION_TYPE_OPTIONS,
    LOCATION_TYPE_OPTIONS,
    TIME_OF_DAY_OPTIONS,
    PRIORITY_LEVEL_OPTIONS,
    PipelineRecommendation,
    build_prompt_hints,
)
from pipeline_registry import (
    PIPELINE_SPECS,
    fallback_pipeline_for_context,
    filter_pipelines_for_context,
    get_pipeline_spec,
)
from prompt import mission_planner_system_prompt, mission_planner_user_prompt
from utils.openai_client import get_openai_client


DEFAULT_OPENAI_MODEL = "gpt-4o-mini"


class MissionReasoner:
    def __init__(
        self,
        *,
        model_name: str = DEFAULT_OPENAI_MODEL,
        top_k: int = 10,
    ) -> None:
        self._model_name = model_name
        self._top_k = top_k
        self._coco_catalog = coco_class_catalog()

    def plan(
        self,
        mission: str,
        *,
        context: MissionContext,
        cues: Mapping[str, Any] | None = None,
    ) -> MissionPlan:
        mission = (mission or "").strip()
        if not mission:
            raise ValueError("Mission prompt cannot be empty.")
        available_pipelines = self._candidate_pipelines(mission, context, cues)
        candidate_ids = [spec["id"] for spec in available_pipelines] or [PIPELINE_SPECS[0]["id"]]
        lock_pipeline_id = candidate_ids[0] if len(candidate_ids) == 1 else None
        response_payload = self._query_llm(
            mission,
            context=context,
            cues=None,
            pipeline_ids=candidate_ids,
        )
        relevant = self._parse_plan(response_payload, fallback_mission=mission)
        enriched_context = self._merge_context(context, response_payload.get("context"))
        if lock_pipeline_id:
            pipeline_rec = PipelineRecommendation(
                primary_id=lock_pipeline_id,
                primary_reason="Only pipeline compatible with mission context.",
            )
        else:
            pipeline_rec = self._parse_pipeline_recommendation(
                response_payload.get("pipelines") or response_payload.get("pipeline"),
                available_pipelines,
                context,
            )
        return MissionPlan(
            mission=response_payload.get("mission", mission),
            relevant_classes=relevant[: self._top_k],
            context=enriched_context,
            pipeline=pipeline_rec,
        )

    def _render_pipeline_catalog(self, specs: List[Dict[str, object]]) -> str:
        if not specs:
            return "No compatible pipelines available."
        sections: List[str] = []
        for spec in specs:
            reason = spec.get("availability_reason") or "Compatible with mission context."
            hf_bindings = spec.get("huggingface") or {}

            def _format_models(models: List[Dict[str, object]]) -> str:
                if not models:
                    return "none"
                labels = []
                for entry in models:
                    model_id = entry.get("model_id") or entry.get("name") or "unknown"
                    label = entry.get("label") or model_id
                    suffix = " (optional)" if entry.get("optional") else ""
                    labels.append(f"{label}{suffix}")
                return ", ".join(labels)

            detection_models = _format_models(hf_bindings.get("detection", []))
            segmentation_models = _format_models(hf_bindings.get("segmentation", []))
            tracking_models = _format_models(hf_bindings.get("tracking", []))
            hf_notes = hf_bindings.get("notes") or ""
            sections.append(
                "\n".join(
                    [
                        f"{spec['id']} pipeline",
                        f"  Modalities: {', '.join(spec.get('modalities', ())) or 'unspecified'}",
                        f"  Locations: {', '.join(spec.get('location_types', ())) or 'any'}",
                        f"  Time of day: {', '.join(spec.get('time_of_day', ())) or 'any'}",
                        f"  Availability: {reason}",
                        f"  HF detection: {detection_models}",
                        f"  HF segmentation: {segmentation_models}",
                        f"  Tracking: {tracking_models}",
                        f"  Notes: {hf_notes or 'n/a'}",
                    ]
                )
            )
        return "\n\n".join(sections)

    def _candidate_pipelines(
        self,
        mission: str,
        context: MissionContext,
        cues: Mapping[str, Any] | None,
    ) -> List[Dict[str, object]]:
        filtered = filter_pipelines_for_context(context)
        if filtered:
            return filtered
        fallback_spec = fallback_pipeline_for_context(context, [])
        if fallback_spec is None:
            logging.error("No fallback pipeline available; mission context=%s", context)
            return [dict(spec) for spec in PIPELINE_SPECS]
        logging.warning(
            "No compatible pipelines for context %s; selecting fallback %s.",
            context,
            fallback_spec["id"],
        )
        fallback_copy = dict(fallback_spec)
        fallback_copy["availability_reason"] = (
            "Fallback engaged because no specialized pipeline matched this mission context."
        )
        return [fallback_copy]

    def _query_llm(
        self,
        mission: str,
        *,
        context: MissionContext,
        cues: Mapping[str, Any] | None = None,
        pipeline_ids: List[str] | None,
    ) -> Dict[str, object]:
        client = get_openai_client()
        system_prompt = mission_planner_system_prompt()
        context_payload = context.to_prompt_payload()
        user_prompt = mission_planner_user_prompt(
            mission,
            self._top_k,
            context=context_payload,
            cues=cues,
            pipeline_candidates=pipeline_ids,
            coco_catalog=self._coco_catalog,
        )
        completion = client.chat.completions.create(
            model=self._model_name,
            temperature=0.1,
            response_format={"type": "json_object"},
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ],
        )
        content = completion.choices[0].message.content or "{}"
        try:
            return json.loads(content)
        except json.JSONDecodeError:
            logging.exception("LLM returned non-JSON content: %s", content)
            return {"mission": mission, "classes": []}

    def _parse_plan(self, payload: Dict[str, object], fallback_mission: str) -> List[MissionClass]:
        entries = (
            payload.get("entities")
            or payload.get("classes")
            or payload.get("relevant_classes")
            or []
        )
        mission = payload.get("mission") or fallback_mission
        parsed: List[MissionClass] = []
        seen = set()
        for entry in entries:
            if not isinstance(entry, dict):
                continue
            name = str(entry.get("name") or "").strip()
            if not name:
                continue
            canonical_name = canonicalize_coco_name(name)
            if not canonical_name:
                logging.warning("Skipping non-COCO entity '%s'.", name)
                continue
            if canonical_name in seen:
                continue
            seen.add(canonical_name)
            score_raw = entry.get("score")
            try:
                score = float(score_raw)
            except (TypeError, ValueError):
                score = 0.5
            rationale = str(entry.get("rationale") or f"Track '{name}' for mission '{mission}'.")
            parsed.append(
                MissionClass(
                    name=canonical_name,
                    score=max(0.0, min(1.0, score)),
                    rationale=rationale,
                )
            )

        if not parsed:
            raise RuntimeError("LLM returned no semantic entities; aborting instead of fabricating outputs.")
        return parsed

    def _merge_context(
        self,
        base_context: MissionContext,
        context_payload: Dict[str, object] | None,
    ) -> MissionContext:
        payload = context_payload or {}
        if not isinstance(payload, dict):
            return base_context

        def _coerce_choice(value: object | None, allowed: Tuple[str, ...]) -> str | None:
            if value is None:
                return None
            candidate = str(value).strip().lower()
            return candidate if candidate in allowed else None

        updates: Dict[str, Any] = {}
        new_mission_type = _coerce_choice(payload.get("mission_type"), MISSION_TYPE_OPTIONS)
        new_location_type = _coerce_choice(payload.get("location_type"), LOCATION_TYPE_OPTIONS)
        new_time_of_day = _coerce_choice(payload.get("time_of_day"), TIME_OF_DAY_OPTIONS)
        new_priority = _coerce_choice(payload.get("priority_level"), PRIORITY_LEVEL_OPTIONS)

        if new_mission_type:
            updates["mission_type"] = new_mission_type
        if new_location_type:
            updates["location_type"] = new_location_type
        if new_time_of_day:
            updates["time_of_day"] = new_time_of_day
        if new_priority:
            updates["priority_level"] = new_priority

        if not updates:
            return base_context
        return replace(base_context, **updates)

    def _parse_pipeline_recommendation(
        self,
        payload: object,
        available_specs: List[Dict[str, object]],
        context: MissionContext,
    ) -> PipelineRecommendation | None:
        if not isinstance(payload, dict):
            return self._validate_pipeline_selection(None, available_specs, context)

        if "id" in payload or "pipeline_id" in payload or "pipeline" in payload:
            pipeline_id_raw = payload.get("id") or payload.get("pipeline_id") or payload.get("pipeline")
            pipeline_id = str(pipeline_id_raw or "").strip()
            reason = str(payload.get("reason") or "").strip() or None
            candidate = PipelineRecommendation(primary_id=pipeline_id or None, primary_reason=reason)
            return self._validate_pipeline_selection(candidate, available_specs, context)

        def _extract_entry(entry_key: str) -> tuple[str | None, str | None]:
            value = payload.get(entry_key)
            if not isinstance(value, dict):
                return None, None
            pipeline_id_raw = value.get("id") or value.get("pipeline_id") or value.get("pipeline")
            pipeline_id = str(pipeline_id_raw).strip()
            if not pipeline_id:
                return None, None
            if not get_pipeline_spec(pipeline_id):
                return None, None
            reason = str(value.get("reason") or "").strip() or None
            return pipeline_id, reason

        primary_id, primary_reason = _extract_entry("primary")
        fallback_id, fallback_reason = _extract_entry("fallback")

        rec = PipelineRecommendation(
            primary_id=primary_id,
            primary_reason=primary_reason,
            fallback_id=fallback_id,
            fallback_reason=fallback_reason,
        )
        return self._validate_pipeline_selection(rec, available_specs, context)

    def _validate_pipeline_selection(
        self,
        candidate: PipelineRecommendation | None,
        available_specs: List[Dict[str, object]],
        context: MissionContext,
    ) -> PipelineRecommendation | None:
        if not available_specs:
            return None
        available_ids = {spec["id"] for spec in available_specs}

        def _normalize_reason(reason: str | None, default: str) -> str:
            text = (reason or "").strip()
            return text or default

        primary_id = candidate.primary_id if candidate and candidate.primary_id in available_ids else None
        if not primary_id:
            fallback_spec = fallback_pipeline_for_context(context, available_specs)
            if fallback_spec is None:
                logging.warning("No pipelines available even after fallback.")
                return None
            logging.warning(
                "Pipeline recommendation invalid or missing. Defaulting to %s.", fallback_spec["id"]
            )
            return PipelineRecommendation(
                primary_id=fallback_spec["id"],
                primary_reason=_normalize_reason(
                    candidate.primary_reason if candidate else None,
                    "Auto-selected based on available sensors and context.",
                ),
                fallback_id=None,
                fallback_reason=None,
            )

        primary_reason = _normalize_reason(candidate.primary_reason if candidate else None, "LLM-selected.")

        fallback_allowed = context.priority_level in {"elevated", "high"}
        fallback_id = candidate.fallback_id if candidate else None
        fallback_reason = candidate.fallback_reason if candidate else None

        if not fallback_allowed or fallback_id not in available_ids or fallback_id == primary_id:
            if fallback_id:
                logging.info("Dropping fallback pipeline %s due to priority/context constraints.", fallback_id)
            fallback_id_valid = None
            fallback_reason_valid = None
        else:
            fallback_id_valid = fallback_id
            fallback_reason_valid = _normalize_reason(fallback_reason, "Fallback allowed due to priority level.")

        return PipelineRecommendation(
            primary_id=primary_id,
            primary_reason=primary_reason,
            fallback_id=fallback_id_valid,
            fallback_reason=fallback_reason_valid,
        )


_REASONER: MissionReasoner | None = None


def get_mission_plan(
    mission: str,
    *,
    latitude: float | None = None,
    longitude: float | None = None,
    context_overrides: MissionContext | None = None,
) -> MissionPlan:
    global _REASONER
    if _REASONER is None:
        _REASONER = MissionReasoner()
    context = context_overrides or MissionContext()
    cues = build_prompt_hints(mission, latitude, longitude)
    if latitude is not None and longitude is not None:
        logging.info("Mission location coordinates: lat=%s, lon=%s", latitude, longitude)
    local_time_hint = cues.get("local_time") if isinstance(cues, Mapping) else None
    if local_time_hint:
        logging.info("Derived local mission time: %s", local_time_hint)
    timezone_hint = cues.get("timezone") if isinstance(cues, Mapping) else None
    if timezone_hint:
        logging.info("Derived local timezone: %s", timezone_hint)
    locality_hint = cues.get("nearest_locality") if isinstance(cues, Mapping) else None
    if locality_hint:
        logging.info("Reverse geocoded locality: %s", locality_hint)
    inferred_time = _infer_time_of_day_from_cues(context, cues)
    if inferred_time and context.time_of_day != inferred_time:
        context = replace(context, time_of_day=inferred_time)
    return _REASONER.plan(mission, context=context, cues=cues)


def _infer_time_of_day_from_cues(context: MissionContext, cues: Mapping[str, Any] | None) -> str | None:
    if context.time_of_day or not cues:
        return context.time_of_day
    local_time_raw = cues.get("local_time") if isinstance(cues, Mapping) else None
    if not local_time_raw:
        return None
    try:
        local_dt = datetime.fromisoformat(str(local_time_raw))
    except (ValueError, TypeError):
        return None
    hour = local_dt.hour
    return "day" if 6 <= hour < 18 else "night"