The old emotion recognition model has been replaced with the new model_pro, and …
…the results have been integrated into the project.
Showing
3 changed files
with
181 additions
and
61 deletions
| @@ -6,12 +6,12 @@ from tqdm import tqdm | @@ -6,12 +6,12 @@ from tqdm import tqdm | ||
| 6 | import os | 6 | import os |
| 7 | import sys | 7 | import sys |
| 8 | import json | 8 | import json |
| 9 | -import chardet # 导入 chardet | 9 | +import chardet |
| 10 | 10 | ||
| 11 | -# 导入您定义的模型和模块 | ||
| 12 | -from MHA import MultiHeadAttentionLayer | ||
| 13 | -from classifier import FinalClassifier | ||
| 14 | -from BERT_CTM import BERT_CTM_Model | 11 | +# 导入改进版模型的组件 |
| 12 | +from model_pro.MHA import MultiHeadAttentionLayer | ||
| 13 | +from model_pro.classifier import FinalClassifier | ||
| 14 | +from model_pro.BERT_CTM import BERT_CTM_Model | ||
| 15 | 15 | ||
| 16 | # 设置设备 | 16 | # 设置设备 |
| 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| @@ -30,7 +30,7 @@ def detect_file_encoding(file_path, num_bytes=10000): | @@ -30,7 +30,7 @@ def detect_file_encoding(file_path, num_bytes=10000): | ||
| 30 | result = chardet.detect(rawdata) | 30 | result = chardet.detect(rawdata) |
| 31 | encoding = result['encoding'] | 31 | encoding = result['encoding'] |
| 32 | confidence = result['confidence'] | 32 | confidence = result['confidence'] |
| 33 | - print(f"Detected encoding: {encoding} with confidence {confidence}") | 33 | + print(f"检测到的编码: {encoding}, 置信度: {confidence}") |
| 34 | return encoding | 34 | return encoding |
| 35 | 35 | ||
| 36 | 36 | ||
| @@ -42,8 +42,6 @@ def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_compon | @@ -42,8 +42,6 @@ def get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path, n_compon | ||
| 42 | n_components=n_components, | 42 | n_components=n_components, |
| 43 | num_epochs=num_epochs | 43 | num_epochs=num_epochs |
| 44 | ) | 44 | ) |
| 45 | - # 加载已保存的CTM模型 | ||
| 46 | - bert_ctm_model.load_model() | ||
| 47 | # 获取嵌入 | 45 | # 获取嵌入 |
| 48 | embeddings = bert_ctm_model.get_bert_embeddings(texts) | 46 | embeddings = bert_ctm_model.get_bert_embeddings(texts) |
| 49 | return embeddings | 47 | return embeddings |
| @@ -60,15 +58,11 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | @@ -60,15 +58,11 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | ||
| 60 | num_classes=2): | 58 | num_classes=2): |
| 61 | try: | 59 | try: |
| 62 | # 加载模型 | 60 | # 加载模型 |
| 63 | - # 修改这里,设置 weights_only=True 以消除 FutureWarning | ||
| 64 | - checkpoint = torch.load(model_save_path, map_location=device, weights_only=False) | ||
| 65 | - classifier_model = FinalClassifier(input_dim=768, num_classes=num_classes) | ||
| 66 | - classifier_model.load_state_dict(checkpoint['classifier_model_state_dict']) | ||
| 67 | - classifier_model.to(device) | 61 | + print("加载模型...") |
| 62 | + classifier_model = torch.load(model_save_path, map_location=device) | ||
| 68 | classifier_model.eval() | 63 | classifier_model.eval() |
| 69 | 64 | ||
| 70 | attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | 65 | attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) |
| 71 | - attention_model.load_state_dict(checkpoint['attention_model_state_dict']) | ||
| 72 | attention_model.to(device) | 66 | attention_model.to(device) |
| 73 | attention_model.eval() | 67 | attention_model.eval() |
| 74 | 68 | ||
| @@ -76,11 +70,12 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | @@ -76,11 +70,12 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | ||
| 76 | encoding = detect_file_encoding(input_data_path) | 70 | encoding = detect_file_encoding(input_data_path) |
| 77 | 71 | ||
| 78 | # 读取输入数据 | 72 | # 读取输入数据 |
| 73 | + print("读取输入数据...") | ||
| 79 | data = pd.read_csv(input_data_path, encoding=encoding) | 74 | data = pd.read_csv(input_data_path, encoding=encoding) |
| 80 | texts = data['TEXT'].tolist() | 75 | texts = data['TEXT'].tolist() |
| 81 | 76 | ||
| 82 | # 生成嵌入 | 77 | # 生成嵌入 |
| 83 | - print("Generating embeddings...") | 78 | + print("生成文本嵌入...") |
| 84 | embeddings = get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path) | 79 | embeddings = get_bert_ctm_embeddings(texts, bert_model_path, ctm_tokenizer_path) |
| 85 | 80 | ||
| 86 | # 准备DataLoader | 81 | # 准备DataLoader |
| @@ -88,63 +83,89 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | @@ -88,63 +83,89 @@ def predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_ | ||
| 88 | 83 | ||
| 89 | # 存储预测结果 | 84 | # 存储预测结果 |
| 90 | all_predictions = [] | 85 | all_predictions = [] |
| 86 | + all_probabilities = [] | ||
| 91 | 87 | ||
| 88 | + print("开始预测...") | ||
| 92 | with torch.no_grad(): | 89 | with torch.no_grad(): |
| 93 | - for batch in tqdm(data_loader, desc="Predicting"): | 90 | + for batch in tqdm(data_loader, desc="预测进度"): |
| 94 | batch_x = batch[0].to(device) | 91 | batch_x = batch[0].to(device) |
| 95 | batch_x = torch.mean(batch_x, dim=1) | 92 | batch_x = torch.mean(batch_x, dim=1) |
| 93 | + | ||
| 94 | + # 使用注意力机制 | ||
| 96 | attention_output = attention_model(batch_x, batch_x, batch_x) | 95 | attention_output = attention_model(batch_x, batch_x, batch_x) |
| 96 | + | ||
| 97 | + # 获取分类结果 | ||
| 97 | outputs = classifier_model(attention_output) | 98 | outputs = classifier_model(attention_output) |
| 98 | outputs = torch.mean(outputs, dim=1) | 99 | outputs = torch.mean(outputs, dim=1) |
| 100 | + | ||
| 101 | + # 获取预测概率 | ||
| 102 | + probabilities = torch.softmax(outputs, dim=1) | ||
| 103 | + | ||
| 104 | + # 获取预测标签 | ||
| 99 | _, predicted = torch.max(outputs, 1) | 105 | _, predicted = torch.max(outputs, 1) |
| 106 | + | ||
| 100 | all_predictions.extend(predicted.cpu().numpy()) | 107 | all_predictions.extend(predicted.cpu().numpy()) |
| 108 | + all_probabilities.extend(probabilities.cpu().numpy()) | ||
| 101 | 109 | ||
| 102 | - # 保存预测结果 | 110 | + # 添加预测结果和概率到数据框 |
| 103 | data['Predicted_Label'] = all_predictions | 111 | data['Predicted_Label'] = all_predictions |
| 112 | + data['Confidence'] = [prob[pred] for prob, pred in zip(all_probabilities, all_predictions)] | ||
| 113 | + | ||
| 114 | + # 保存预测结果 | ||
| 104 | data.to_csv(output_path, index=False, encoding='utf-8') | 115 | data.to_csv(output_path, index=False, encoding='utf-8') |
| 105 | - print(f"Predictions saved to {output_path}") | 116 | + print(f"预测结果已保存到 {output_path}") |
| 106 | 117 | ||
| 107 | # 统计标签的个数和占比 | 118 | # 统计标签的个数和占比 |
| 108 | label_counts = data['Predicted_Label'].value_counts() | 119 | label_counts = data['Predicted_Label'].value_counts() |
| 109 | total_count = len(data) | 120 | total_count = len(data) |
| 110 | - stats = {} | 121 | + stats = { |
| 122 | + '统计信息': { | ||
| 123 | + '总样本数': total_count, | ||
| 124 | + '各类别统计': {} | ||
| 125 | + } | ||
| 126 | + } | ||
| 127 | + | ||
| 111 | for label, count in label_counts.items(): | 128 | for label, count in label_counts.items(): |
| 112 | label_name = "良好" if label == 0 else "不良" | 129 | label_name = "良好" if label == 0 else "不良" |
| 113 | percentage = (count / total_count) * 100 | 130 | percentage = (count / total_count) * 100 |
| 114 | - stats[label_name] = { | ||
| 115 | - 'count': count, | ||
| 116 | - 'percentage': f"{percentage:.2f}%" | 131 | + confidence_mean = data[data['Predicted_Label'] == label]['Confidence'].mean() |
| 132 | + | ||
| 133 | + stats['统计信息']['各类别统计'][label_name] = { | ||
| 134 | + '数量': int(count), | ||
| 135 | + '占比': f"{percentage:.2f}%", | ||
| 136 | + '平均置信度': f"{confidence_mean:.2f}" | ||
| 117 | } | 137 | } |
| 118 | - print(f"Label: {label_name}, Count: {count}, Percentage: {percentage:.2f}%") | 138 | + print(f"标签: {label_name}, 数量: {count}, 占比: {percentage:.2f}%, 平均置信度: {confidence_mean:.2f}") |
| 119 | 139 | ||
| 120 | # 将统计信息保存到 JSON 文件 | 140 | # 将统计信息保存到 JSON 文件 |
| 121 | with open(stats_output_path, 'w', encoding='utf-8') as f: | 141 | with open(stats_output_path, 'w', encoding='utf-8') as f: |
| 122 | - json.dump(stats, f, ensure_ascii=False) | 142 | + json.dump(stats, f, ensure_ascii=False, indent=4) |
| 123 | 143 | ||
| 124 | - return True # 成功执行 | 144 | + return True |
| 125 | except Exception as e: | 145 | except Exception as e: |
| 126 | - print(f"Error during prediction: {e}") | ||
| 127 | - return False # 执行失败 | 146 | + print(f"预测过程中出现错误: {e}") |
| 147 | + return False | ||
| 128 | 148 | ||
| 129 | 149 | ||
| 130 | if __name__ == "__main__": | 150 | if __name__ == "__main__": |
| 131 | if len(sys.argv) != 3: | 151 | if len(sys.argv) != 3: |
| 132 | - print("Usage: python using_example.py <input_data_path> <stats_output_path>") | 152 | + print("使用方法: python predict.py <input_data_path> <stats_output_path>") |
| 133 | sys.exit(1) | 153 | sys.exit(1) |
| 134 | 154 | ||
| 135 | input_data_path = sys.argv[1] | 155 | input_data_path = sys.argv[1] |
| 136 | stats_output_path = sys.argv[2] | 156 | stats_output_path = sys.argv[2] |
| 157 | + | ||
| 137 | # 定义路径 | 158 | # 定义路径 |
| 138 | - model_save_path = 'BCAT/final_model.pt' | ||
| 139 | - output_path = 'BCAT/predictions.csv' # 保存预测结果的文件 | ||
| 140 | - bert_model_path = 'BCAT/bert_model' | ||
| 141 | - ctm_tokenizer_path = 'BCAT/sentence_bert_model' | 159 | + model_save_path = 'model_pro/final_model.pt' |
| 160 | + output_path = 'model_pro/predictions.csv' | ||
| 161 | + bert_model_path = 'model_pro/bert_model' | ||
| 162 | + ctm_tokenizer_path = 'model_pro/sentence_bert_model' | ||
| 142 | 163 | ||
| 143 | # 执行预测 | 164 | # 执行预测 |
| 144 | success = predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path, | 165 | success = predict(model_save_path, input_data_path, output_path, bert_model_path, ctm_tokenizer_path, |
| 145 | stats_output_path) | 166 | stats_output_path) |
| 146 | 167 | ||
| 147 | if success: | 168 | if success: |
| 148 | - sys.exit(0) # 成功 | 169 | + sys.exit(0) |
| 149 | else: | 170 | else: |
| 150 | - sys.exit(1) # 失败 | 171 | + sys.exit(1) |
| 1 | from utils.getPublicData import * # Import utility functions for data retrieval | 1 | from utils.getPublicData import * # Import utility functions for data retrieval |
| 2 | from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis | 2 | from utils.mynlp import SnowNLP # Import SnowNLP for sentiment analysis |
| 3 | from collections import Counter # Import Counter for counting occurrences | 3 | from collections import Counter # Import Counter for counting occurrences |
| 4 | +import torch | ||
| 5 | +from model_pro.MHA import MultiHeadAttentionLayer | ||
| 6 | +from model_pro.classifier import FinalClassifier | ||
| 7 | +from model_pro.BERT_CTM import BERT_CTM_Model | ||
| 4 | 8 | ||
| 5 | articleList = getAllArticleData() # Retrieve all article data | 9 | articleList = getAllArticleData() # Retrieve all article data |
| 6 | commentList = getAllCommentsData() # Retrieve all comment data | 10 | commentList = getAllCommentsData() # Retrieve all comment data |
| 7 | 11 | ||
| 12 | +# 设置设备 | ||
| 13 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 14 | + | ||
| 15 | +# 加载模型(全局变量,避免重复加载) | ||
| 16 | +model_save_path = 'model_pro/final_model.pt' | ||
| 17 | +bert_model_path = 'model_pro/bert_model' | ||
| 18 | +ctm_tokenizer_path = 'model_pro/sentence_bert_model' | ||
| 19 | + | ||
| 20 | +try: | ||
| 21 | + classifier_model = torch.load(model_save_path, map_location=device) | ||
| 22 | + classifier_model.eval() | ||
| 23 | + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | ||
| 24 | + attention_model.to(device) | ||
| 25 | + attention_model.eval() | ||
| 26 | + bert_ctm_model = BERT_CTM_Model( | ||
| 27 | + bert_model_path=bert_model_path, | ||
| 28 | + ctm_tokenizer_path=ctm_tokenizer_path | ||
| 29 | + ) | ||
| 30 | +except Exception as e: | ||
| 31 | + print(f"模型加载失败: {e}") | ||
| 32 | + | ||
| 33 | +def predict_sentiment(texts): | ||
| 34 | + """使用改进版模型预测情感""" | ||
| 35 | + try: | ||
| 36 | + # 获取文本嵌入 | ||
| 37 | + embeddings = bert_ctm_model.get_bert_embeddings(texts) | ||
| 38 | + | ||
| 39 | + # 转换为tensor | ||
| 40 | + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device) | ||
| 41 | + batch_x = torch.mean(batch_x, dim=1) | ||
| 42 | + | ||
| 43 | + with torch.no_grad(): | ||
| 44 | + # 使用注意力机制 | ||
| 45 | + attention_output = attention_model(batch_x, batch_x, batch_x) | ||
| 46 | + # 获取分类结果 | ||
| 47 | + outputs = classifier_model(attention_output) | ||
| 48 | + outputs = torch.mean(outputs, dim=1) | ||
| 49 | + # 获取预测标签 | ||
| 50 | + _, predicted = torch.max(outputs, 1) | ||
| 51 | + | ||
| 52 | + return predicted.cpu().numpy() | ||
| 53 | + except Exception as e: | ||
| 54 | + print(f"预测过程中出现错误: {e}") | ||
| 55 | + return None | ||
| 56 | + | ||
| 8 | def getTypeList(): | 57 | def getTypeList(): |
| 9 | # Return a list of unique article types | 58 | # Return a list of unique article types |
| 10 | return list(set([x[8] for x in articleList])) | 59 | return list(set([x[8] for x in articleList])) |
| @@ -119,32 +168,32 @@ def getYuQingCharDataOne(): | @@ -119,32 +168,32 @@ def getYuQingCharDataOne(): | ||
| 119 | return X, Y, biedata | 168 | return X, Y, biedata |
| 120 | 169 | ||
| 121 | def getYuQingCharDataTwo(): | 170 | def getYuQingCharDataTwo(): |
| 122 | - # Analyze sentiment of comments and articles | ||
| 123 | - comment_sentiments = [] | ||
| 124 | - for comment in commentList: | ||
| 125 | - emotionValue = SnowNLP(comment[4]).sentiments | ||
| 126 | - if emotionValue > 0.4: | ||
| 127 | - comment_sentiments.append('正面') | ||
| 128 | - elif emotionValue < 0.2: | ||
| 129 | - comment_sentiments.append('负面') | 171 | + # 分析评论和文章的情感 |
| 172 | + comment_texts = [comment[4] for comment in commentList] | ||
| 173 | + article_texts = [article[5] for article in articleList] | ||
| 174 | + | ||
| 175 | + # 预测评论情感 | ||
| 176 | + comment_predictions = predict_sentiment(comment_texts) | ||
| 177 | + if comment_predictions is not None: | ||
| 178 | + comment_sentiments = ['良好' if pred == 0 else '不良' for pred in comment_predictions] | ||
| 130 | else: | 179 | else: |
| 131 | - comment_sentiments.append('中性') | ||
| 132 | - comment_counts = Counter(comment_sentiments) | 180 | + comment_sentiments = [] |
| 133 | 181 | ||
| 134 | - article_sentiments = [] | ||
| 135 | - for article in articleList: | ||
| 136 | - emotionValue = SnowNLP(article[5]).sentiments | ||
| 137 | - if emotionValue > 0.4: | ||
| 138 | - article_sentiments.append('正面') | ||
| 139 | - elif emotionValue < 0.2: | ||
| 140 | - article_sentiments.append('负面') | 182 | + # 预测文章情感 |
| 183 | + article_predictions = predict_sentiment(article_texts) | ||
| 184 | + if article_predictions is not None: | ||
| 185 | + article_sentiments = ['良好' if pred == 0 else '不良' for pred in article_predictions] | ||
| 141 | else: | 186 | else: |
| 142 | - article_sentiments.append('中性') | 187 | + article_sentiments = [] |
| 188 | + | ||
| 189 | + # 统计结果 | ||
| 190 | + comment_counts = Counter(comment_sentiments) | ||
| 143 | article_counts = Counter(article_sentiments) | 191 | article_counts = Counter(article_sentiments) |
| 144 | 192 | ||
| 145 | - X = ['正面', '中性', '负面'] | 193 | + X = ['良好', '不良'] |
| 146 | biedata1 = [{'name': x, 'value': comment_counts.get(x, 0)} for x in X] | 194 | biedata1 = [{'name': x, 'value': comment_counts.get(x, 0)} for x in X] |
| 147 | biedata2 = [{'name': x, 'value': article_counts.get(x, 0)} for x in X] | 195 | biedata2 = [{'name': x, 'value': article_counts.get(x, 0)} for x in X] |
| 196 | + | ||
| 148 | return biedata1, biedata2 | 197 | return biedata1, biedata2 |
| 149 | 198 | ||
| 150 | def getYuQingCharDataThree(): | 199 | def getYuQingCharDataThree(): |
| @@ -8,12 +8,61 @@ from utils.getEchartsData import * | @@ -8,12 +8,61 @@ from utils.getEchartsData import * | ||
| 8 | from utils.getTopicPageData import * | 8 | from utils.getTopicPageData import * |
| 9 | from utils.yuqingpredict import * | 9 | from utils.yuqingpredict import * |
| 10 | from utils.logger import app_logger as logging | 10 | from utils.logger import app_logger as logging |
| 11 | +import torch | ||
| 12 | +from model_pro.MHA import MultiHeadAttentionLayer | ||
| 13 | +from model_pro.classifier import FinalClassifier | ||
| 14 | +from model_pro.BERT_CTM import BERT_CTM_Model | ||
| 11 | 15 | ||
| 12 | pb = Blueprint('page', | 16 | pb = Blueprint('page', |
| 13 | __name__, | 17 | __name__, |
| 14 | url_prefix='/page', | 18 | url_prefix='/page', |
| 15 | template_folder='templates') | 19 | template_folder='templates') |
| 16 | 20 | ||
| 21 | +# 设置设备 | ||
| 22 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| 23 | + | ||
| 24 | +# 加载模型(全局变量,避免重复加载) | ||
| 25 | +model_save_path = 'model_pro/final_model.pt' | ||
| 26 | +bert_model_path = 'model_pro/bert_model' | ||
| 27 | +ctm_tokenizer_path = 'model_pro/sentence_bert_model' | ||
| 28 | + | ||
| 29 | +try: | ||
| 30 | + classifier_model = torch.load(model_save_path, map_location=device) | ||
| 31 | + classifier_model.eval() | ||
| 32 | + attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8) | ||
| 33 | + attention_model.to(device) | ||
| 34 | + attention_model.eval() | ||
| 35 | + bert_ctm_model = BERT_CTM_Model( | ||
| 36 | + bert_model_path=bert_model_path, | ||
| 37 | + ctm_tokenizer_path=ctm_tokenizer_path | ||
| 38 | + ) | ||
| 39 | +except Exception as e: | ||
| 40 | + print(f"模型加载失败: {e}") | ||
| 41 | + | ||
| 42 | +def predict_sentiment(text): | ||
| 43 | + """使用改进版模型预测单个文本的情感""" | ||
| 44 | + try: | ||
| 45 | + # 获取文本嵌入 | ||
| 46 | + embeddings = bert_ctm_model.get_bert_embeddings([text]) | ||
| 47 | + | ||
| 48 | + # 转换为tensor | ||
| 49 | + batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device) | ||
| 50 | + batch_x = torch.mean(batch_x, dim=1) | ||
| 51 | + | ||
| 52 | + with torch.no_grad(): | ||
| 53 | + # 使用注意力机制 | ||
| 54 | + attention_output = attention_model(batch_x, batch_x, batch_x) | ||
| 55 | + # 获取分类结果 | ||
| 56 | + outputs = classifier_model(attention_output) | ||
| 57 | + outputs = torch.mean(outputs, dim=1) | ||
| 58 | + # 获取预测标签和概率 | ||
| 59 | + probabilities = torch.softmax(outputs, dim=1) | ||
| 60 | + _, predicted = torch.max(outputs, 1) | ||
| 61 | + | ||
| 62 | + return predicted.item(), probabilities[0][predicted.item()].item() | ||
| 63 | + except Exception as e: | ||
| 64 | + print(f"预测过程中出现错误: {e}") | ||
| 65 | + return None, None | ||
| 17 | 66 | ||
| 18 | @pb.route('/home') | 67 | @pb.route('/home') |
| 19 | def home(): | 68 | def home(): |
| @@ -172,14 +221,15 @@ def yuqingpredict(): | @@ -172,14 +221,15 @@ def yuqingpredict(): | ||
| 172 | defaultTopic = request.args.get('Topic') | 221 | defaultTopic = request.args.get('Topic') |
| 173 | TopicLen = getTopicLen(defaultTopic) | 222 | TopicLen = getTopicLen(defaultTopic) |
| 174 | X, Y = getTopicCreatedAtandpredictData(defaultTopic) | 223 | X, Y = getTopicCreatedAtandpredictData(defaultTopic) |
| 175 | - sentences = '' | ||
| 176 | - value = SnowNLP(defaultTopic).sentiments | ||
| 177 | - if value == 0.5: | ||
| 178 | - sentences = '中性' | ||
| 179 | - elif value > 0.5: | ||
| 180 | - sentences = '正面' | ||
| 181 | - elif value < 0.5: | ||
| 182 | - sentences = '负面' | 224 | + |
| 225 | + # 使用改进版模型进行情感预测 | ||
| 226 | + predicted_label, confidence = predict_sentiment(defaultTopic) | ||
| 227 | + if predicted_label is not None: | ||
| 228 | + sentences = '良好' if predicted_label == 0 else '不良' | ||
| 229 | + sentences = f"{sentences} (置信度: {confidence:.2f})" | ||
| 230 | + else: | ||
| 231 | + sentences = '预测失败' | ||
| 232 | + | ||
| 183 | comments = getCommentFilterDataTopic(defaultTopic) | 233 | comments = getCommentFilterDataTopic(defaultTopic) |
| 184 | return render_template('yuqingpredict.html', | 234 | return render_template('yuqingpredict.html', |
| 185 | username=username, | 235 | username=username, |
-
Please register or login to post a comment