戒酒的李白

Tuning Qwen3 fine-tuning hyperparameters

@@ -347,9 +347,9 @@ def main(): @@ -347,9 +347,9 @@ def main():
347 print("Qwen3-Embedding模型训练") 347 print("Qwen3-Embedding模型训练")
348 print("="*40) 348 print("="*40)
349 print("可用模型大小:") 349 print("可用模型大小:")
350 - print(" 1. 0.6B - 轻量级,训练快速,显存需求约2GB")  
351 - print(" 2. 4B - 中等规模,性能均衡,显存需求约8GB")  
352 - print(" 3. 8B - 大规模,性能最佳,显存需求约16GB") 350 + print(" 1. 0.6B - 轻量级,训练快速,显存需求约4GB")
  351 + print(" 2. 4B - 中等规模,性能均衡,显存需求约16GB")
  352 + print(" 3. 8B - 大规模,性能最佳,显存需求约32GB")
353 353
354 while True: 354 while True:
355 choice = input("\n请选择模型大小 (1/2/3): ").strip() 355 choice = input("\n请选择模型大小 (1/2/3): ").strip()
@@ -155,7 +155,7 @@ class Qwen3LoRAUniversal(BaseQwenModel): @@ -155,7 +155,7 @@ class Qwen3LoRAUniversal(BaseQwenModel):
155 tokenized = self.tokenizer( 155 tokenized = self.tokenizer(
156 examples["text"], 156 examples["text"],
157 truncation=True, 157 truncation=True,
158 - padding=False, 158 + padding="max_length",
159 max_length=512, 159 max_length=512,
160 return_tensors=None 160 return_tensors=None
161 ) 161 )
@@ -178,9 +178,15 @@ class Qwen3LoRAUniversal(BaseQwenModel): @@ -178,9 +178,15 @@ class Qwen3LoRAUniversal(BaseQwenModel):
178 178
179 self.lora_model = get_peft_model(self.base_model, lora_config) 179 self.lora_model = get_peft_model(self.base_model, lora_config)
180 180
  181 + # 统计参数
  182 + total_params = sum(p.numel() for p in self.lora_model.parameters())
  183 + trainable_params = sum(p.numel() for p in self.lora_model.parameters() if p.requires_grad)
  184 +
181 print(f"LoRA配置完成 (r={lora_r}, alpha={lora_alpha})") 185 print(f"LoRA配置完成 (r={lora_r}, alpha={lora_alpha})")
182 - print(f"可训练参数: {self.lora_model.num_parameters():,}")  
183 - print(f"参数比例: {self.lora_model.num_parameters() / self.lora_model.base_model.num_parameters() * 100:.2f}%") 186 + print(f"总参数: {total_params:,}")
  187 + print(f"可训练参数: {trainable_params:,}")
  188 + print(f"可训练参数比例: {trainable_params / total_params * 100:.2f}%")
  189 + self.lora_model.print_trainable_parameters() # PEFT库自带的参数统计
184 190
185 return lora_config 191 return lora_config
186 192
@@ -360,7 +366,7 @@ def main(): @@ -360,7 +366,7 @@ def main():
360 parser.add_argument('--batch_size', type=int, help='批大小(可选,使用推荐值)') 366 parser.add_argument('--batch_size', type=int, help='批大小(可选,使用推荐值)')
361 parser.add_argument('--learning_rate', type=float, help='学习率(可选,使用推荐值)') 367 parser.add_argument('--learning_rate', type=float, help='学习率(可选,使用推荐值)')
362 parser.add_argument('--lora_r', type=int, help='LoRA秩(可选,使用推荐值)') 368 parser.add_argument('--lora_r', type=int, help='LoRA秩(可选,使用推荐值)')
363 - parser.add_argument('--max_samples', type=int, default=1000, help='最大训练样本数') 369 + parser.add_argument('--max_samples', type=int, default=0, help='最大训练样本数(0表示使用全部数据)')
364 parser.add_argument('--eval_only', action='store_true', help='仅评估模式') 370 parser.add_argument('--eval_only', action='store_true', help='仅评估模式')
365 371
366 args = parser.parse_args() 372 args = parser.parse_args()
@@ -370,9 +376,9 @@ def main(): @@ -370,9 +376,9 @@ def main():
370 print("Qwen3-LoRA模型训练") 376 print("Qwen3-LoRA模型训练")
371 print("="*40) 377 print("="*40)
372 print("可用模型大小:") 378 print("可用模型大小:")
373 - print(" 1. 0.6B - 轻量级,训练快速,显存需求约4GB")  
374 - print(" 2. 4B - 中等规模,性能均衡,显存需求约16GB")  
375 - print(" 3. 8B - 大规模,性能最佳,显存需求约32GB") 379 + print(" 1. 0.6B - 轻量级,训练快速,显存需求约8GB")
  380 + print(" 2. 4B - 中等规模,性能均衡,显存需求约32GB")
  381 + print(" 3. 8B - 大规模,性能最佳,显存需求约64GB")
376 print("\n注意: LoRA微调比Embedding方法需要更多显存") 382 print("\n注意: LoRA微调比Embedding方法需要更多显存")
377 383
378 while True: 384 while True:
@@ -414,9 +420,13 @@ def main(): @@ -414,9 +420,13 @@ def main():
414 # 训练模式 420 # 训练模式
415 train_data, test_data = BaseQwenModel.load_data(args.train_path, args.test_path) 421 train_data, test_data = BaseQwenModel.load_data(args.train_path, args.test_path)
416 422
417 - # 由于LoRA训练资源消耗大,使用部分数据 423 + # 训练数据处理
  424 + if args.max_samples > 0:
418 train_subset = train_data[:args.max_samples] 425 train_subset = train_data[:args.max_samples]
419 print(f"使用 {len(train_subset)} 条数据进行LoRA训练") 426 print(f"使用 {len(train_subset)} 条数据进行LoRA训练")
  427 + else:
  428 + train_subset = train_data
  429 + print(f"使用全部 {len(train_subset)} 条数据进行LoRA训练")
420 430
421 # 准备训练参数 431 # 准备训练参数
422 train_kwargs = {'num_epochs': args.epochs} 432 train_kwargs = {'num_epochs': args.epochs}
@@ -10,7 +10,7 @@ qwen 0.6B模型加线性分类器,做特定领域的文本分类和序列标 @@ -10,7 +10,7 @@ qwen 0.6B模型加线性分类器,做特定领域的文本分类和序列标
10 10
11 在经过了一些相关的调研之后,我觉的将Qwen3的一些小参数模型用在本系统中是一个不错的选择。 11 在经过了一些相关的调研之后,我觉的将Qwen3的一些小参数模型用在本系统中是一个不错的选择。
12 12
13 -虽然这个参数在LLM时代算小,但作为个人开发者计算资源有限,微调他们还是实属不易 13 +虽然这个参数在LLM时代算小,但作为个人开发者计算资源有限,微调他们还是实属不易,在一张A100上训练了整整四天,求求star了
14 14
15 ## 问题探究 15 ## 问题探究
16 16
@@ -87,9 +87,9 @@ python predict_universal.py --load_all --text "这个电影太棒了" @@ -87,9 +87,9 @@ python predict_universal.py --load_all --text "这个电影太棒了"
87 ### 注意事项 87 ### 注意事项
88 88
89 1. **显存要求** 89 1. **显存要求**
90 - - 0.6B: 最低2GB显存  
91 - - 4B: 最低8GB显存  
92 - - 8B: 最低16GB显存 90 + - 0.6B: 最低4GB显存
  91 + - 4B: 最低16GB显存
  92 + - 8B: 最低32GB显存
93 93
94 2. **数据格式**:每行格式为`文本内容\t标签`,标签为0(负面)或1(正面) 94 2. **数据格式**:每行格式为`文本内容\t标签`,标签为0(负面)或1(正面)
95 95