|
|
""" |
|
|
Unit Tests for Utility Functions |
|
|
|
|
|
Tests for src/utils module including tool functions. |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
import json |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from unittest.mock import patch, MagicMock |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent.parent |
|
|
sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
|
|
|
|
|
|
class TestToolResponseParsing: |
|
|
"""Tests for parsing tool responses.""" |
|
|
|
|
|
def test_parse_valid_json_response(self): |
|
|
"""Test parsing valid JSON response.""" |
|
|
response = '{"status": "success", "data": {"temperature": 28}}' |
|
|
parsed = json.loads(response) |
|
|
|
|
|
assert parsed["status"] == "success" |
|
|
assert parsed["data"]["temperature"] == 28 |
|
|
|
|
|
def test_parse_error_response(self): |
|
|
"""Test parsing error response.""" |
|
|
response = '{"error": "API timeout", "solution": "Retry in 5 seconds"}' |
|
|
parsed = json.loads(response) |
|
|
|
|
|
assert "error" in parsed |
|
|
assert "solution" in parsed |
|
|
|
|
|
def test_handle_invalid_json(self): |
|
|
"""Test handling of invalid JSON.""" |
|
|
invalid_response = "Not valid JSON {" |
|
|
|
|
|
with pytest.raises(json.JSONDecodeError): |
|
|
json.loads(invalid_response) |
|
|
|
|
|
def test_handle_empty_response(self): |
|
|
"""Test handling of empty response.""" |
|
|
empty = "" |
|
|
|
|
|
with pytest.raises(json.JSONDecodeError): |
|
|
json.loads(empty) |
|
|
|
|
|
|
|
|
class TestDistrictMapping: |
|
|
"""Tests for Sri Lankan district mapping.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def district_list(self): |
|
|
"""List of Sri Lankan districts.""" |
|
|
return [ |
|
|
"Colombo", |
|
|
"Gampaha", |
|
|
"Kalutara", |
|
|
"Kandy", |
|
|
"Matale", |
|
|
"Nuwara Eliya", |
|
|
"Galle", |
|
|
"Matara", |
|
|
"Hambantota", |
|
|
"Jaffna", |
|
|
"Kilinochchi", |
|
|
"Mannar", |
|
|
"Batticaloa", |
|
|
"Ampara", |
|
|
"Trincomalee", |
|
|
"Kurunegala", |
|
|
"Puttalam", |
|
|
"Anuradhapura", |
|
|
"Polonnaruwa", |
|
|
"Badulla", |
|
|
"Monaragala", |
|
|
"Ratnapura", |
|
|
"Kegalle", |
|
|
] |
|
|
|
|
|
def test_district_count(self, district_list): |
|
|
"""Verify we have all 25 districts (or close to it).""" |
|
|
assert len(district_list) >= 23, "Should have at least 23 districts" |
|
|
|
|
|
def test_district_name_format(self, district_list): |
|
|
"""Verify district names are properly capitalized.""" |
|
|
for district in district_list: |
|
|
assert district[0].isupper(), f"District {district} should be capitalized" |
|
|
|
|
|
def test_major_districts_present(self, district_list): |
|
|
"""Verify major districts are present.""" |
|
|
major = ["Colombo", "Kandy", "Galle", "Jaffna"] |
|
|
for district in major: |
|
|
assert district in district_list |
|
|
|
|
|
|
|
|
class TestDataValidation: |
|
|
"""Tests for data validation functions.""" |
|
|
|
|
|
def test_validate_feed_item(self): |
|
|
"""Test feed item validation.""" |
|
|
valid_item = { |
|
|
"title": "Test Title", |
|
|
"summary": "Test summary", |
|
|
"source": "Test Source", |
|
|
"timestamp": "2024-01-01T00:00:00", |
|
|
} |
|
|
|
|
|
|
|
|
required_fields = ["title", "summary", "source"] |
|
|
for field in required_fields: |
|
|
assert field in valid_item |
|
|
|
|
|
def test_validate_missing_fields(self): |
|
|
"""Test detection of missing required fields.""" |
|
|
invalid_item = { |
|
|
"title": "Test Title" |
|
|
|
|
|
} |
|
|
|
|
|
required_fields = ["title", "summary", "source"] |
|
|
missing = [f for f in required_fields if f not in invalid_item] |
|
|
|
|
|
assert len(missing) == 2 |
|
|
assert "summary" in missing |
|
|
assert "source" in missing |
|
|
|
|
|
def test_sanitize_summary(self): |
|
|
"""Test summary text sanitization.""" |
|
|
|
|
|
def sanitize(text: str, max_length: int = 500) -> str: |
|
|
if not text: |
|
|
return "" |
|
|
|
|
|
text = " ".join(text.split()) |
|
|
|
|
|
if len(text) > max_length: |
|
|
text = text[: max_length - 3] + "..." |
|
|
return text |
|
|
|
|
|
|
|
|
assert sanitize("Hello World") == "Hello World" |
|
|
|
|
|
|
|
|
assert sanitize("Hello World") == "Hello World" |
|
|
|
|
|
|
|
|
long_text = "a" * 600 |
|
|
result = sanitize(long_text) |
|
|
assert len(result) == 500 |
|
|
assert result.endswith("...") |
|
|
|
|
|
|
|
|
class TestRiskScoring: |
|
|
"""Tests for risk scoring logic.""" |
|
|
|
|
|
def test_calculate_severity_score(self): |
|
|
"""Test severity score calculation.""" |
|
|
|
|
|
def calculate_severity(risk_type: str, confidence: float) -> float: |
|
|
severity_weights = { |
|
|
"Flood": 0.9, |
|
|
"Storm": 0.8, |
|
|
"Economic": 0.7, |
|
|
"Political": 0.6, |
|
|
"Social": 0.5, |
|
|
} |
|
|
base = severity_weights.get(risk_type, 0.5) |
|
|
return base * confidence |
|
|
|
|
|
|
|
|
assert calculate_severity("Flood", 0.9) == pytest.approx(0.81) |
|
|
|
|
|
|
|
|
assert calculate_severity("Social", 0.5) == pytest.approx(0.25) |
|
|
|
|
|
|
|
|
assert calculate_severity("Unknown", 1.0) == pytest.approx(0.5) |
|
|
|
|
|
def test_aggregate_risk_scores(self): |
|
|
"""Test aggregation of multiple risk scores.""" |
|
|
|
|
|
def aggregate(scores: list) -> dict: |
|
|
if not scores: |
|
|
return {"min": 0, "max": 0, "avg": 0} |
|
|
return { |
|
|
"min": min(scores), |
|
|
"max": max(scores), |
|
|
"avg": sum(scores) / len(scores), |
|
|
} |
|
|
|
|
|
scores = [0.3, 0.5, 0.7, 0.9] |
|
|
result = aggregate(scores) |
|
|
|
|
|
assert result["min"] == 0.3 |
|
|
assert result["max"] == 0.9 |
|
|
assert result["avg"] == pytest.approx(0.6) |
|
|
|
|
|
def test_empty_score_handling(self): |
|
|
"""Test handling of empty score list.""" |
|
|
|
|
|
def aggregate(scores: list) -> dict: |
|
|
if not scores: |
|
|
return {"min": 0, "max": 0, "avg": 0} |
|
|
return { |
|
|
"min": min(scores), |
|
|
"max": max(scores), |
|
|
"avg": sum(scores) / len(scores), |
|
|
} |
|
|
|
|
|
result = aggregate([]) |
|
|
assert result == {"min": 0, "max": 0, "avg": 0} |
|
|
|
|
|
|
|
|
class TestTimestampHandling: |
|
|
"""Tests for timestamp parsing and formatting.""" |
|
|
|
|
|
def test_parse_iso_timestamp(self): |
|
|
"""Test ISO timestamp parsing.""" |
|
|
from datetime import datetime |
|
|
|
|
|
iso_str = "2024-01-15T10:30:00" |
|
|
dt = datetime.fromisoformat(iso_str) |
|
|
|
|
|
assert dt.year == 2024 |
|
|
assert dt.month == 1 |
|
|
assert dt.day == 15 |
|
|
assert dt.hour == 10 |
|
|
assert dt.minute == 30 |
|
|
|
|
|
def test_format_timestamp(self): |
|
|
"""Test timestamp formatting.""" |
|
|
from datetime import datetime |
|
|
|
|
|
dt = datetime(2024, 1, 15, 10, 30, 0) |
|
|
formatted = dt.strftime("%Y-%m-%d %H:%M") |
|
|
|
|
|
assert formatted == "2024-01-15 10:30" |
|
|
|
|
|
def test_handle_invalid_timestamp(self): |
|
|
"""Test handling of invalid timestamps.""" |
|
|
from datetime import datetime |
|
|
|
|
|
invalid = "not a timestamp" |
|
|
|
|
|
with pytest.raises(ValueError): |
|
|
datetime.fromisoformat(invalid) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
pytest.main([__file__, "-v"]) |
|
|
|