lstm_predict.py
5.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import os
import logging
from LSTM_model import lstm_model_manager
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('lstm_predict')
class LSTMPredictor:
"""LSTM预测器,与当前系统的预测接口兼容"""
def __init__(self):
self.model_loaded = False
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"初始化LSTM预测器,使用设备: {self.device}")
def load_models(self, model_save_path, bert_model_path, tokenizer_path=None):
"""
加载模型,与当前系统的model_manager.load_models接口兼容
参数:
model_save_path: LSTM模型保存路径
bert_model_path: BERT模型路径
tokenizer_path: 分词器路径(LSTM模型中使用BERT的分词器,可忽略)
"""
try:
# 检查模型文件是否存在
if not os.path.exists(model_save_path):
logger.warning(f"模型文件 {model_save_path} 不存在,需要先训练模型")
return False
if not os.path.exists(bert_model_path):
logger.error(f"BERT模型路径 {bert_model_path} 不存在")
return False
# 实际上我们在lstm_model_manager初始化时已经加载了模型,这里只是检查一下
if lstm_model_manager.model is not None:
self.model_loaded = True
logger.info("LSTM模型已加载成功")
return True
else:
logger.error("LSTM模型加载失败")
return False
except Exception as e:
logger.error(f"加载模型过程中出错: {e}")
return False
def predict_batch(self, texts):
"""
批量预测文本的情感
参数:
texts: 文本列表
返回:
predictions: 预测结果列表(0表示良好,1表示不良)
probabilities: 预测概率列表
"""
if not self.model_loaded and lstm_model_manager.model is None:
logger.error("模型未加载,无法进行预测")
return None, None
if not texts:
logger.warning("未提供文本,无法进行预测")
return None, None
try:
# 调用LSTM模型管理器的批量预测函数
predictions, probabilities = lstm_model_manager.predict_batch(texts)
return predictions, probabilities
except Exception as e:
logger.error(f"预测过程中出错: {e}")
return None, None
def predict(self, text):
"""
预测单个文本的情感
参数:
text: 文本字符串
返回:
prediction: 预测结果(0表示良好,1表示不良)
probability: 预测概率
"""
if not self.model_loaded and lstm_model_manager.model is None:
logger.error("模型未加载,无法进行预测")
return None, None
if not text or len(text.strip()) == 0:
logger.warning("未提供文本或文本为空,无法进行预测")
return None, None
try:
# 调用LSTM模型管理器的单个文本预测函数
prediction, probability = lstm_model_manager.predict(text)
return prediction, probability
except Exception as e:
logger.error(f"预测过程中出错: {e}")
return None, None
def train_model(self, train_texts, train_labels, val_texts=None, val_labels=None,
batch_size=32, learning_rate=2e-5, epochs=10):
"""
训练模型
参数:
train_texts: 训练集文本
train_labels: 训练集标签
val_texts: 验证集文本
val_labels: 验证集标签
batch_size: 批次大小
learning_rate: 学习率
epochs: 训练轮数
返回:
训练结果
"""
try:
results = lstm_model_manager.train(
train_texts, train_labels, val_texts, val_labels,
batch_size, learning_rate, epochs
)
self.model_loaded = True
return results
except Exception as e:
logger.error(f"训练模型过程中出错: {e}")
return None
# 创建全局预测器实例
lstm_predictor = LSTMPredictor()
# 为了与现有代码兼容,提供一个与model_manager相同的predict_batch函数
def predict_batch(texts):
return lstm_predictor.predict_batch(texts)
# 为了与现有代码兼容,提供一个与model_manager相同的load_models函数
def load_models(model_save_path, bert_model_path, tokenizer_path=None):
return lstm_predictor.load_models(model_save_path, bert_model_path, tokenizer_path)
# 测试代码
if __name__ == "__main__":
# 加载模型
load_models(
model_save_path="model_pro/lstm_model.pt",
bert_model_path="model_pro/bert_model"
)
# 测试预测功能
test_sentences = [
"这件事情做得非常好",
"服务太差了,态度恶劣",
"这个产品质量一般,但价格便宜",
"我对这家公司非常满意",
]
for sentence in test_sentences:
pred, prob = lstm_predictor.predict(sentence)
if pred is not None:
label = '良好' if pred == 0 else '不良'
confidence = prob[pred]
print(f"句子: '{sentence}' 预测结果: {label} (置信度: {confidence:.2%})")
else:
print(f"句子: '{sentence}' 预测失败")