戒酒的李白

Optimize the model loading logic of Qwen3.

@@ -37,11 +37,14 @@ class Qwen3UniversalPredictor: @@ -37,11 +37,14 @@ class Qwen3UniversalPredictor:
37 raise ValueError(f"不支持的模型大小: {model_size}") 37 raise ValueError(f"不支持的模型大小: {model_size}")
38 38
39 model_path = MODEL_PATHS[model_type][model_size] 39 model_path = MODEL_PATHS[model_type][model_size]
  40 + model_key = self._get_model_key(model_type, model_size)
  41 +
  42 + # 检查训练好的模型文件是否存在
40 if not os.path.exists(model_path): 43 if not os.path.exists(model_path):
41 - print(f"模型文件不存在: {model_path}") 44 + print(f"训练好的模型文件不存在: {model_path}")
  45 + print(f"请先训练 {model_type.upper()}-{model_size} 模型,或检查模型路径配置")
42 return 46 return
43 47
44 - model_key = self._get_model_key(model_type, model_size)  
45 print(f"加载 {model_type.upper()}-{model_size} 模型...") 48 print(f"加载 {model_type.upper()}-{model_size} 模型...")
46 49
47 try: 50 try:
@@ -60,6 +63,7 @@ class Qwen3UniversalPredictor: @@ -60,6 +63,7 @@ class Qwen3UniversalPredictor:
60 63
61 except Exception as e: 64 except Exception as e:
62 print(f"加载 {model_type.upper()}-{model_size} 模型失败: {e}") 65 print(f"加载 {model_type.upper()}-{model_size} 模型失败: {e}")
  66 + print(f"这可能是因为基础模型下载失败或训练好的模型文件损坏")
63 67
64 def load_all_models(self, model_dir: str = './models') -> None: 68 def load_all_models(self, model_dir: str = './models') -> None:
65 """加载所有可用的模型""" 69 """加载所有可用的模型"""
@@ -103,46 +103,61 @@ class Qwen3EmbeddingUniversal(BaseQwenModel): @@ -103,46 +103,61 @@ class Qwen3EmbeddingUniversal(BaseQwenModel):
103 """加载Qwen3 Embedding模型""" 103 """加载Qwen3 Embedding模型"""
104 print(f"加载{self.model_size}模型: {self.model_name_hf}") 104 print(f"加载{self.model_size}模型: {self.model_name_hf}")
105 105
  106 + # 第一步:检查当前文件夹的models目录
  107 + local_model_dir = f"./models/qwen3-embedding-{self.model_size.lower()}"
  108 + if os.path.exists(local_model_dir) and os.path.exists(os.path.join(local_model_dir, "config.json")):
106 try: 109 try:
  110 + print(f"发现本地模型,从本地加载: {local_model_dir}")
  111 + self.tokenizer = AutoTokenizer.from_pretrained(local_model_dir)
  112 + self.embedding_model = AutoModel.from_pretrained(local_model_dir).to(self.device)
  113 + print(f"从本地模型加载{self.model_size}模型成功")
  114 + return
  115 +
  116 + except Exception as e:
  117 + print(f"本地模型加载失败: {e}")
  118 +
  119 + # 第二步:检查HuggingFace缓存
  120 + try:
  121 + from transformers.utils import default_cache_path
  122 + cache_path = default_cache_path
  123 + print(f"检查HuggingFace缓存: {cache_path}")
  124 +
107 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf) 125 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf)
108 self.embedding_model = AutoModel.from_pretrained(self.model_name_hf).to(self.device) 126 self.embedding_model = AutoModel.from_pretrained(self.model_name_hf).to(self.device)
109 - print(f"{self.model_size}模型加载完成") 127 + print(f"从HuggingFace缓存加载{self.model_size}模型成功")
110 128
111 - # 立即保存到本地缓存  
112 - cache_dir = f"./models/qwen3-embedding-{self.model_size.lower()}"  
113 - if not os.path.exists(cache_dir):  
114 - print(f"保存模型到本地: {cache_dir}")  
115 - os.makedirs(cache_dir, exist_ok=True)  
116 - self.tokenizer.save_pretrained(cache_dir)  
117 - self.embedding_model.save_pretrained(cache_dir)  
118 - print(f"模型已保存到: {cache_dir}") 129 + # 保存到本地models目录
  130 + print(f"保存模型到本地: {local_model_dir}")
  131 + os.makedirs(local_model_dir, exist_ok=True)
  132 + self.tokenizer.save_pretrained(local_model_dir)
  133 + self.embedding_model.save_pretrained(local_model_dir)
  134 + print(f"模型已保存到: {local_model_dir}")
119 135
120 except Exception as e: 136 except Exception as e:
121 - print(f"从Hugging Face加载失败: {e}") 137 + print(f"从HuggingFace缓存加载失败: {e}")
122 138
123 - # 尝试从本地缓存加载  
124 - cache_dir = f"./models/qwen3-embedding-{self.model_size.lower()}" 139 + # 第三步:从HuggingFace下载
125 try: 140 try:
126 - if os.path.exists(cache_dir):  
127 - self.tokenizer = AutoTokenizer.from_pretrained(cache_dir)  
128 - self.embedding_model = AutoModel.from_pretrained(cache_dir).to(self.device)  
129 - print(f"从本地缓存加载{self.model_size}模型成功")  
130 - else:  
131 - raise FileNotFoundError("本地缓存也不存在") 141 + print(f"正在从HuggingFace下载{self.model_size}模型...")
  142 +
  143 + self.tokenizer = AutoTokenizer.from_pretrained(
  144 + self.model_name_hf,
  145 + force_download=True
  146 + )
  147 + self.embedding_model = AutoModel.from_pretrained(
  148 + self.model_name_hf,
  149 + force_download=True
  150 + ).to(self.device)
  151 +
  152 + # 保存到本地models目录
  153 + os.makedirs(local_model_dir, exist_ok=True)
  154 + self.tokenizer.save_pretrained(local_model_dir)
  155 + self.embedding_model.save_pretrained(local_model_dir)
  156 + print(f"{self.model_size}模型下载并保存到: {local_model_dir}")
132 157
133 except Exception as e2: 158 except Exception as e2:
134 - print(f"本地加载也失败: {e2}")  
135 - print(f"正在下载{self.model_size}模型...")  
136 -  
137 - # 创建缓存目录并下载  
138 - os.makedirs(cache_dir, exist_ok=True)  
139 - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf, cache_dir=cache_dir)  
140 - self.embedding_model = AutoModel.from_pretrained(self.model_name_hf, cache_dir=cache_dir).to(self.device)  
141 -  
142 - # 保存到本地  
143 - self.tokenizer.save_pretrained(cache_dir)  
144 - self.embedding_model.save_pretrained(cache_dir)  
145 - print(f"{self.model_size}模型下载并保存到: {cache_dir}") 159 + print(f"从HuggingFace下载也失败: {e2}")
  160 + raise RuntimeError(f"无法加载{self.model_size}模型,所有方法都失败了")
146 161
147 def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None: 162 def train(self, train_data: List[Tuple[str, int]], **kwargs) -> None:
148 """训练模型""" 163 """训练模型"""
@@ -46,10 +46,14 @@ class Qwen3LoRAUniversal(BaseQwenModel): @@ -46,10 +46,14 @@ class Qwen3LoRAUniversal(BaseQwenModel):
46 """加载Qwen3基础模型""" 46 """加载Qwen3基础模型"""
47 print(f"加载{self.model_size}基础模型: {self.model_name_hf}") 47 print(f"加载{self.model_size}基础模型: {self.model_name_hf}")
48 48
  49 + # 第一步:检查当前文件夹的models目录
  50 + local_model_dir = f"./models/qwen3-{self.model_size.lower()}"
  51 + if os.path.exists(local_model_dir) and os.path.exists(os.path.join(local_model_dir, "config.json")):
49 try: 52 try:
50 - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf) 53 + print(f"发现本地模型,从本地加载: {local_model_dir}")
  54 + self.tokenizer = AutoTokenizer.from_pretrained(local_model_dir)
51 self.base_model = AutoModelForCausalLM.from_pretrained( 55 self.base_model = AutoModelForCausalLM.from_pretrained(
52 - self.model_name_hf, 56 + local_model_dir,
53 torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, 57 torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
54 device_map="auto" if torch.cuda.is_available() else None 58 device_map="auto" if torch.cuda.is_available() else None
55 ) 59 )
@@ -59,49 +63,53 @@ class Qwen3LoRAUniversal(BaseQwenModel): @@ -59,49 +63,53 @@ class Qwen3LoRAUniversal(BaseQwenModel):
59 self.tokenizer.pad_token = self.tokenizer.eos_token 63 self.tokenizer.pad_token = self.tokenizer.eos_token
60 self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 64 self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
61 65
62 - print(f"{self.model_size}基础模型加载完成")  
63 -  
64 - # 立即保存到本地缓存  
65 - cache_dir = f"./models/qwen3-{self.model_size.lower()}"  
66 - if not os.path.exists(cache_dir):  
67 - print(f"保存模型到本地: {cache_dir}")  
68 - os.makedirs(cache_dir, exist_ok=True)  
69 - self.tokenizer.save_pretrained(cache_dir)  
70 - self.base_model.save_pretrained(cache_dir)  
71 - print(f"模型已保存到: {cache_dir}") 66 + print(f"从本地模型加载{self.model_size}基础模型成功")
  67 + return
72 68
73 except Exception as e: 69 except Exception as e:
74 - print(f"从Hugging Face加载失败: {e}") 70 + print(f"本地模型加载失败: {e}")
75 71
76 - # 尝试从本地缓存加载  
77 - cache_dir = f"./models/qwen3-{self.model_size.lower()}" 72 + # 第二步:检查HuggingFace缓存
78 try: 73 try:
79 - if os.path.exists(cache_dir):  
80 - self.tokenizer = AutoTokenizer.from_pretrained(cache_dir) 74 + from transformers.utils import default_cache_path
  75 + cache_path = default_cache_path
  76 + print(f"检查HuggingFace缓存: {cache_path}")
  77 +
  78 + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf)
81 self.base_model = AutoModelForCausalLM.from_pretrained( 79 self.base_model = AutoModelForCausalLM.from_pretrained(
82 - cache_dir, 80 + self.model_name_hf,
83 torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, 81 torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
84 device_map="auto" if torch.cuda.is_available() else None 82 device_map="auto" if torch.cuda.is_available() else None
85 ) 83 )
86 84
  85 + # 设置pad_token
87 if self.tokenizer.pad_token is None: 86 if self.tokenizer.pad_token is None:
88 self.tokenizer.pad_token = self.tokenizer.eos_token 87 self.tokenizer.pad_token = self.tokenizer.eos_token
89 self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 88 self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
90 89
91 - print(f"从本地缓存加载{self.model_size}模型成功")  
92 - else:  
93 - raise FileNotFoundError("本地缓存也不存在") 90 + print(f"从HuggingFace缓存加载{self.model_size}基础模型成功")
94 91
95 - except Exception as e2:  
96 - print(f"本地加载也失败: {e2}")  
97 - print(f"正在下载{self.model_size}模型...") 92 + # 保存到本地models目录
  93 + print(f"保存模型到本地: {local_model_dir}")
  94 + os.makedirs(local_model_dir, exist_ok=True)
  95 + self.tokenizer.save_pretrained(local_model_dir)
  96 + self.base_model.save_pretrained(local_model_dir)
  97 + print(f"模型已保存到: {local_model_dir}")
98 98
99 - # 创建缓存目录并下载  
100 - os.makedirs(cache_dir, exist_ok=True)  
101 - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_hf, cache_dir=cache_dir) 99 + except Exception as e:
  100 + print(f"从HuggingFace缓存加载失败: {e}")
  101 +
  102 + # 第三步:从HuggingFace下载
  103 + try:
  104 + print(f"正在从HuggingFace下载{self.model_size}模型...")
  105 +
  106 + self.tokenizer = AutoTokenizer.from_pretrained(
  107 + self.model_name_hf,
  108 + force_download=True
  109 + )
102 self.base_model = AutoModelForCausalLM.from_pretrained( 110 self.base_model = AutoModelForCausalLM.from_pretrained(
103 self.model_name_hf, 111 self.model_name_hf,
104 - cache_dir=cache_dir, 112 + force_download=True,
105 torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, 113 torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
106 device_map="auto" if torch.cuda.is_available() else None 114 device_map="auto" if torch.cuda.is_available() else None
107 ) 115 )
@@ -110,10 +118,15 @@ class Qwen3LoRAUniversal(BaseQwenModel): @@ -110,10 +118,15 @@ class Qwen3LoRAUniversal(BaseQwenModel):
110 self.tokenizer.pad_token = self.tokenizer.eos_token 118 self.tokenizer.pad_token = self.tokenizer.eos_token
111 self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 119 self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
112 120
113 - # 保存到本地  
114 - self.tokenizer.save_pretrained(cache_dir)  
115 - self.base_model.save_pretrained(cache_dir)  
116 - print(f"{self.model_size}模型下载并保存到: {cache_dir}") 121 + # 保存到本地models目录
  122 + os.makedirs(local_model_dir, exist_ok=True)
  123 + self.tokenizer.save_pretrained(local_model_dir)
  124 + self.base_model.save_pretrained(local_model_dir)
  125 + print(f"{self.model_size}模型下载并保存到: {local_model_dir}")
  126 +
  127 + except Exception as e2:
  128 + print(f"从HuggingFace下载也失败: {e2}")
  129 + raise RuntimeError(f"无法加载{self.model_size}模型,所有方法都失败了")
117 130
118 def _create_instruction_data(self, data: List[Tuple[str, int]]) -> Dataset: 131 def _create_instruction_data(self, data: List[Tuple[str, int]]) -> Dataset:
119 """创建指令格式的训练数据""" 132 """创建指令格式的训练数据"""