modelx / tests /unit /test_utils.py
nivakaran's picture
Upload folder using huggingface_hub
752f5cc verified
"""
Unit Tests for Utility Functions
Tests for src/utils module including tool functions.
"""
import pytest
import json
import sys
from pathlib import Path
from unittest.mock import patch, MagicMock
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
class TestToolResponseParsing:
"""Tests for parsing tool responses."""
def test_parse_valid_json_response(self):
"""Test parsing valid JSON response."""
response = '{"status": "success", "data": {"temperature": 28}}'
parsed = json.loads(response)
assert parsed["status"] == "success"
assert parsed["data"]["temperature"] == 28
def test_parse_error_response(self):
"""Test parsing error response."""
response = '{"error": "API timeout", "solution": "Retry in 5 seconds"}'
parsed = json.loads(response)
assert "error" in parsed
assert "solution" in parsed
def test_handle_invalid_json(self):
"""Test handling of invalid JSON."""
invalid_response = "Not valid JSON {"
with pytest.raises(json.JSONDecodeError):
json.loads(invalid_response)
def test_handle_empty_response(self):
"""Test handling of empty response."""
empty = ""
with pytest.raises(json.JSONDecodeError):
json.loads(empty)
class TestDistrictMapping:
"""Tests for Sri Lankan district mapping."""
@pytest.fixture
def district_list(self):
"""List of Sri Lankan districts."""
return [
"Colombo",
"Gampaha",
"Kalutara",
"Kandy",
"Matale",
"Nuwara Eliya",
"Galle",
"Matara",
"Hambantota",
"Jaffna",
"Kilinochchi",
"Mannar",
"Batticaloa",
"Ampara",
"Trincomalee",
"Kurunegala",
"Puttalam",
"Anuradhapura",
"Polonnaruwa",
"Badulla",
"Monaragala",
"Ratnapura",
"Kegalle",
]
def test_district_count(self, district_list):
"""Verify we have all 25 districts (or close to it)."""
assert len(district_list) >= 23, "Should have at least 23 districts"
def test_district_name_format(self, district_list):
"""Verify district names are properly capitalized."""
for district in district_list:
assert district[0].isupper(), f"District {district} should be capitalized"
def test_major_districts_present(self, district_list):
"""Verify major districts are present."""
major = ["Colombo", "Kandy", "Galle", "Jaffna"]
for district in major:
assert district in district_list
class TestDataValidation:
"""Tests for data validation functions."""
def test_validate_feed_item(self):
"""Test feed item validation."""
valid_item = {
"title": "Test Title",
"summary": "Test summary",
"source": "Test Source",
"timestamp": "2024-01-01T00:00:00",
}
# Required fields present
required_fields = ["title", "summary", "source"]
for field in required_fields:
assert field in valid_item
def test_validate_missing_fields(self):
"""Test detection of missing required fields."""
invalid_item = {
"title": "Test Title"
# Missing summary and source
}
required_fields = ["title", "summary", "source"]
missing = [f for f in required_fields if f not in invalid_item]
assert len(missing) == 2
assert "summary" in missing
assert "source" in missing
def test_sanitize_summary(self):
"""Test summary text sanitization."""
def sanitize(text: str, max_length: int = 500) -> str:
if not text:
return ""
# Remove extra whitespace
text = " ".join(text.split())
# Truncate if too long
if len(text) > max_length:
text = text[: max_length - 3] + "..."
return text
# Test normal text
assert sanitize("Hello World") == "Hello World"
# Test whitespace normalization
assert sanitize("Hello World") == "Hello World"
# Test truncation
long_text = "a" * 600
result = sanitize(long_text)
assert len(result) == 500
assert result.endswith("...")
class TestRiskScoring:
"""Tests for risk scoring logic."""
def test_calculate_severity_score(self):
"""Test severity score calculation."""
def calculate_severity(risk_type: str, confidence: float) -> float:
severity_weights = {
"Flood": 0.9,
"Storm": 0.8,
"Economic": 0.7,
"Political": 0.6,
"Social": 0.5,
}
base = severity_weights.get(risk_type, 0.5)
return base * confidence
# High priority risk
assert calculate_severity("Flood", 0.9) == pytest.approx(0.81)
# Low priority risk
assert calculate_severity("Social", 0.5) == pytest.approx(0.25)
# Unknown risk type
assert calculate_severity("Unknown", 1.0) == pytest.approx(0.5)
def test_aggregate_risk_scores(self):
"""Test aggregation of multiple risk scores."""
def aggregate(scores: list) -> dict:
if not scores:
return {"min": 0, "max": 0, "avg": 0}
return {
"min": min(scores),
"max": max(scores),
"avg": sum(scores) / len(scores),
}
scores = [0.3, 0.5, 0.7, 0.9]
result = aggregate(scores)
assert result["min"] == 0.3
assert result["max"] == 0.9
assert result["avg"] == pytest.approx(0.6)
def test_empty_score_handling(self):
"""Test handling of empty score list."""
def aggregate(scores: list) -> dict:
if not scores:
return {"min": 0, "max": 0, "avg": 0}
return {
"min": min(scores),
"max": max(scores),
"avg": sum(scores) / len(scores),
}
result = aggregate([])
assert result == {"min": 0, "max": 0, "avg": 0}
class TestTimestampHandling:
"""Tests for timestamp parsing and formatting."""
def test_parse_iso_timestamp(self):
"""Test ISO timestamp parsing."""
from datetime import datetime
iso_str = "2024-01-15T10:30:00"
dt = datetime.fromisoformat(iso_str)
assert dt.year == 2024
assert dt.month == 1
assert dt.day == 15
assert dt.hour == 10
assert dt.minute == 30
def test_format_timestamp(self):
"""Test timestamp formatting."""
from datetime import datetime
dt = datetime(2024, 1, 15, 10, 30, 0)
formatted = dt.strftime("%Y-%m-%d %H:%M")
assert formatted == "2024-01-15 10:30"
def test_handle_invalid_timestamp(self):
"""Test handling of invalid timestamps."""
from datetime import datetime
invalid = "not a timestamp"
with pytest.raises(ValueError):
datetime.fromisoformat(invalid)
if __name__ == "__main__":
pytest.main([__file__, "-v"])