戒酒的李白

The old emotion recognition model has been replaced with the new model_pro, and …

…the results have been integrated into the project.
@@ -6,12 +6,12 @@ from tqdm import tqdm @@ -6,12 +6,12 @@ from tqdm import tqdm
6 import os 6 import os
7 import sys 7 import sys
8 import json 8 import json
9 -import chardet # 导入 chardet 9 +import chardet
10 10
11 -# 导入您定义的模型和模块  
12 -from MHA import MultiHeadAttentionLayer  
13 -from classifier import FinalClassifier  
14 -from BERT_CTM import BERT_CTM_Model 11 +# 导入改进版模型的组件
  12 +from model_pro.MHA import MultiHeadAttentionLayer
  13 +from model_pro.classifier import FinalClassifier
  14 +from model_pro.BERT_CTM import BERT_CTM_Model
15 15
16 # 设置设备 16 # 设置设备
17 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -30,7 +30,7 @@ def detect_file_encoding(file_path, num_bytes=10000): @@ -30,7 +30,7 @@ def detect_file_encoding(file_path, num_bytes=10000):
30 result = chardet.detect(rawdata) 30 result = chardet.detect(rawdata)
31 encoding = result['encoding'] 31 encoding = result['encoding']
32 confidence = result['confidence'] 32 confidence = result['confidence']
33 - print(f"Detected encoding: {encoding} with confidence {confidence}") 33 + print(f"检测到的编码: {encoding}, 置信度: {confidence}")
34 return encoding 34 return encoding
35 35
36 36
@@ -42,8 +42,6 @@ def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_compon @@ -42,8 +42,6 @@ def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_compon
42 n_components=n_components, 42 n_components=n_components,
43 num_epochs=num_epochs 43 num_epochs=num_epochs
44 ) 44 )
45 - # 加载已保存的CTM模型  
46 - bert_ctm_model.load_model()  
47 # 获取嵌入 45 # 获取嵌入
48 embeddings = bert_ctm_model.get_bert_embeddings(texts) 46 embeddings = bert_ctm_model.get_bert_embeddings(texts)
49 return embeddings 47 return embeddings
@@ -60,15 +58,11 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ @@ -60,15 +58,11 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
60 num_classes=2): 58 num_classes=2):
61 try: 59 try:
62 # 加载模型 60 # 加载模型
63 - # 修改这里,设置 weights_only=True 以消除 FutureWarning  
64 - checkpoint = torch.load(model_save_path, map_location=device, weights_only=False)  
65 - classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes)  
66 - classifier_model.load_state_dict(checkpoint['classifier_model_state_dict'])  
67 - classifier_model.to(device) 61 + print("加载模型...")
  62 + classifier_model = torch.load(model_save_path, map_location=device)
68 classifier_model.eval() 63 classifier_model.eval()
69 64
70 attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) 65 attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
71 - attention_model.load_state_dict(checkpoint['attention_model_state_dict'])  
72 attention_model.to(device) 66 attention_model.to(device)
73 attention_model.eval() 67 attention_model.eval()
74 68
@@ -76,11 +70,12 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ @@ -76,11 +70,12 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
76 encoding = detect_file_encoding(input_data_path) 70 encoding = detect_file_encoding(input_data_path)
77 71
78 # 读取输入数据 72 # 读取输入数据
  73 + print("读取输入数据...")
79 data = pd.read_csv(input_data_path, encoding=encoding) 74 data = pd.read_csv(input_data_path, encoding=encoding)
80 texts = data['TEXT'].tolist() 75 texts = data['TEXT'].tolist()
81 76
82 # 生成嵌入 77 # 生成嵌入
83 - print("Generating embeddings...") 78 + print("生成文本嵌入...")
84 embeddings = get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path) 79 embeddings = get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path)
85 80
86 # 准备DataLoader 81 # 准备DataLoader
@@ -88,63 +83,89 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ @@ -88,63 +83,89 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_
88 83
89 # 存储预测结果 84 # 存储预测结果
90 all_predictions = [] 85 all_predictions = []
  86 + all_probabilities = []
91 87
  88 + print("开始预测...")
92 with torch.no_grad(): 89 with torch.no_grad():
93 - for batch in tqdm(data_loader, desc="Predicting"): 90 + for batch in tqdm(data_loader, desc="预测进度"):
94 batch_x = batch[0].to(device) 91 batch_x = batch[0].to(device)
95 batch_x = torch.mean(batch_x, dim=1) 92 batch_x = torch.mean(batch_x, dim=1)
  93 +
  94 + # 使用注意力机制
96 attention_output = attention_model(batch_x, batch_x, batch_x) 95 attention_output = attention_model(batch_x, batch_x, batch_x)
  96 +
  97 + # 获取分类结果
97 outputs = classifier_model(attention_output) 98 outputs = classifier_model(attention_output)
98 outputs = torch.mean(outputs, dim=1) 99 outputs = torch.mean(outputs, dim=1)
  100 +
  101 + # 获取预测概率
  102 + probabilities = torch.softmax(outputs, dim=1)
  103 +
  104 + # 获取预测标签
99 _, predicted = torch.max(outputs, 1) 105 _, predicted = torch.max(outputs, 1)
  106 +
100 all_predictions.extend(predicted.cpu().numpy()) 107 all_predictions.extend(predicted.cpu().numpy())
  108 + all_probabilities.extend(probabilities.cpu().numpy())
101 109
102 - # 保存预测结果 110 + # 添加预测结果和概率到数据框
103 data['Predicted_Label'] = all_predictions 111 data['Predicted_Label'] = all_predictions
  112 + data['Confidence'] = [prob[pred] for prob, pred in zip(all_probabilities, all_predictions)]
  113 +
  114 + # 保存预测结果
104 data.to_csv(output_path, index=False, encoding='utf-8') 115 data.to_csv(output_path, index=False, encoding='utf-8')
105 - print(f"Predictions saved to {output_path}") 116 + print(f"预测结果已保存到 {output_path}")
106 117
107 # 统计标签的个数和占比 118 # 统计标签的个数和占比
108 label_counts = data['Predicted_Label'].value_counts() 119 label_counts = data['Predicted_Label'].value_counts()
109 total_count = len(data) 120 total_count = len(data)
110 - stats = {} 121 + stats = {
  122 + '统计信息': {
  123 + '总样本数': total_count,
  124 + '各类别统计': {}
  125 + }
  126 + }
  127 +
111 for label, count in label_counts.items(): 128 for label, count in label_counts.items():
112 label_name = "良好" if label == 0 else "不良" 129 label_name = "良好" if label == 0 else "不良"
113 percentage = (count / total_count) * 100 130 percentage = (count / total_count) * 100
114 - stats[label_name] = {  
115 - 'count': count,  
116 - 'percentage': f"{percentage:.2f}%" 131 + confidence_mean = data[data['Predicted_Label'] == label]['Confidence'].mean()
  132 +
  133 + stats['统计信息']['各类别统计'][label_name] = {
  134 + '数量': int(count),
  135 + '占比': f"{percentage:.2f}%",
  136 + '平均置信度': f"{confidence_mean:.2f}"
117 } 137 }
118 - print(f"Label: {label_name}, Count: {count}, Percentage: {percentage:.2f}%") 138 + print(f"标签: {label_name}, 数量: {count}, 占比: {percentage:.2f}%, 平均置信度: {confidence_mean:.2f}")
119 139
120 # 将统计信息保存到 JSON 文件 140 # 将统计信息保存到 JSON 文件
121 with open(stats_output_path, 'w', encoding='utf-8') as f: 141 with open(stats_output_path, 'w', encoding='utf-8') as f:
122 - json.dump(stats, f, ensure_ascii=False) 142 + json.dump(stats, f, ensure_ascii=False, indent=4)
123 143
124 - return True # 成功执行 144 + return True
125 except Exception as e: 145 except Exception as e:
126 - print(f"Error during prediction: {e}")  
127 - return False # 执行失败 146 + print(f"预测过程中出现错误: {e}")
  147 + return False
128 148
129 149
130 if __name__ == "__main__": 150 if __name__ == "__main__":
131 if len(sys.argv) != 3: 151 if len(sys.argv) != 3:
132 - print("Usage: python using_example.py <input_data_path> <stats_output_path>") 152 + print("使用方法: python predict.py <input_data_path> <stats_output_path>")
133 sys.exit(1) 153 sys.exit(1)
134 154
135 input_data_path = sys.argv[1] 155 input_data_path = sys.argv[1]
136 stats_output_path = sys.argv[2] 156 stats_output_path = sys.argv[2]
  157 +
137 # 定义路径 158 # 定义路径
138 - model_save_path = 'BCAT/final_model.pt'  
139 - output_path = 'BCAT/predictions.csv' # 保存预测结果的文件  
140 - bert_model_path = 'BCAT/bert_model'  
141 - ctm_tokenizer_path = 'BCAT/sentence_bert_model' 159 + model_save_path = 'model_pro/final_model.pt'
  160 + output_path = 'model_pro/predictions.csv'
  161 + bert_model_path = 'model_pro/bert_model'
  162 + ctm_tokenizer_path = 'model_pro/sentence_bert_model'
142 163
143 # 执行预测 164 # 执行预测
144 success = predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path, 165 success = predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path,
145 - stats_output_path) 166 + stats_output_path)
146 167
147 if success: 168 if success:
148 - sys.exit(0) # 成功 169 + sys.exit(0)
149 else: 170 else:
150 - sys.exit(1) # 失败 171 + sys.exit(1)
1 from utils.getPublicData import * # Import utility functions for data retrieval 1 from utils.getPublicData import * # Import utility functions for data retrieval
2 from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis 2 from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis
3 from collections import Counter # Import Counter for counting occurrences 3 from collections import Counter # Import Counter for counting occurrences
  4 +import torch
  5 +from model_pro.MHA import MultiHeadAttentionLayer
  6 +from model_pro.classifier import FinalClassifier
  7 +from model_pro.BERT_CTM import BERT_CTM_Model
4 8
5 articleList = getAllArticleData() # Retrieve all article data 9 articleList = getAllArticleData() # Retrieve all article data
6 commentList = getAllCommentsData() # Retrieve all comment data 10 commentList = getAllCommentsData() # Retrieve all comment data
7 11
  12 +# 设置设备
  13 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  14 +
  15 +# 加载模型(全局变量,避免重复加载)
  16 +model_save_path = 'model_pro/final_model.pt'
  17 +bert_model_path = 'model_pro/bert_model'
  18 +ctm_tokenizer_path = 'model_pro/sentence_bert_model'
  19 +
  20 +try:
  21 + classifier_model = torch.load(model_save_path, map_location=device)
  22 + classifier_model.eval()
  23 + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
  24 + attention_model.to(device)
  25 + attention_model.eval()
  26 + bert_ctm_model = BERT_CTM_Model(
  27 + bert_model_path=bert_model_path,
  28 + ctm_tokenizer_path=ctm_tokenizer_path
  29 + )
  30 +except Exception as e:
  31 + print(f"模型加载失败: {e}")
  32 +
  33 +def predict_sentiment(texts):
  34 + """使用改进版模型预测情感"""
  35 + try:
  36 + # 获取文本嵌入
  37 + embeddings = bert_ctm_model.get_bert_embeddings(texts)
  38 +
  39 + # 转换为tensor
  40 + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
  41 + batch_x = torch.mean(batch_x, dim=1)
  42 +
  43 + with torch.no_grad():
  44 + # 使用注意力机制
  45 + attention_output = attention_model(batch_x, batch_x, batch_x)
  46 + # 获取分类结果
  47 + outputs = classifier_model(attention_output)
  48 + outputs = torch.mean(outputs, dim=1)
  49 + # 获取预测标签
  50 + _, predicted = torch.max(outputs, 1)
  51 +
  52 + return predicted.cpu().numpy()
  53 + except Exception as e:
  54 + print(f"预测过程中出现错误: {e}")
  55 + return None
  56 +
8 def getTypeList(): 57 def getTypeList():
9 # Return a list of unique article types 58 # Return a list of unique article types
10 return list(set([x[8] for x in articleList])) 59 return list(set([x[8] for x in articleList]))
@@ -119,32 +168,32 @@ def getYuQingCharDataOne(): @@ -119,32 +168,32 @@ def getYuQingCharDataOne():
119 return X, Y, biedata 168 return X, Y, biedata
120 169
121 def getYuQingCharDataTwo(): 170 def getYuQingCharDataTwo():
122 - # Analyze sentiment of comments and articles  
123 - comment_sentiments = []  
124 - for comment in commentList:  
125 - emotionValue = SnowNLP(comment[4]).sentiments  
126 - if emotionValue > 0.4:  
127 - comment_sentiments.append('正面')  
128 - elif emotionValue < 0.2:  
129 - comment_sentiments.append('负面')  
130 - else:  
131 - comment_sentiments.append('中性')  
132 - comment_counts = Counter(comment_sentiments) 171 + # 分析评论和文章的情感
  172 + comment_texts = [comment[4] for comment in commentList]
  173 + article_texts = [article[5] for article in articleList]
133 174
134 - article_sentiments = []  
135 - for article in articleList:  
136 - emotionValue = SnowNLP(article[5]).sentiments  
137 - if emotionValue > 0.4:  
138 - article_sentiments.append('正面')  
139 - elif emotionValue < 0.2:  
140 - article_sentiments.append('负面')  
141 - else:  
142 - article_sentiments.append('中性') 175 + # 预测评论情感
  176 + comment_predictions = predict_sentiment(comment_texts)
  177 + if comment_predictions is not None:
  178 + comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions]
  179 + else:
  180 + comment_sentiments = []
  181 +
  182 + # 预测文章情感
  183 + article_predictions = predict_sentiment(article_texts)
  184 + if article_predictions is not None:
  185 + article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions]
  186 + else:
  187 + article_sentiments = []
  188 +
  189 + # 统计结果
  190 + comment_counts = Counter(comment_sentiments)
143 article_counts = Counter(article_sentiments) 191 article_counts = Counter(article_sentiments)
144 192
145 - X = ['正面', '中性', '负面'] 193 + X = ['良好', '不良']
146 biedata1 = [{'name': x, 'value': comment_counts.get(x, 0)} for x in X] 194 biedata1 = [{'name': x, 'value': comment_counts.get(x, 0)} for x in X]
147 biedata2 = [{'name': x, 'value': article_counts.get(x, 0)} for x in X] 195 biedata2 = [{'name': x, 'value': article_counts.get(x, 0)} for x in X]
  196 +
148 return biedata1, biedata2 197 return biedata1, biedata2
149 198
150 def getYuQingCharDataThree(): 199 def getYuQingCharDataThree():
@@ -8,12 +8,61 @@ from utils.getEchartsData import * @@ -8,12 +8,61 @@ from utils.getEchartsData import *
8 from utils.getTopicPageData import * 8 from utils.getTopicPageData import *
9 from utils.yuqingpredict import * 9 from utils.yuqingpredict import *
10 from utils.logger import app_logger as logging 10 from utils.logger import app_logger as logging
  11 +import torch
  12 +from model_pro.MHA import MultiHeadAttentionLayer
  13 +from model_pro.classifier import FinalClassifier
  14 +from model_pro.BERT_CTM import BERT_CTM_Model
11 15
12 pb = Blueprint('page', 16 pb = Blueprint('page',
13 __name__, 17 __name__,
14 url_prefix='/page', 18 url_prefix='/page',
15 template_folder='templates') 19 template_folder='templates')
16 20
  21 +# 设置设备
  22 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  23 +
  24 +# 加载模型(全局变量,避免重复加载)
  25 +model_save_path = 'model_pro/final_model.pt'
  26 +bert_model_path = 'model_pro/bert_model'
  27 +ctm_tokenizer_path = 'model_pro/sentence_bert_model'
  28 +
  29 +try:
  30 + classifier_model = torch.load(model_save_path, map_location=device)
  31 + classifier_model.eval()
  32 + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
  33 + attention_model.to(device)
  34 + attention_model.eval()
  35 + bert_ctm_model = BERT_CTM_Model(
  36 + bert_model_path=bert_model_path,
  37 + ctm_tokenizer_path=ctm_tokenizer_path
  38 + )
  39 +except Exception as e:
  40 + print(f"模型加载失败: {e}")
  41 +
  42 +def predict_sentiment(text):
  43 + """使用改进版模型预测单个文本的情感"""
  44 + try:
  45 + # 获取文本嵌入
  46 + embeddings = bert_ctm_model.get_bert_embeddings([text])
  47 +
  48 + # 转换为tensor
  49 + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
  50 + batch_x = torch.mean(batch_x, dim=1)
  51 +
  52 + with torch.no_grad():
  53 + # 使用注意力机制
  54 + attention_output = attention_model(batch_x, batch_x, batch_x)
  55 + # 获取分类结果
  56 + outputs = classifier_model(attention_output)
  57 + outputs = torch.mean(outputs, dim=1)
  58 + # 获取预测标签和概率
  59 + probabilities = torch.softmax(outputs, dim=1)
  60 + _, predicted = torch.max(outputs, 1)
  61 +
  62 + return predicted.item(), probabilities[0][predicted.item()].item()
  63 + except Exception as e:
  64 + print(f"预测过程中出现错误: {e}")
  65 + return None, None
17 66
18 @pb.route('/home') 67 @pb.route('/home')
19 def home(): 68 def home():
@@ -172,14 +221,15 @@ def yuqingpredict(): @@ -172,14 +221,15 @@ def yuqingpredict():
172 defaultTopic = request.args.get('Topic') 221 defaultTopic = request.args.get('Topic')
173 TopicLen = getTopicLen(defaultTopic) 222 TopicLen = getTopicLen(defaultTopic)
174 X, Y = getTopicCreatedAtandpredictData(defaultTopic) 223 X, Y = getTopicCreatedAtandpredictData(defaultTopic)
175 - sentences = ''  
176 - value = SnowNLP(defaultTopic).sentiments  
177 - if value == 0.5:  
178 - sentences = '中性'  
179 - elif value > 0.5:  
180 - sentences = '正面'  
181 - elif value < 0.5:  
182 - sentences = '负面' 224 +
  225 + # 使用改进版模型进行情感预测
  226 + predicted_label, confidence = predict_sentiment(defaultTopic)
  227 + if predicted_label is not None:
  228 + sentences = '良好' if predicted_label == 0 else '不良'
  229 + sentences = f"{sentences} (置信度: {confidence:.2f})"
  230 + else:
  231 + sentences = '预测失败'
  232 +
183 comments = getCommentFilterDataTopic(defaultTopic) 233 comments = getCommentFilterDataTopic(defaultTopic)
184 return render_template('yuqingpredict.html', 234 return render_template('yuqingpredict.html',
185 username=username, 235 username=username,