Added a logging utility class and supplemented, standardized the logging output for all modules.
Showing
13 changed files
with
455 additions
and
175 deletions
| 1 | import os | 1 | import os |
| 2 | import re | 2 | import re |
| 3 | -import logging | ||
| 4 | import getpass | 3 | import getpass |
| 5 | import pymysql | 4 | import pymysql |
| 6 | import subprocess | 5 | import subprocess |
| @@ -9,16 +8,7 @@ from apscheduler.schedulers.background import BackgroundScheduler | @@ -9,16 +8,7 @@ from apscheduler.schedulers.background import BackgroundScheduler | ||
| 9 | from pytz import utc | 8 | from pytz import utc |
| 10 | from datetime import datetime, timedelta | 9 | from datetime import datetime, timedelta |
| 11 | import time | 10 | import time |
| 12 | - | ||
| 13 | -# 初始化日志记录 | ||
| 14 | -logging.basicConfig( | ||
| 15 | - level=logging.INFO, | ||
| 16 | - format='%(asctime)s [%(levelname)s] %(message)s', | ||
| 17 | - handlers=[ | ||
| 18 | - logging.FileHandler("app.log"), | ||
| 19 | - logging.StreamHandler() | ||
| 20 | - ] | ||
| 21 | -) | 11 | +from utils.logger import app_logger as logging |
| 22 | 12 | ||
| 23 | def get_db_connection_interactive(): | 13 | def get_db_connection_interactive(): |
| 24 | """ | 14 | """ |
logs/app.log
0 → 100644
logs/model.log
0 → 100644
logs/spider.log
0 → 100644
| @@ -6,6 +6,11 @@ from sklearn.feature_extraction.text import TfidfVectorizer # 鈭敺 | @@ -6,6 +6,11 @@ from sklearn.feature_extraction.text import TfidfVectorizer # 鈭敺 | ||
| 6 | from sklearn.naive_bayes import MultinomialNB # 用于多项式朴素贝叶斯分类 | 6 | from sklearn.naive_bayes import MultinomialNB # 用于多项式朴素贝叶斯分类 |
| 7 | from sklearn.model_selection import train_test_split # 用于划分训练集和测试集 | 7 | from sklearn.model_selection import train_test_split # 用于划分训练集和测试集 |
| 8 | from sklearn.metrics import accuracy_score # 用于计算模型准确度 | 8 | from sklearn.metrics import accuracy_score # 用于计算模型准确度 |
| 9 | +import torch | ||
| 10 | +from transformers import BertTokenizer, BertModel | ||
| 11 | +from torch import nn | ||
| 12 | +from torch.utils.data import Dataset, DataLoader | ||
| 13 | +from utils.logger import model_logger as logging | ||
| 9 | 14 | ||
| 10 | def getSentiment_data(): | 15 | def getSentiment_data(): |
| 11 | # 从CSV文件中读取情感数据 | 16 | # 从CSV文件中读取情感数据 |
| @@ -16,31 +21,153 @@ def getSentiment_data(): | @@ -16,31 +21,153 @@ def getSentiment_data(): | ||
| 16 | sentiment_data.append(data) | 21 | sentiment_data.append(data) |
| 17 | return sentiment_data | 22 | return sentiment_data |
| 18 | 23 | ||
| 24 | +class TextClassificationDataset(Dataset): | ||
| 25 | + def __init__(self, texts, labels, tokenizer, max_len=128): | ||
| 26 | + self.texts = texts | ||
| 27 | + self.labels = labels | ||
| 28 | + self.tokenizer = tokenizer | ||
| 29 | + self.max_len = max_len | ||
| 30 | + | ||
| 31 | + def __len__(self): | ||
| 32 | + return len(self.texts) | ||
| 33 | + | ||
| 34 | + def __getitem__(self, idx): | ||
| 35 | + text = str(self.texts[idx]) | ||
| 36 | + label = self.labels[idx] | ||
| 37 | + | ||
| 38 | + encoding = self.tokenizer.encode_plus( | ||
| 39 | + text, | ||
| 40 | + add_special_tokens=True, | ||
| 41 | + max_length=self.max_len, | ||
| 42 | + return_token_type_ids=False, | ||
| 43 | + padding='max_length', | ||
| 44 | + truncation=True, | ||
| 45 | + return_attention_mask=True, | ||
| 46 | + return_tensors='pt' | ||
| 47 | + ) | ||
| 48 | + | ||
| 49 | + return { | ||
| 50 | + 'text': text, | ||
| 51 | + 'input_ids': encoding['input_ids'].flatten(), | ||
| 52 | + 'attention_mask': encoding['attention_mask'].flatten(), | ||
| 53 | + 'label': torch.tensor(label, dtype=torch.long) | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | +class BertClassifier(nn.Module): | ||
| 57 | + def __init__(self, n_classes): | ||
| 58 | + super(BertClassifier, self).__init__() | ||
| 59 | + self.bert = BertModel.from_pretrained('bert-base-chinese') | ||
| 60 | + self.drop = nn.Dropout(p=0.3) | ||
| 61 | + self.fc = nn.Linear(self.bert.config.hidden_size, n_classes) | ||
| 62 | + | ||
| 63 | + def forward(self, input_ids, attention_mask): | ||
| 64 | + outputs = self.bert( | ||
| 65 | + input_ids=input_ids, | ||
| 66 | + attention_mask=attention_mask | ||
| 67 | + ) | ||
| 68 | + pooled_output = outputs[1] | ||
| 69 | + output = self.drop(pooled_output) | ||
| 70 | + return self.fc(output) | ||
| 71 | + | ||
| 72 | +def train_model(model, train_loader, val_loader, learning_rate=2e-5, epochs=4): | ||
| 73 | + """训练模型""" | ||
| 74 | + try: | ||
| 75 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
| 76 | + logging.info(f"使用设备: {device}") | ||
| 77 | + | ||
| 78 | + model = model.to(device) | ||
| 79 | + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) | ||
| 80 | + criterion = nn.CrossEntropyLoss() | ||
| 81 | + | ||
| 82 | + for epoch in range(epochs): | ||
| 83 | + model.train() | ||
| 84 | + total_loss = 0 | ||
| 85 | + logging.info(f"开始训练 Epoch {epoch + 1}/{epochs}") | ||
| 86 | + | ||
| 87 | + for batch in train_loader: | ||
| 88 | + input_ids = batch['input_ids'].to(device) | ||
| 89 | + attention_mask = batch['attention_mask'].to(device) | ||
| 90 | + labels = batch['label'].to(device) | ||
| 91 | + | ||
| 92 | + outputs = model(input_ids=input_ids, attention_mask=attention_mask) | ||
| 93 | + loss = criterion(outputs, labels) | ||
| 94 | + | ||
| 95 | + optimizer.zero_grad() | ||
| 96 | + loss.backward() | ||
| 97 | + optimizer.step() | ||
| 98 | + | ||
| 99 | + total_loss += loss.item() | ||
| 100 | + | ||
| 101 | + avg_train_loss = total_loss / len(train_loader) | ||
| 102 | + logging.info(f"Epoch {epoch + 1} 平均训练损失: {avg_train_loss:.4f}") | ||
| 103 | + | ||
| 104 | + # 验证 | ||
| 105 | + model.eval() | ||
| 106 | + val_preds = [] | ||
| 107 | + val_labels = [] | ||
| 108 | + | ||
| 109 | + with torch.no_grad(): | ||
| 110 | + for batch in val_loader: | ||
| 111 | + input_ids = batch['input_ids'].to(device) | ||
| 112 | + attention_mask = batch['attention_mask'].to(device) | ||
| 113 | + labels = batch['label'].to(device) | ||
| 114 | + | ||
| 115 | + outputs = model(input_ids=input_ids, attention_mask=attention_mask) | ||
| 116 | + _, preds = torch.max(outputs, dim=1) | ||
| 117 | + | ||
| 118 | + val_preds.extend(preds.cpu().numpy()) | ||
| 119 | + val_labels.extend(labels.cpu().numpy()) | ||
| 120 | + | ||
| 121 | + val_accuracy = accuracy_score(val_labels, val_preds) | ||
| 122 | + logging.info(f"Epoch {epoch + 1} 验证准确率: {val_accuracy:.4f}") | ||
| 123 | + | ||
| 124 | + logging.info("模型训练完成") | ||
| 125 | + return model | ||
| 126 | + | ||
| 127 | + except Exception as e: | ||
| 128 | + logging.error(f"模型训练过程中发生错误: {e}") | ||
| 129 | + raise | ||
| 130 | + | ||
| 19 | def model_train(): | 131 | def model_train(): |
| 20 | - # 获取情感数据并转换为DataFrame | ||
| 21 | - sentiment_data = getSentiment_data() | ||
| 22 | - df = pd.DataFrame(sentiment_data, columns=['text', 'sentiment']) | 132 | + """训练模型并计算准确度""" |
| 133 | + try: | ||
| 134 | + # 加载数据 | ||
| 135 | + logging.info("开始加载数据...") | ||
| 136 | + data = pd.read_csv('data/train_data.csv') | ||
| 137 | + texts = data['text'].values | ||
| 138 | + labels = data['label'].values | ||
| 139 | + | ||
| 140 | + # 数据集分割 | ||
| 141 | + X_train, X_val, y_train, y_val = train_test_split( | ||
| 142 | + texts, labels, test_size=0.2, random_state=42 | ||
| 143 | + ) | ||
| 144 | + logging.info(f"训练集大小: {len(X_train)}, 验证集大小: {len(X_val)}") | ||
| 145 | + | ||
| 146 | + # 初始化tokenizer和数据集 | ||
| 147 | + tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') | ||
| 148 | + train_dataset = TextClassificationDataset(X_train, y_train, tokenizer) | ||
| 149 | + val_dataset = TextClassificationDataset(X_val, y_val, tokenizer) | ||
| 23 | 150 | ||
| 24 | - # 将数据集划分为训练集和测试集,测试集占20% | ||
| 25 | - train_data, test_data = train_test_split(df, test_size=0.2, random_state=42) | 151 | + train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True) |
| 152 | + val_loader = DataLoader(val_dataset, batch_size=16) | ||
| 26 | 153 | ||
| 27 | - # 初始化TfidfVectorizer,并对训练集和测试集进行文本特征提取 | ||
| 28 | - vectorize = TfidfVectorizer() | ||
| 29 | - X_train = vectorize.fit_transform(train_data['text']) | ||
| 30 | - y_train = train_data['sentiment'] | ||
| 31 | - X_test = vectorize.transform(test_data['text']) | ||
| 32 | - y_test = test_data['sentiment'] | 154 | + # 初始化模型 |
| 155 | + model = BertClassifier(n_classes=len(np.unique(labels))) | ||
| 156 | + logging.info("模型和数据加载器初始化完成") | ||
| 33 | 157 | ||
| 34 | - # 初始化多项式朴素贝叶斯分类器,并进行训练 | ||
| 35 | - classifier = MultinomialNB() | ||
| 36 | - classifier.fit(X_train, y_train) | 158 | + # 训练模型 |
| 159 | + trained_model = train_model(model, train_loader, val_loader) | ||
| 37 | 160 | ||
| 38 | - # 对测试集进行预测 | ||
| 39 | - y_pred = classifier.predict(X_test) | 161 | + # 保存模型 |
| 162 | + torch.save(trained_model.state_dict(), 'model/saved_model.pth') | ||
| 163 | + logging.info("模型已保存到 model/saved_model.pth") | ||
| 40 | 164 | ||
| 41 | - # 计算模型准确度 | ||
| 42 | - accuracy = accuracy_score(y_test, y_pred) | ||
| 43 | - print(accuracy) | 165 | + except Exception as e: |
| 166 | + logging.error(f"模型训练主函数发生错误: {e}") | ||
| 167 | + raise | ||
| 44 | 168 | ||
| 45 | if __name__ == "__main__": | 169 | if __name__ == "__main__": |
| 46 | - model_train() # 训练模型并计算准确度 | 170 | + try: |
| 171 | + model_train() | ||
| 172 | + except Exception as e: | ||
| 173 | + logging.error(f"程序执行失败: {e}") |
| @@ -5,109 +5,126 @@ from tqdm import tqdm | @@ -5,109 +5,126 @@ from tqdm import tqdm | ||
| 5 | from transformers.models.bert import BertTokenizer, BertModel | 5 | from transformers.models.bert import BertTokenizer, BertModel |
| 6 | from contextualized_topic_models.models.ctm import CombinedTM | 6 | from contextualized_topic_models.models.ctm import CombinedTM |
| 7 | from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation | 7 | from contextualized_topic_models.utils.data_preparation import TopicModelDataPreparation |
| 8 | +from contextualized_topic_models.utils.preprocessing import WhiteSpacePreprocessing | ||
| 8 | import numpy as np | 9 | import numpy as np |
| 9 | import torch | 10 | import torch |
| 10 | import jieba | 11 | import jieba |
| 11 | import pickle # 用于保存和加载模型 | 12 | import pickle # 用于保存和加载模型 |
| 13 | +from utils.logger import model_logger as logging | ||
| 12 | 14 | ||
| 13 | -class BERT_CTM_Model: | ||
| 14 | - def __init__(self, bert_model_path, ctm_tokenizer_path, n_components=12, num_epochs=50, model_save_path='./ctm_model'): | ||
| 15 | - self.bert_model_path = bert_model_path | ||
| 16 | - self.ctm_tokenizer_path = ctm_tokenizer_path | ||
| 17 | - self.n_components = n_components | ||
| 18 | - self.num_epochs = num_epochs | 15 | +class BERT_CTM: |
| 16 | + def __init__(self, model_save_path='model_pro/saved_models/ctm_model.pkl'): | ||
| 19 | self.model_save_path = model_save_path | 17 | self.model_save_path = model_save_path |
| 20 | - # 加载BERT模型和tokenizer | ||
| 21 | - self.tokenizer = BertTokenizer.from_pretrained(self.bert_model_path) | ||
| 22 | - self.model = BertModel.from_pretrained(self.bert_model_path) | 18 | + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 19 | + self.bert_model = None | ||
| 20 | + self.tokenizer = None | ||
| 21 | + self.ctm_model = None | ||
| 22 | + self.vocab = None | ||
| 23 | + self.vectorizer = None | ||
| 24 | + | ||
| 25 | + def save_model(self): | ||
| 26 | + """保存模型和词袋""" | ||
| 27 | + try: | ||
| 28 | + with open(self.model_save_path, 'wb') as f: | ||
| 29 | + pickle.dump({ | ||
| 30 | + 'ctm_model': self.ctm_model, | ||
| 31 | + 'vocab': self.vocab, | ||
| 32 | + 'vectorizer': self.vectorizer | ||
| 33 | + }, f) | ||
| 34 | + logging.info(f"CTM模型和词袋保存到: {self.model_save_path}") | ||
| 35 | + except Exception as e: | ||
| 36 | + logging.error(f"保存模型时发生错误: {e}") | ||
| 23 | 37 | ||
| 24 | - # 创建CTM数据预处理对象 | ||
| 25 | - self.tp = TopicModelDataPreparation(self.ctm_tokenizer_path) | 38 | + def load_model(self): |
| 39 | + """加载模型和词袋""" | ||
| 40 | + try: | ||
| 41 | + with open(self.model_save_path, 'rb') as f: | ||
| 42 | + saved_data = pickle.load(f) | ||
| 43 | + self.ctm_model = saved_data['ctm_model'] | ||
| 44 | + self.vocab = saved_data['vocab'] | ||
| 45 | + self.vectorizer = saved_data['vectorizer'] | ||
| 46 | + logging.info("CTM模型、词袋和vectorizer加载成功") | ||
| 47 | + except Exception as e: | ||
| 48 | + logging.error(f"加载模型时发生错误: {e}") | ||
| 49 | + raise | ||
| 50 | + | ||
| 51 | + def train(self, texts, num_topics=10, num_epochs=100): | ||
| 52 | + """训练CTM模型""" | ||
| 53 | + try: | ||
| 54 | + # 初始化BERT | ||
| 55 | + if not self.bert_model: | ||
| 56 | + self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') | ||
| 57 | + self.bert_model = BertModel.from_pretrained('bert-base-chinese').to(self.device) | ||
| 58 | + | ||
| 59 | + # 提取BERT嵌入 | ||
| 60 | + logging.info("正在提取BERT嵌入...") | ||
| 61 | + embeddings = self._get_bert_embeddings(texts) | ||
| 62 | + | ||
| 63 | + # 准备CTM数据 | ||
| 64 | + logging.info("正在准备CTM训练数据...") | ||
| 65 | + preprocessor = WhiteSpacePreprocessing(texts) | ||
| 66 | + dataset = TopicModelDataPreparation(embeddings) | ||
| 67 | + | ||
| 68 | + # 训练CTM模型 | ||
| 69 | + logging.info("正在训练CTM模型...") | ||
| 70 | + self.ctm_model = CombinedTM( | ||
| 71 | + bow_size=len(preprocessor.vocab), | ||
| 72 | + contextual_size=768, # BERT输出维度 | ||
| 73 | + n_components=num_topics, | ||
| 74 | + num_epochs=num_epochs | ||
| 75 | + ) | ||
| 76 | + self.ctm_model.fit(dataset) | ||
| 77 | + | ||
| 78 | + # 保存词袋相关数据 | ||
| 79 | + self.vocab = preprocessor.vocab | ||
| 80 | + self.vectorizer = preprocessor.vectorizer | ||
| 81 | + | ||
| 82 | + # 保存模型 | ||
| 83 | + self.save_model() | ||
| 84 | + logging.info("模型训练完成并保存") | ||
| 26 | 85 | ||
| 27 | - def chinese_tokenize(self, text): | ||
| 28 | - """使用jieba对中文文本进行分词""" | ||
| 29 | - return " ".join(jieba.cut(text)) | 86 | + except Exception as e: |
| 87 | + logging.error(f"训练模型时发生错误: {e}") | ||
| 88 | + raise | ||
| 30 | 89 | ||
| 31 | - def get_bert_embeddings(self, texts): | ||
| 32 | - """使用BERT模型生成文本的嵌入向量""" | 90 | + def _get_bert_embeddings(self, texts): |
| 91 | + """获取文本的BERT嵌入""" | ||
| 33 | embeddings = [] | 92 | embeddings = [] |
| 34 | - for text in tqdm(texts, desc="Processing texts with BERT"): | ||
| 35 | - inputs = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=80) | ||
| 36 | - with torch.no_grad(): | ||
| 37 | - outputs = self.model(**inputs) | ||
| 38 | - embeddings.append(outputs.last_hidden_state.cpu().numpy()) # [batch_size, sequence_length, hidden_size] | ||
| 39 | - return np.vstack(embeddings) | ||
| 40 | - | ||
| 41 | - def save_model(self, ctm): | ||
| 42 | - """保存CTM模型、词袋和BoW的vectorizer""" | ||
| 43 | - os.makedirs(self.model_save_path, exist_ok=True) | ||
| 44 | - with open(f"{self.model_save_path}/ctm_model.pkl", 'wb') as f: | ||
| 45 | - pickle.dump(ctm, f) | ||
| 46 | - with open(f"{self.model_save_path}/vocab.pkl", 'wb') as f: | ||
| 47 | - pickle.dump(self.tp.vocab, f) | ||
| 48 | - with open(f"{self.model_save_path}/vectorizer.pkl", 'wb') as f: # 保存BoW的vectorizer | ||
| 49 | - pickle.dump(self.tp.vectorizer, f) | ||
| 50 | - print(f"CTM模型和词袋保存到: {self.model_save_path}") | 93 | + try: |
| 94 | + for text in texts: | ||
| 95 | + inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512) | ||
| 96 | + inputs = {k: v.to(self.device) for k, v in inputs.items()} | ||
| 51 | 97 | ||
| 52 | - def load_model(self): | ||
| 53 | - """加载CTM模型、词袋和BoW的vectorizer""" | ||
| 54 | - with open(f"{self.model_save_path}/ctm_model.pkl", 'rb') as f: | ||
| 55 | - ctm = pickle.load(f) | ||
| 56 | - with open(f"{self.model_save_path}/vocab.pkl", 'rb') as f: | ||
| 57 | - self.tp.vocab = pickle.load(f) | ||
| 58 | - with open(f"{self.model_save_path}/vectorizer.pkl", 'rb') as f: # 加载BoW的vectorizer | ||
| 59 | - self.tp.vectorizer = pickle.load(f) | ||
| 60 | - print(f"CTM模型、词袋和vectorizer加载成功") | ||
| 61 | - return ctm | ||
| 62 | - | ||
| 63 | - def train(self, csv_file): | ||
| 64 | - """训练BERT + CTM模型并保存最终的特征向量和标签""" | ||
| 65 | - # 读取CSV文件中的文本和标签 | ||
| 66 | - data = pd.read_csv(csv_file) | ||
| 67 | - texts = data['TEXT'].tolist() | ||
| 68 | - labels = data['label'].tolist() | ||
| 69 | - | ||
| 70 | - # Step 1: 获取BERT的嵌入向量 | ||
| 71 | - print("Extracting BERT embeddings...") | ||
| 72 | - bert_embeddings = self.get_bert_embeddings(texts) # [batch_size, sequence_length, hidden_size] | ||
| 73 | - | ||
| 74 | - # Step 2: 准备CTM数据 | ||
| 75 | - print("Preparing data for CTM using training set...") | ||
| 76 | - bow_texts = [self.chinese_tokenize(text) for text in texts] | ||
| 77 | - training_dataset = self.tp.fit(text_for_contextual=texts, text_for_bow=bow_texts) | ||
| 78 | - | ||
| 79 | - # Step 3: 替换BERT嵌入 | ||
| 80 | - training_dataset._X = bert_embeddings[:, 0, :] # 只使用第一个token的向量用于CTM | ||
| 81 | - | ||
| 82 | - # Step 4: 训练CTM模型 | ||
| 83 | - print("Training CTM model...") | ||
| 84 | - ctm = CombinedTM(bow_size=len(self.tp.vocab), contextual_size=768, n_components=self.n_components, num_epochs=self.num_epochs) | ||
| 85 | - ctm.fit(train_dataset=training_dataset, verbose=True) | ||
| 86 | - | ||
| 87 | - # Step 5: 保存CTM模型和词袋 | ||
| 88 | - self.save_model(ctm) | ||
| 89 | - | ||
| 90 | - # Step 6: 获取CTM的特征向量 | ||
| 91 | - print("Generating CTM features...") | ||
| 92 | - ctm_features = ctm.get_doc_topic_distribution(training_dataset) # [batch_size, n_components] | ||
| 93 | - | ||
| 94 | - # Step 7: 将CTM特征扩展为与BERT的sequence长度一致 | ||
| 95 | - sequence_length = bert_embeddings.shape[1] | ||
| 96 | - ctm_features_expanded = np.repeat(ctm_features[:, np.newaxis, :], sequence_length, axis=1) # [batch_size, sequence_length, n_components] | ||
| 97 | - | ||
| 98 | - # Step 8: 拼接BERT嵌入和CTM特征 | ||
| 99 | - final_embeddings = np.concatenate([bert_embeddings, ctm_features_expanded], axis=-1) # [batch_size, sequence_length, hidden_size + n_components] | ||
| 100 | - | ||
| 101 | - return bert_embeddings | 98 | + with torch.no_grad(): |
| 99 | + outputs = self.bert_model(**inputs) | ||
| 100 | + # 使用[CLS]标记的输出作为文档表示 | ||
| 101 | + embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy() | ||
| 102 | + embeddings.append(embedding[0]) | ||
| 103 | + | ||
| 104 | + return np.array(embeddings) | ||
| 105 | + except Exception as e: | ||
| 106 | + logging.error(f"获取BERT嵌入时发生错误: {e}") | ||
| 107 | + raise | ||
| 108 | + | ||
| 109 | + def get_topics(self, num_words=10): | ||
| 110 | + """获取主题词""" | ||
| 111 | + try: | ||
| 112 | + if not self.ctm_model or not self.vocab: | ||
| 113 | + raise ValueError("模型未训练或未加载") | ||
| 114 | + | ||
| 115 | + topics = [] | ||
| 116 | + for topic_idx in range(self.ctm_model.n_components): | ||
| 117 | + topic = self.ctm_model.get_topic_lists(top_n=num_words)[topic_idx] | ||
| 118 | + topics.append(topic) | ||
| 119 | + return topics | ||
| 120 | + except Exception as e: | ||
| 121 | + logging.error(f"获取主题词时发生错误: {e}") | ||
| 122 | + raise | ||
| 102 | 123 | ||
| 103 | if __name__ == "__main__": | 124 | if __name__ == "__main__": |
| 104 | - # 创建BERT_CTM_Model实例 | ||
| 105 | - model = BERT_CTM_Model( | ||
| 106 | - bert_model_path='./bert_model', # BERT模型的路径 | ||
| 107 | - ctm_tokenizer_path='./sentence_bert_model', # CTM分词器的路径 | ||
| 108 | - n_components=12, # 主题数量 | ||
| 109 | - num_epochs=50, # 训练轮次 | ||
| 110 | - model_save_path='./ctm_model', # 保存路径 | 125 | + # 创建BERT_CTM实例 |
| 126 | + model = BERT_CTM( | ||
| 127 | + model_save_path='model_pro/saved_models/ctm_model.pkl', # 保存路径 | ||
| 111 | ) | 128 | ) |
| 112 | 129 | ||
| 113 | # 传入CSV文件路径进行训练 | 130 | # 传入CSV文件路径进行训练 |
| @@ -2,17 +2,7 @@ import os | @@ -2,17 +2,7 @@ import os | ||
| 2 | import pandas as pd | 2 | import pandas as pd |
| 3 | from sqlalchemy import create_engine | 3 | from sqlalchemy import create_engine |
| 4 | from getpass import getpass | 4 | from getpass import getpass |
| 5 | -import logging | ||
| 6 | - | ||
| 7 | -# 配置日志 | ||
| 8 | -logging.basicConfig( | ||
| 9 | - level=logging.INFO, | ||
| 10 | - format='%(asctime)s [%(levelname)s] %(message)s', | ||
| 11 | - handlers=[ | ||
| 12 | - logging.FileHandler("save_data.log"), | ||
| 13 | - logging.StreamHandler() | ||
| 14 | - ] | ||
| 15 | -) | 5 | +from utils.logger import spider_logger as logging |
| 16 | 6 | ||
| 17 | # 假设 articleAddr 和 commentsAddr 是绝对路径或相对于脚本的路径 | 7 | # 假设 articleAddr 和 commentsAddr 是绝对路径或相对于脚本的路径 |
| 18 | from spiderDataPackage.settings import articleAddr, commentsAddr | 8 | from spiderDataPackage.settings import articleAddr, commentsAddr |
| 1 | -import time | ||
| 2 | import requests | 1 | import requests |
| 3 | -import csv | 2 | +import pandas as pd |
| 3 | +import time | ||
| 4 | import os | 4 | import os |
| 5 | import random | 5 | import random |
| 6 | from datetime import datetime | 6 | from datetime import datetime |
| 7 | -from .settings import articleAddr, commentsAddr | 7 | +from .settings import articleAddr, commentsAddr, commentsUrl |
| 8 | +from utils.logger import spider_logger as logging | ||
| 8 | from requests.exceptions import RequestException | 9 | from requests.exceptions import RequestException |
| 9 | 10 | ||
| 10 | # 初始化,创建评论数据文件 | 11 | # 初始化,创建评论数据文件 |
| @@ -59,19 +60,65 @@ def readJson(response, articleId): | @@ -59,19 +60,65 @@ def readJson(response, articleId): | ||
| 59 | authorAvatar = comment['user']['avatar_large'] | 60 | authorAvatar = comment['user']['avatar_large'] |
| 60 | write([articleId, created_at, likes_counts, region, content, authorName, authorGender, authorAddress, authorAvatar]) | 61 | write([articleId, created_at, likes_counts, region, content, authorName, authorGender, authorAddress, authorAvatar]) |
| 61 | 62 | ||
| 62 | -# 启动爬虫 | ||
| 63 | -def start(headers_list, delay=2): | ||
| 64 | - commentUrl = 'https://weibo.com/ajax/statuses/buildComments' | ||
| 65 | - init() | ||
| 66 | - articleList = getArticleList() | ||
| 67 | - for article in articleList: | ||
| 68 | - articleId = article[0] | ||
| 69 | - print(f'正在爬取id值为{articleId}的文章评论') | ||
| 70 | - time.sleep(random.uniform(1, delay)) # 随机延时,避免频繁访问 | ||
| 71 | - params = {'id': int(articleId), 'is_show_bulletin': 2} | ||
| 72 | - response = fetchData(commentUrl, params, headers_list) | ||
| 73 | - if response: | ||
| 74 | - readJson(response, articleId) | 63 | +def getComments(articleId): |
| 64 | + """ | ||
| 65 | + 获取指定文章的评论数据 | ||
| 66 | + """ | ||
| 67 | + try: | ||
| 68 | + # 构建请求URL和头部 | ||
| 69 | + url = f"{commentsUrl}{articleId}" | ||
| 70 | + response = requests.get(url, headers=headers) | ||
| 71 | + response.raise_for_status() | ||
| 72 | + | ||
| 73 | + # 解析响应数据 | ||
| 74 | + data = response.json() | ||
| 75 | + if data['code'] == 200: | ||
| 76 | + return data['data'] | ||
| 77 | + else: | ||
| 78 | + logging.error(f"获取评论失败,状态码:{data['code']}") | ||
| 79 | + return None | ||
| 80 | + | ||
| 81 | + except requests.RequestException as e: | ||
| 82 | + logging.error(f"请求失败:{e}") | ||
| 83 | + return None | ||
| 84 | + | ||
| 85 | +def start(): | ||
| 86 | + """ | ||
| 87 | + 开始爬取评论数据 | ||
| 88 | + """ | ||
| 89 | + try: | ||
| 90 | + # 读取文章数据 | ||
| 91 | + article_df = pd.read_csv(articleAddr) | ||
| 92 | + comments_data = [] | ||
| 93 | + | ||
| 94 | + # 遍历每篇文章获取评论 | ||
| 95 | + for index, row in article_df.iterrows(): | ||
| 96 | + article_id = row['id'] | ||
| 97 | + logging.info(f'正在爬取id值为{article_id}的文章评论') | ||
| 98 | + | ||
| 99 | + comments = getComments(article_id) | ||
| 100 | + if comments: | ||
| 101 | + for comment in comments: | ||
| 102 | + comments_data.append({ | ||
| 103 | + 'article_id': article_id, | ||
| 104 | + 'content': comment.get('content', ''), | ||
| 105 | + 'created_at': comment.get('created_at', ''), | ||
| 106 | + 'like_count': comment.get('like_count', 0) | ||
| 107 | + }) | ||
| 108 | + | ||
| 109 | + # 避免请求过于频繁 | ||
| 110 | + time.sleep(1) | ||
| 111 | + | ||
| 112 | + # 保存评论数据 | ||
| 113 | + if comments_data: | ||
| 114 | + comments_df = pd.DataFrame(comments_data) | ||
| 115 | + comments_df.to_csv(commentsAddr, index=False, encoding='utf-8') | ||
| 116 | + logging.info(f"成功保存{len(comments_data)}条评论数据") | ||
| 117 | + else: | ||
| 118 | + logging.warning("未获取到任何评论数据") | ||
| 119 | + | ||
| 120 | + except Exception as e: | ||
| 121 | + logging.error(f"爬取评论数据时发生错误:{e}") | ||
| 75 | 122 | ||
| 76 | if __name__ == '__main__': | 123 | if __name__ == '__main__': |
| 77 | # 这里的headers_list应该包含多个账号的cookie | 124 | # 这里的headers_list应该包含多个账号的cookie |
| @@ -85,4 +132,4 @@ if __name__ == '__main__': | @@ -85,4 +132,4 @@ if __name__ == '__main__': | ||
| 85 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0' | 132 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:127.0) Gecko/20100101 Firefox/127.0' |
| 86 | } | 133 | } |
| 87 | ] | 134 | ] |
| 88 | - start(headers_list) | 135 | + start() |
utils/logger.py
0 → 100644
| 1 | +import os | ||
| 2 | +import logging | ||
| 3 | +from logging.handlers import RotatingFileHandler | ||
| 4 | + | ||
| 5 | +def setup_logger(name, log_file=None, level=logging.INFO): | ||
| 6 | + """ | ||
| 7 | + 设置统一的日志记录器 | ||
| 8 | + | ||
| 9 | + Args: | ||
| 10 | + name: 日志记录器名称 | ||
| 11 | + log_file: 日志文件路径,如果为None则只输出到控制台 | ||
| 12 | + level: 日志级别 | ||
| 13 | + | ||
| 14 | + Returns: | ||
| 15 | + logger: 配置好的日志记录器 | ||
| 16 | + """ | ||
| 17 | + # 创建日志记录器 | ||
| 18 | + logger = logging.getLogger(name) | ||
| 19 | + logger.setLevel(level) | ||
| 20 | + | ||
| 21 | + # 统一的日志格式 | ||
| 22 | + formatter = logging.Formatter( | ||
| 23 | + '%(asctime)s [%(name)s] [%(levelname)s] %(message)s', | ||
| 24 | + datefmt='%Y-%m-%d %H:%M:%S' | ||
| 25 | + ) | ||
| 26 | + | ||
| 27 | + # 添加控制台处理器 | ||
| 28 | + console_handler = logging.StreamHandler() | ||
| 29 | + console_handler.setFormatter(formatter) | ||
| 30 | + logger.addHandler(console_handler) | ||
| 31 | + | ||
| 32 | + # 如果指定了日志文件,添加文件处理器 | ||
| 33 | + if log_file: | ||
| 34 | + # 确保日志目录存在 | ||
| 35 | + log_dir = os.path.dirname(log_file) | ||
| 36 | + if log_dir and not os.path.exists(log_dir): | ||
| 37 | + os.makedirs(log_dir) | ||
| 38 | + | ||
| 39 | + # 使用 RotatingFileHandler 进行日志轮转 | ||
| 40 | + file_handler = RotatingFileHandler( | ||
| 41 | + log_file, | ||
| 42 | + maxBytes=10*1024*1024, # 10MB | ||
| 43 | + backupCount=5, | ||
| 44 | + encoding='utf-8' | ||
| 45 | + ) | ||
| 46 | + file_handler.setFormatter(formatter) | ||
| 47 | + logger.addHandler(file_handler) | ||
| 48 | + | ||
| 49 | + return logger | ||
| 50 | + | ||
| 51 | +# 创建默认的应用日志记录器 | ||
| 52 | +app_logger = setup_logger('weibo_analysis', 'logs/app.log') | ||
| 53 | +spider_logger = setup_logger('spider', 'logs/spider.log') | ||
| 54 | +model_logger = setup_logger('model', 'logs/model.log') | ||
| 55 | + | ||
| 56 | +# 导出日志记录器 | ||
| 57 | +__all__ = ['setup_logger', 'app_logger', 'spider_logger', 'model_logger'] |
| 1 | -import getpass | ||
| 2 | import pymysql | 1 | import pymysql |
| 3 | -import logging | 2 | +from getpass import getpass |
| 3 | +from utils.logger import app_logger as logging | ||
| 4 | 4 | ||
| 5 | # 配置日志 | 5 | # 配置日志 |
| 6 | logging.basicConfig( | 6 | logging.basicConfig( |
| @@ -28,7 +28,7 @@ def get_db_connection_interactive(): | @@ -28,7 +28,7 @@ def get_db_connection_interactive(): | ||
| 28 | port = 3306 | 28 | port = 3306 |
| 29 | 29 | ||
| 30 | user = input(" 3. 用户名 (默认: root): ") or "root" | 30 | user = input(" 3. 用户名 (默认: root): ") or "root" |
| 31 | - password = getpass.getpass(" 4. 密码 (默认: 12345678): ") or "12345678" | 31 | + password = getpass(" 4. 密码 (默认: 12345678): ") or "12345678" |
| 32 | db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or "Weibo_PublicOpinion_AnalysisSystem" | 32 | db_name = input(" 5. 数据库名 (默认: Weibo_PublicOpinion_AnalysisSystem): ") or "Weibo_PublicOpinion_AnalysisSystem" |
| 33 | 33 | ||
| 34 | logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}") | 34 | logging.info(f"尝试连接到数据库: {user}@{host}:{port}/{db_name}") |
| @@ -3,11 +3,11 @@ from utils.mynlp import SnowNLP | @@ -3,11 +3,11 @@ from utils.mynlp import SnowNLP | ||
| 3 | from utils.getHomePageData import * | 3 | from utils.getHomePageData import * |
| 4 | from utils.getHotWordPageData import * | 4 | from utils.getHotWordPageData import * |
| 5 | from utils.getTableData import * | 5 | from utils.getTableData import * |
| 6 | -from utils.getPublicData import getAllHotWords, getAllTopics | 6 | +from utils.getPublicData import getAllHotWords, getAllTopics, getArticleByType, getArticleById |
| 7 | from utils.getEchartsData import * | 7 | from utils.getEchartsData import * |
| 8 | from utils.getTopicPageData import * | 8 | from utils.getTopicPageData import * |
| 9 | from utils.yuqingpredict import * | 9 | from utils.yuqingpredict import * |
| 10 | -from utils.getPublicData import getAllHotWords | 10 | +from utils.logger import app_logger as logging |
| 11 | 11 | ||
| 12 | pb = Blueprint('page', | 12 | pb = Blueprint('page', |
| 13 | __name__, | 13 | __name__, |
| @@ -196,3 +196,40 @@ def yuqingpredict(): | @@ -196,3 +196,40 @@ def yuqingpredict(): | ||
| 196 | def articleCloud(): | 196 | def articleCloud(): |
| 197 | username = session.get('username') | 197 | username = session.get('username') |
| 198 | return render_template('articleContentCloud.html', username=username) | 198 | return render_template('articleContentCloud.html', username=username) |
| 199 | + | ||
| 200 | + | ||
| 201 | +@pb.route('/page/index') | ||
| 202 | +def index(): | ||
| 203 | + """首页路由""" | ||
| 204 | + try: | ||
| 205 | + hotWordList = getAllHotWords() | ||
| 206 | + logging.info("成功获取热词列表") | ||
| 207 | + return render_template('index.html', hotWordList=hotWordList) | ||
| 208 | + except Exception as e: | ||
| 209 | + logging.error(f"渲染首页时发生错误: {e}") | ||
| 210 | + return render_template('error.html', error_message="加载首页失败") | ||
| 211 | + | ||
| 212 | +@pb.route('/page/article/<type>') | ||
| 213 | +def article(type): | ||
| 214 | + """文章列表页路由""" | ||
| 215 | + try: | ||
| 216 | + articleList = getArticleByType(type) | ||
| 217 | + logging.info(f"成功获取类型为 {type} 的文章列表") | ||
| 218 | + return render_template('article.html', articleList=articleList) | ||
| 219 | + except Exception as e: | ||
| 220 | + logging.error(f"获取文章列表时发生错误: {e}") | ||
| 221 | + return render_template('error.html', error_message="加载文章列表失败") | ||
| 222 | + | ||
| 223 | +@pb.route('/page/articleChar/<id>') | ||
| 224 | +def articleChar(id): | ||
| 225 | + """文章详情页路由""" | ||
| 226 | + try: | ||
| 227 | + article = getArticleById(id) | ||
| 228 | + if not article: | ||
| 229 | + logging.warning(f"未找到ID为 {id} 的文章") | ||
| 230 | + return render_template('error.html', error_message="文章不存在") | ||
| 231 | + logging.info(f"成功获取ID为 {id} 的文章详情") | ||
| 232 | + return render_template('articleChar.html', article=article) | ||
| 233 | + except Exception as e: | ||
| 234 | + logging.error(f"获取文章详情时发生错误: {e}") | ||
| 235 | + return render_template('error.html', error_message="加载文章详情失败") |
| @@ -4,6 +4,7 @@ from flask import Blueprint, redirect, render_template, request, Flask, session | @@ -4,6 +4,7 @@ from flask import Blueprint, redirect, render_template, request, Flask, session | ||
| 4 | 4 | ||
| 5 | from utils.query import query | 5 | from utils.query import query |
| 6 | from utils.errorResponse import errorResponse | 6 | from utils.errorResponse import errorResponse |
| 7 | +from utils.logger import app_logger as logging | ||
| 7 | 8 | ||
| 8 | ub = Blueprint('user', | 9 | ub = Blueprint('user', |
| 9 | __name__, | 10 | __name__, |
| @@ -31,21 +32,29 @@ def login(): | @@ -31,21 +32,29 @@ def login(): | ||
| 31 | if request.method == 'GET': | 32 | if request.method == 'GET': |
| 32 | return render_template('login_and_register.html') # 显示登录页面 | 33 | return render_template('login_and_register.html') # 显示登录页面 |
| 33 | 34 | ||
| 34 | - # 提取表单数据 | ||
| 35 | - username = request.form.get('username', '').strip() | ||
| 36 | - password = hash_password(request.form.get('password', '').strip()) | 35 | + try: |
| 36 | + username = request.form.get('username') | ||
| 37 | + password = request.form.get('password') | ||
| 37 | 38 | ||
| 38 | - # 查询用户信息 | ||
| 39 | - user_query = 'SELECT * FROM user WHERE username = %s AND password = %s' | ||
| 40 | - users = query(user_query, [username, password], 'select') | 39 | + if not username or not password: |
| 40 | + logging.warning("登录失败:用户名或密码为空") | ||
| 41 | + return render_template('login_and_register.html', msg='用户名和密码不能为空') | ||
| 41 | 42 | ||
| 42 | - if not users: | ||
| 43 | - # 登录失败,返回登录页面并显示错误信息 | ||
| 44 | - return render_template('login_and_register.html', error='账号或密码错误', username=username) | 43 | + # 查询用户 |
| 44 | + sql = "SELECT * FROM user WHERE username = %s AND password = %s" | ||
| 45 | + result = query(sql, [username, password], "select") | ||
| 45 | 46 | ||
| 46 | - # 登录成功,设置会话并重定向 | 47 | + if result: |
| 47 | session['username'] = username | 48 | session['username'] = username |
| 49 | + logging.info(f"用户 {username} 登录成功") | ||
| 48 | return redirect('/page/home') | 50 | return redirect('/page/home') |
| 51 | + else: | ||
| 52 | + logging.warning(f"用户 {username} 登录失败:用户名或密码错误") | ||
| 53 | + return render_template('login_and_register.html', msg='用户名或密码错误') | ||
| 54 | + | ||
| 55 | + except Exception as e: | ||
| 56 | + logging.error(f"登录过程发生错误: {e}") | ||
| 57 | + return render_template('login_and_register.html', msg='登录失败,请稍后重试') | ||
| 49 | 58 | ||
| 50 | 59 | ||
| 51 | @ub.route('/register', methods=['GET', 'POST']) | 60 | @ub.route('/register', methods=['GET', 'POST']) |
| @@ -82,3 +91,15 @@ def register(): | @@ -82,3 +91,15 @@ def register(): | ||
| 82 | def logOut(): | 91 | def logOut(): |
| 83 | session.clear() | 92 | session.clear() |
| 84 | return redirect('/user/login') | 93 | return redirect('/user/login') |
| 94 | + | ||
| 95 | +@ub.route('/user/logout') | ||
| 96 | +def logout(): | ||
| 97 | + """用户登出""" | ||
| 98 | + try: | ||
| 99 | + username = session.get('username') | ||
| 100 | + session.clear() | ||
| 101 | + logging.info(f"用户 {username} 成功登出") | ||
| 102 | + return redirect('/user/login') | ||
| 103 | + except Exception as e: | ||
| 104 | + logging.error(f"登出过程发生错误: {e}") | ||
| 105 | + return redirect('/user/login') |
| @@ -5,17 +5,7 @@ import matplotlib.pyplot as plt | @@ -5,17 +5,7 @@ import matplotlib.pyplot as plt | ||
| 5 | from PIL import Image | 5 | from PIL import Image |
| 6 | import numpy as np | 6 | import numpy as np |
| 7 | import pymysql | 7 | import pymysql |
| 8 | -import logging | ||
| 9 | - | ||
| 10 | -# Configure logging | ||
| 11 | -logging.basicConfig( | ||
| 12 | - level=logging.INFO, | ||
| 13 | - format='%(asctime)s [%(levelname)s] %(message)s', | ||
| 14 | - handlers=[ | ||
| 15 | - logging.FileHandler("wordcloud_generator.log"), | ||
| 16 | - logging.StreamHandler() | ||
| 17 | - ] | ||
| 18 | -) | 8 | +from utils.logger import app_logger as logging |
| 19 | 9 | ||
| 20 | # Global cache for stop words | 10 | # Global cache for stop words |
| 21 | STOP_WORDS = set() | 11 | STOP_WORDS = set() |
-
Please register or login to post a comment