evalstate HF Staff commited on
Commit
d010ca8
·
verified ·
1 Parent(s): b596d41

Upload demo_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo_train.py +84 -0
demo_train.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "trackio",
8
+ # ]
9
+ # ///
10
+
11
+ import trackio
12
+ from datasets import load_dataset
13
+ from peft import LoraConfig
14
+ from trl import SFTTrainer, SFTConfig
15
+
16
+ # Initialize Trackio for real-time monitoring
17
+ trackio.init(
18
+ project="qwen-demo-sft",
19
+ space_id="evalstate/demo-trackio-dashboard",
20
+ config={
21
+ "model": "Qwen/Qwen2.5-0.5B",
22
+ "dataset": "trl-lib/Capybara",
23
+ "examples": 50,
24
+ "max_steps": 20,
25
+ "note": "Quick demo training"
26
+ }
27
+ )
28
+
29
+ # Load dataset (only 50 examples for quick demo)
30
+ dataset = load_dataset("trl-lib/Capybara", split="train[:50]")
31
+ print(f"✅ Dataset loaded: {len(dataset)} examples")
32
+
33
+ # Training configuration
34
+ config = SFTConfig(
35
+ # Hub settings - CRITICAL for saving results
36
+ output_dir="qwen-demo-sft",
37
+ push_to_hub=True,
38
+ hub_model_id="evalstate/qwen-demo-sft",
39
+
40
+ # Quick training settings
41
+ max_steps=20, # Very short for demo
42
+ per_device_train_batch_size=2,
43
+ gradient_accumulation_steps=2,
44
+ learning_rate=2e-5,
45
+
46
+ # Logging
47
+ logging_steps=5,
48
+ save_strategy="steps",
49
+ save_steps=10,
50
+
51
+ # Monitoring
52
+ report_to="trackio",
53
+ )
54
+
55
+ # LoRA configuration (memory efficient)
56
+ peft_config = LoraConfig(
57
+ r=16,
58
+ lora_alpha=32,
59
+ lora_dropout=0.05,
60
+ bias="none",
61
+ task_type="CAUSAL_LM",
62
+ target_modules=["q_proj", "v_proj"],
63
+ )
64
+
65
+ # Initialize and train
66
+ trainer = SFTTrainer(
67
+ model="Qwen/Qwen2.5-0.5B",
68
+ train_dataset=dataset,
69
+ args=config,
70
+ peft_config=peft_config,
71
+ )
72
+
73
+ print("🚀 Starting demo training...")
74
+ trainer.train()
75
+
76
+ print("💾 Pushing to Hub...")
77
+ trainer.push_to_hub()
78
+
79
+ # Finish Trackio tracking
80
+ trackio.finish()
81
+
82
+ print("✅ Demo complete!")
83
+ print(f"📦 Model: https://huggingface.co/evalstate/qwen-demo-sft")
84
+ print(f"📊 Metrics: https://huggingface.co/spaces/evalstate/demo-trackio-dashboard")