Spaces:
Running
Running
File size: 11,671 Bytes
aff421f 9650b46 aff421f 9650b46 aff421f 9650b46 aff421f d739c8e aff421f f391dd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 |
from langgraph.graph import StateGraph, START
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from typing import TypedDict, Annotated
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
import asyncio
from langchain_mcp_adapters.client import MultiServerMCPClient
import os
from typing import Optional
import sys
api_key = os.getenv("GEMINI_API_KEY")
llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
temperature=0.2,
google_api_key=api_key
)
print("π DEBUG: MCP Server Path Resolution")
print("=" * 60)
print(f"π Current working directory: {os.getcwd()}")
print(f"π __file__ is: {__file__}")
print(f"π __file__ directory: {os.path.dirname(__file__)}")
# Try multiple path strategies
server_path_options = [
os.path.join(os.path.dirname(__file__), "server.py"),
os.path.join(os.getcwd(), "server.py"),
os.path.abspath("server.py"),
"server.py"
]
server_path = None
for path in server_path_options:
print(f"π Checking: {path} ... ", end="")
if os.path.exists(path):
print("β
FOUND!")
server_path = path
break
else:
print("β Not found")
if not server_path:
print("\nβ ERROR: server.py not found in any expected location!")
print(f"π Files in current directory: {os.listdir('.')}")
if os.path.dirname(__file__):
print(f"π Files in __file__ directory: {os.listdir(os.path.dirname(__file__))}")
raise FileNotFoundError("server.py not found. Make sure it's uploaded to your HF Space.")
else:
print(f"β
Using server.py at: {server_path}")
print(f"π Python executable: {sys.executable}")
print("=" * 60)
# Initialize MCP Client
try:
# In client.py - pass environment explicitly
client = MultiServerMCPClient(
{
"FistalMCP": {
"transport": "stdio",
"command": sys.executable,
"args": ["-u", server_path],
"env": { # Explicitly pass environment variables
**os.environ, # Pass all current environment
"GROQ_API_KEY": os.getenv("GROQ_API_KEY", ""),
"HF_TOKEN": os.getenv("HF_TOKEN", ""),
"GOOGLE_API_KEY_1": os.getenv("GOOGLE_API_KEY_1", ""),
"GOOGLE_API_KEY_2": os.getenv("GOOGLE_API_KEY_2", ""),
"GOOGLE_API_KEY_3": os.getenv("GOOGLE_API_KEY_3", ""),
"PATH": os.getenv("PATH", ""),
}
}
}
)
print("β
MCP Client initialized successfully!")
except Exception as e:
print(f"β Failed to initialize MCP Client: {e}")
import traceback
traceback.print_exc()
raise
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
dataset_topic: str
num_samples: int
model_name: str
task_type: str
dataset_path: Optional[str]
converted_path: Optional[str]
model_path: Optional[str]
hf_url: Optional[str]
async def my_graph():
"""Agent graph that handles mcp tools"""
tools = await client.get_tools()
available_tools = []
tool_order = ["generate_json_data", "format_json", "finetune_model", "llm_as_judge"]
available_tools = []
for tool_name in tool_order:
for tool in tools:
if tool.name == tool_name:
available_tools.append(tool)
break
llm_toolkit = llm.bind_tools(available_tools)
async def chat_node(state: ChatState):
messages = state["messages"]
dataset_topic = state['dataset_topic']
if isinstance(dataset_topic, list):
dataset_topic = dataset_topic[0] if dataset_topic else "unknown"
num_samples = state['num_samples']
if isinstance(num_samples, list):
num_samples = num_samples[0] if num_samples else 100
model_name = state['model_name']
if isinstance(model_name, list):
model_name = model_name[0] if model_name else "unknown"
task_type = state['task_type']
if isinstance(task_type, list):
task_type = task_type[0] if task_type else "text-generation"
system_msg = f"""You are Fistal, an AI fine-tuning assistant.
**User's Configuration:**
- Dataset Topic: {dataset_topic}
- Number of Samples: {num_samples}
- Model to Fine-tune: {model_name}
- Task Type: {task_type}
- Evaluation : Using LLM
**Your Workflow:**
1. Use generate_json_data with topic="{dataset_topic}", task_type="{task_type}", num_samples={num_samples}
- This returns a dictionary with a "data" field containing the raw dataset
2. Use format_json with the "data" field from step 1
- Pass: raw_data=<the data list from step 1>
- This returns a dictionary with a "data" field containing formatted data
3. Use finetune_model with the "data" field from step 2 and model_name="{model_name}"
- Pass: formatted_data=<the data list from step 2>, model_name="{model_name}"
- This returns the Hugging Face repo URL
4. Use llm_as_judge with the repo_id from step 3
- Pass: repo_id=<the HF repo from step 3>, topic="{dataset_topic}", task_type="{task_type}"
**FINAL STEP - CRITICAL:**
5. After completing all tools, you MUST return:
- The Hugging Face model URL from step 3
- The evaluation report from step 4
- Format your final response as:
π **Fine-tuning Complete!**
**π€ Model Repository:** [HF Repo Link] \n\n
**π Evaluation Report:** [Full report from llm_as_judge]
**IMPORTANT:**
- Tools pass DATA directly, not file paths
- Always mention the tool you are going to use first and then proceed with the tool action
- Extract the "data" field from each tool's response and pass it to the next tool
- After llm_as_judge completes, return both the HF URL and evaluation report
- Keep the user informed of progress at each step
- If a step takes time, do not stay idle. Inform users about short interesting facts about Modal, Unsloth, Gemini, Gradio , HuggingFace and MCP, do not repeat them.
- Try to add atleast 1 new fact every 10 seconds.
- Report any errors clearly
- Do not mention internal data structures or file paths"""
full_messages = [SystemMessage(content=system_msg)] + messages
response = await llm_toolkit.ainvoke(full_messages)
return {'messages': [response]}
tool_node = ToolNode(available_tools)
graph = StateGraph(ChatState)
graph.add_node("chat_node", chat_node)
graph.add_node("tools", tool_node)
graph.add_edge(START, "chat_node")
graph.add_conditional_edges("chat_node", tools_condition)
graph.add_edge("tools", "chat_node")
chat = graph.compile()
return chat
async def run_fistal(
dataset_topic: str,
num_samples: int,
model_name: str,
task_type: str
):
chatbot = await my_graph()
user_message = f"""Execute the complete fine-tuning workflow:
- Generate {num_samples} training examples about {dataset_topic}
- Fine-tune {model_name}
- Evaluate for {task_type} task
Start now!"""
initial_state = {
"messages": [HumanMessage(content=user_message)],
"dataset_topic": dataset_topic,
"num_samples": num_samples,
"model_name": model_name,
"task_type": task_type,
"dataset_path": None,
"converted_path": None,
"model_path": None,
"hf_url": None
}
facts = {
"generate_json_data": [
"π‘ Using parallel batch generation with multiple API keys for 3x speed!",
"π Quality over quantity - diverse examples lead to better models!",
"π― Generating diverse prompt-response pairs...",
],
"format_json": [
"π Converting to chat format optimized for instruction tuning...",
"π¬ Proper formatting helps models understand conversation structure!",
"π¨ Applying ChatML format for consistency...",
"β
Validating JSON structure for training compatibility...",
"π§ Optimizing token distribution across examples..."
],
"finetune_model": [
"ποΈ Training on Modal's serverless T4 GPU...",
"π‘ Using 4-bit quantization to fit in 16GB VRAM!",
"π¦₯ Unsloth makes training 2x faster with 70% less memory!",
"β‘ LoRA fine-tuning updates only 0.1% of model parameters!",
"π― Typical training time: 10-20 minutes for 500 samples...",
"π₯ Your model is learning patterns from authentic data!",
"βοΈ Uploading to HuggingFace - your model will be public soon!"
],
"llm_as_judge": [
"π Generating evaluation test cases...",
"π€ LLM-as-judge provides qualitative insights!",
"β¨ Testing model coherence, relevance, and accuracy...",
"π Creating comprehensive evaluation report...",
"π Analyzing response quality and task alignment...",
"π Creating comprehensive evaluation report...",
"π Comparing outputs against expected responses...",
"π― Assessing model's understanding of the domain...",
"β
Finalizing evaluation metrics.."
]
}
current_tool = None
fact_i = 0
async for event in chatbot.astream(initial_state):
if "tools" in event:
messages = event["tools"].get("messages", [])
for msg in messages:
if hasattr(msg,"name"):
tool_name = msg.name
current_tool = tool_name
fact_i = 0
yield f"\n{'-'*60}\n"
yield f"π **Using: {tool_name}**\n\n"
if tool_name in facts:
yield f"{facts[tool_name][0]}\n"
await asyncio.sleep(0.3)
if "chat_node" in event:
messages = event["chat_node"].get("messages", [])
for msg in messages:
if hasattr(msg, 'content') and msg.content:
raw_content = msg.content
content = ""
if isinstance(raw_content, list):
for item in raw_content:
if isinstance(item, dict) and item.get('type') == 'text':
content += item.get('text', '')
content = content.strip()
elif isinstance(raw_content, str):
content = raw_content
else:
content = str(raw_content)
if content and len(content) > 20 and "tool_calls" not in content.lower():
yield f"\nπ€ **Fistal:** {content}\n"
if current_tool and current_tool in facts:
fact_i += 1
if fact_i < len(facts[current_tool]):
yield f"\nπ‘ {facts[current_tool][fact_i]}\n"
await asyncio.sleep(0.3)
yield "β
**Successfully finetuned!**\n"
async def main():
"""Test the agent. Only for running client.py"""
print("Testing Fistal Agent\n")
result = await run_fistal(
"python programming",
5,
"unsloth/Llama-3.2-1B-Instruct-bnb-4bit",
"text-generation"
)
print(f"\nAgent Response:\n{result}")
if __name__ == '__main__':
asyncio.run(main()) |