| | import logging |
| | from fastapi import APIRouter, Depends, HTTPException |
| | from jinja2 import Environment |
| | from litellm.router import Router |
| | from dependencies import get_llm_router, get_prompt_templates |
| | from schemas import _ReqGroupingCategory, _ReqGroupingOutput, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse |
| |
|
| | |
| | router = APIRouter(tags=["requirement processing"]) |
| |
|
| |
|
| | @router.post("/get_reqs_from_query", response_model=ReqSearchResponse) |
| | def find_requirements_from_problem_description(req: ReqSearchRequest, llm_router: Router = Depends(get_llm_router)): |
| | """Finds the requirements that adress a given problem description from an extracted list""" |
| |
|
| | requirements = req.requirements |
| | query = req.query |
| |
|
| | requirements_text = "\n".join( |
| | [f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements]) |
| | resp_ai = llm_router.completion( |
| | model="gemini-v2", |
| | messages=[{"role": "user", "content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}], |
| | response_format=ReqSearchLLMResponse |
| | ) |
| |
|
| | out_llm = ReqSearchLLMResponse.model_validate_json( |
| | resp_ai.choices[0].message.content).selected |
| |
|
| | logging.info(f"Found {len(out_llm)} reqs matching case.") |
| |
|
| | if max(out_llm) > len(requirements) - 1: |
| | raise HTTPException( |
| | status_code=500, detail="LLM error : Generated a wrong index, please try again.") |
| |
|
| | return ReqSearchResponse(requirements=[requirements[i] for i in out_llm]) |
| |
|
| |
|
| | @router.post("/categorize_requirements") |
| | async def categorize_reqs(params: ReqGroupingRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> ReqGroupingResponse: |
| | """Categorize the given service requirements into categories""" |
| |
|
| | MAX_ATTEMPTS = 5 |
| |
|
| | categories: list[_ReqGroupingCategory] = [] |
| | messages = [] |
| |
|
| | |
| | req_prompt = await prompt_env.get_template("public/classify.txt").render_async(**{ |
| | "requirements": [rq.model_dump() for rq in params.requirements], |
| | "max_n_categories": params.max_n_categories, |
| | "response_schema": _ReqGroupingOutput.model_json_schema()}) |
| |
|
| | |
| | messages.append({"role": "user", "content": req_prompt}) |
| |
|
| | |
| | for attempt in range(MAX_ATTEMPTS): |
| | req_completion = await llm_router.acompletion(model="gemini-v2", messages=messages, response_format=_ReqGroupingOutput) |
| | output = _ReqGroupingOutput.model_validate_json( |
| | req_completion.choices[0].message.content) |
| |
|
| | |
| | valid_ids_universe = set(range(0, len(params.requirements))) |
| | assigned_ids = { |
| | req_id for cat in output.categories for req_id in cat.items} |
| |
|
| | |
| | valid_assigned_ids = assigned_ids.intersection(valid_ids_universe) |
| |
|
| | |
| | unassigned_ids = valid_ids_universe - valid_assigned_ids |
| |
|
| |
|
| | if len(unassigned_ids) == 0 or params.disable_sort_checks: |
| | categories.extend(output.categories) |
| | break |
| | else: |
| | messages.append(req_completion.choices[0].message) |
| | messages.append( |
| | {"role": "user", "content": f"You haven't categorized the following requirements in at least one category {unassigned_ids}. Please do so."}) |
| |
|
| | if attempt == MAX_ATTEMPTS - 1: |
| | raise Exception("Failed to classify all requirements") |
| |
|
| | |
| | |
| | final_categories = [] |
| | for idx, cat in enumerate(output.categories): |
| | final_categories.append(ReqGroupingCategory( |
| | id=idx, |
| | title=cat.title, |
| | requirements=[params.requirements[i] |
| | for i in cat.items if i < len(params.requirements)] |
| | )) |
| |
|
| | return ReqGroupingResponse(categories=final_categories) |
| |
|