Showing
3 changed files
with
25 additions
and
15 deletions
| @@ -347,9 +347,9 @@ def main(): | @@ -347,9 +347,9 @@ def main(): | ||
| 347 | print("Qwen3-Embedding模型训练") | 347 | print("Qwen3-Embedding模型训练") |
| 348 | print("="*40) | 348 | print("="*40) |
| 349 | print("可用模型大小:") | 349 | print("可用模型大小:") |
| 350 | - print(" 1. 0.6B - 轻量级,训练快速,显存需求约2GB") | ||
| 351 | - print(" 2. 4B - 中等规模,性能均衡,显存需求约8GB") | ||
| 352 | - print(" 3. 8B - 大规模,性能最佳,显存需求约16GB") | 350 | + print(" 1. 0.6B - 轻量级,训练快速,显存需求约4GB") |
| 351 | + print(" 2. 4B - 中等规模,性能均衡,显存需求约16GB") | ||
| 352 | + print(" 3. 8B - 大规模,性能最佳,显存需求约32GB") | ||
| 353 | 353 | ||
| 354 | while True: | 354 | while True: |
| 355 | choice = input("\n请选择模型大小 (1/2/3): ").strip() | 355 | choice = input("\n请选择模型大小 (1/2/3): ").strip() |
| @@ -155,7 +155,7 @@ class Qwen3LoRAUniversal(BaseQwenModel): | @@ -155,7 +155,7 @@ class Qwen3LoRAUniversal(BaseQwenModel): | ||
| 155 | tokenized = self.tokenizer( | 155 | tokenized = self.tokenizer( |
| 156 | examples["text"], | 156 | examples["text"], |
| 157 | truncation=True, | 157 | truncation=True, |
| 158 | - padding=False, | 158 | + padding="max_length", |
| 159 | max_length=512, | 159 | max_length=512, |
| 160 | return_tensors=None | 160 | return_tensors=None |
| 161 | ) | 161 | ) |
| @@ -178,9 +178,15 @@ class Qwen3LoRAUniversal(BaseQwenModel): | @@ -178,9 +178,15 @@ class Qwen3LoRAUniversal(BaseQwenModel): | ||
| 178 | 178 | ||
| 179 | self.lora_model = get_peft_model(self.base_model, lora_config) | 179 | self.lora_model = get_peft_model(self.base_model, lora_config) |
| 180 | 180 | ||
| 181 | + # 统计参数 | ||
| 182 | + total_params = sum(p.numel() for p in self.lora_model.parameters()) | ||
| 183 | + trainable_params = sum(p.numel() for p in self.lora_model.parameters() if p.requires_grad) | ||
| 184 | + | ||
| 181 | print(f"LoRA配置完成 (r={lora_r}, alpha={lora_alpha})") | 185 | print(f"LoRA配置完成 (r={lora_r}, alpha={lora_alpha})") |
| 182 | - print(f"可训练参数: {self.lora_model.num_parameters():,}") | ||
| 183 | - print(f"参数比例: {self.lora_model.num_parameters() / self.lora_model.base_model.num_parameters() * 100:.2f}%") | 186 | + print(f"总参数: {total_params:,}") |
| 187 | + print(f"可训练参数: {trainable_params:,}") | ||
| 188 | + print(f"可训练参数比例: {trainable_params / total_params * 100:.2f}%") | ||
| 189 | + self.lora_model.print_trainable_parameters() # PEFT库自带的参数统计 | ||
| 184 | 190 | ||
| 185 | return lora_config | 191 | return lora_config |
| 186 | 192 | ||
| @@ -360,7 +366,7 @@ def main(): | @@ -360,7 +366,7 @@ def main(): | ||
| 360 | parser.add_argument('--batch_size', type=int, help='批大小(可选,使用推荐值)') | 366 | parser.add_argument('--batch_size', type=int, help='批大小(可选,使用推荐值)') |
| 361 | parser.add_argument('--learning_rate', type=float, help='学习率(可选,使用推荐值)') | 367 | parser.add_argument('--learning_rate', type=float, help='学习率(可选,使用推荐值)') |
| 362 | parser.add_argument('--lora_r', type=int, help='LoRA秩(可选,使用推荐值)') | 368 | parser.add_argument('--lora_r', type=int, help='LoRA秩(可选,使用推荐值)') |
| 363 | - parser.add_argument('--max_samples', type=int, default=1000, help='最大训练样本数') | 369 | + parser.add_argument('--max_samples', type=int, default=0, help='最大训练样本数(0表示使用全部数据)') |
| 364 | parser.add_argument('--eval_only', action='store_true', help='仅评估模式') | 370 | parser.add_argument('--eval_only', action='store_true', help='仅评估模式') |
| 365 | 371 | ||
| 366 | args = parser.parse_args() | 372 | args = parser.parse_args() |
| @@ -370,9 +376,9 @@ def main(): | @@ -370,9 +376,9 @@ def main(): | ||
| 370 | print("Qwen3-LoRA模型训练") | 376 | print("Qwen3-LoRA模型训练") |
| 371 | print("="*40) | 377 | print("="*40) |
| 372 | print("可用模型大小:") | 378 | print("可用模型大小:") |
| 373 | - print(" 1. 0.6B - 轻量级,训练快速,显存需求约4GB") | ||
| 374 | - print(" 2. 4B - 中等规模,性能均衡,显存需求约16GB") | ||
| 375 | - print(" 3. 8B - 大规模,性能最佳,显存需求约32GB") | 379 | + print(" 1. 0.6B - 轻量级,训练快速,显存需求约8GB") |
| 380 | + print(" 2. 4B - 中等规模,性能均衡,显存需求约32GB") | ||
| 381 | + print(" 3. 8B - 大规模,性能最佳,显存需求约64GB") | ||
| 376 | print("\n注意: LoRA微调比Embedding方法需要更多显存") | 382 | print("\n注意: LoRA微调比Embedding方法需要更多显存") |
| 377 | 383 | ||
| 378 | while True: | 384 | while True: |
| @@ -414,9 +420,13 @@ def main(): | @@ -414,9 +420,13 @@ def main(): | ||
| 414 | # 训练模式 | 420 | # 训练模式 |
| 415 | train_data, test_data = BaseQwenModel.load_data(args.train_path, args.test_path) | 421 | train_data, test_data = BaseQwenModel.load_data(args.train_path, args.test_path) |
| 416 | 422 | ||
| 417 | - # 由于LoRA训练资源消耗大,使用部分数据 | 423 | + # 训练数据处理 |
| 424 | + if args.max_samples > 0: | ||
| 418 | train_subset = train_data[:args.max_samples] | 425 | train_subset = train_data[:args.max_samples] |
| 419 | print(f"使用 {len(train_subset)} 条数据进行LoRA训练") | 426 | print(f"使用 {len(train_subset)} 条数据进行LoRA训练") |
| 427 | + else: | ||
| 428 | + train_subset = train_data | ||
| 429 | + print(f"使用全部 {len(train_subset)} 条数据进行LoRA训练") | ||
| 420 | 430 | ||
| 421 | # 准备训练参数 | 431 | # 准备训练参数 |
| 422 | train_kwargs = {'num_epochs': args.epochs} | 432 | train_kwargs = {'num_epochs': args.epochs} |
| @@ -10,7 +10,7 @@ qwen 0.6B模型加线性分类器,做特定领域的文本分类和序列标 | @@ -10,7 +10,7 @@ qwen 0.6B模型加线性分类器,做特定领域的文本分类和序列标 | ||
| 10 | 10 | ||
| 11 | 在经过了一些相关的调研之后,我觉的将Qwen3的一些小参数模型用在本系统中是一个不错的选择。 | 11 | 在经过了一些相关的调研之后,我觉的将Qwen3的一些小参数模型用在本系统中是一个不错的选择。 |
| 12 | 12 | ||
| 13 | -虽然这个参数在LLM时代算小,但作为个人开发者计算资源有限,微调他们还是实属不易。 | 13 | +虽然这个参数在LLM时代算小,但作为个人开发者计算资源有限,微调他们还是实属不易,在一张A100上训练了整整四天,求求star了 |
| 14 | 14 | ||
| 15 | ## 问题探究 | 15 | ## 问题探究 |
| 16 | 16 | ||
| @@ -87,9 +87,9 @@ python predict_universal.py --load_all --text "这个电影太棒了" | @@ -87,9 +87,9 @@ python predict_universal.py --load_all --text "这个电影太棒了" | ||
| 87 | ### 注意事项 | 87 | ### 注意事项 |
| 88 | 88 | ||
| 89 | 1. **显存要求**: | 89 | 1. **显存要求**: |
| 90 | - - 0.6B: 最低2GB显存 | ||
| 91 | - - 4B: 最低8GB显存 | ||
| 92 | - - 8B: 最低16GB显存 | 90 | + - 0.6B: 最低4GB显存 |
| 91 | + - 4B: 最低16GB显存 | ||
| 92 | + - 8B: 最低32GB显存 | ||
| 93 | 93 | ||
| 94 | 2. **数据格式**:每行格式为`文本内容\t标签`,标签为0(负面)或1(正面) | 94 | 2. **数据格式**:每行格式为`文本内容\t标签`,标签为0(负面)或1(正面) |
| 95 | 95 |
-
Please register or login to post a comment