File size: 2,300 Bytes
c8dfbc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08d13a6
 
c8dfbc0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Any, Dict, List, Tuple
import numpy as np
from sklearn.metrics import mean_squared_error, roc_auc_score


def _all_sentence_keys(
    docs_sentences: List[List[Tuple[str, str]]]
) -> List[str]:
    keys: List[str] = []
    for doc in docs_sentences:
        for key, _ in doc:
            keys.append(key)
    return keys


def trace_from_attributes(
    attrs: Dict[str, Any],
    docs_sentences: List[List[Tuple[str, str]]],
) -> Dict[str, float]:
    all_keys = _all_sentence_keys(docs_sentences)
    total = len(all_keys)
    if total == 0:
        return {
            "relevance": 0.0,
            "utilization": 0.0,
            "completeness": 0.0,
            "adherence": 0.0,
        }

    relevant = set(attrs.get("all_relevant_sentence_keys", [])) & set(all_keys)
    utilized = set(attrs.get("all_utilized_sentence_keys", [])) & set(all_keys)

    relevance = len(relevant) / total if total > 0 else 0.0
    utilization = len(utilized) / total if total > 0 else 0.0
    completeness = (
        len(relevant & utilized) / len(relevant) if relevant else 0.0
    )
    adherence = 1.0 if attrs.get("overall_supported", False) else 0.0

    return {
        "relevance": float(relevance),
        "utilization": float(utilization),
        "completeness": float(completeness),
        "adherence": float(adherence),
    }


def compute_rmse_auc(
    y_true_rel: List[float],
    y_pred_rel: List[float],
    y_true_util: List[float],
    y_pred_util: List[float],
    y_true_comp: List[float],
    y_pred_comp: List[float],
    y_true_adh: List[int],
    y_pred_adh: List[float],
) -> Dict[str, float]:
    metrics = {
        "rmse_relevance": float(
            mean_squared_error(y_true_rel, y_pred_rel, squared=False)
        ),
        "rmse_utilization": float(
            mean_squared_error(y_true_util, y_pred_util, squared=False)
        ),
        "rmse_completeness": float(
            mean_squared_error(y_true_comp, y_pred_comp, squared=False)
        ),
    }

    if len(set(y_true_adh)) > 1:
        metrics["auroc_adherence"] = float(
            roc_auc_score(y_true_adh, y_pred_adh)
        )
    else:
        #metrics["auroc_adherence"] = float("nan")
        metrics["auroc_adherence"] = 0.5  # or None, but not float("nan")

    return metrics