page.py 11 KB
from flask import Flask, session, render_template, redirect, Blueprint, request
from utils.mynlp import SnowNLP
from utils.getHomePageData import *
from utils.getHotWordPageData import *
from utils.getTableData import *
from utils.getPublicData import getAllHotWords, getAllTopics, getArticleByType, getArticleById
from utils.getEchartsData import *
from utils.getTopicPageData import *
from utils.yuqingpredict import *
from utils.logger import app_logger as logging
import torch
from model_pro.MHA import MultiHeadAttentionLayer
from model_pro.classifier import FinalClassifier
from model_pro.BERT_CTM import BERT_CTM_Model

pb = Blueprint('page',
               __name__,
               url_prefix='/page',
               template_folder='templates')

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载模型(全局变量,避免重复加载)
model_save_path = 'model_pro/final_model.pt'
bert_model_path = 'model_pro/bert_model'
ctm_tokenizer_path = 'model_pro/sentence_bert_model'

try:
    classifier_model = torch.load(model_save_path, map_location=device)
    classifier_model.eval()
    attention_model = MultiHeadAttentionLayer(embed_size=768, num_heads=8)
    attention_model.to(device)
    attention_model.eval()
    bert_ctm_model = BERT_CTM_Model(
        bert_model_path=bert_model_path,
        ctm_tokenizer_path=ctm_tokenizer_path
    )
except Exception as e:
    print(f"模型加载失败: {e}")

def predict_sentiment(text):
    """使用改进版模型预测单个文本的情感"""
    try:
        # 获取文本嵌入
        embeddings = bert_ctm_model.get_bert_embeddings([text])
        
        # 转换为tensor
        batch_x = torch.tensor(embeddings, dtype=torch.float32).to(device)
        batch_x = torch.mean(batch_x, dim=1)
        
        with torch.no_grad():
            # 使用注意力机制
            attention_output = attention_model(batch_x, batch_x, batch_x)
            # 获取分类结果
            outputs = classifier_model(attention_output)
            outputs = torch.mean(outputs, dim=1)
            # 获取预测标签和概率
            probabilities = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs, 1)
            
        return predicted.item(), probabilities[0][predicted.item()].item()
    except Exception as e:
        print(f"预测过程中出现错误: {e}")
        return None, None

@pb.route('/home')
def home():
    username = session.get('username')
    articleLenMax, likeCountMaxAuthorName, cityMax = getHomeTagsData()
    commentsLikeCountTopFore = getHomeCommentsLikeCountTopFore()
    X, Y = getHomeArticleCreatedAtChart()
    typeChart = getHomeTypeChart()
    createAtChart = getHomeCommentCreatedChart()
    # getUserNameWordCloud()
    return render_template('index.html',
                           username=username,
                           articleLenMax=articleLenMax,
                           likeCountMaxAuthorName=likeCountMaxAuthorName,
                           cityMax=cityMax,
                           commentsLikeCountTopFore=commentsLikeCountTopFore,
                           xData=X,
                           yData=Y,
                           typeChart=typeChart,
                           createAtChart=createAtChart)


@pb.route('/hotWord')
def hotWord():
    username = session.get('username')
    hotWordList = getAllHotWords()
    print(hotWordList)
    defaultHotWord = hotWordList[0][0]
    if request.args.get('hotWord'):
        defaultHotWord = request.args.get('hotWord')
    hotWordLen = getHotWordLen(defaultHotWord)
    X, Y = getHotWordPageCreatedAtCharData(defaultHotWord)
    sentences = ''
    value = SnowNLP(defaultHotWord).sentiments
    if value == 0.5:
        sentences = '中性'
    elif value > 0.5:
        sentences = '正面'
    elif value < 0.5:
        sentences = '负面'
    comments = getCommentFilterData(defaultHotWord)
    return render_template('hotWord.html',
                           username=username,
                           hotWordList=hotWordList,
                           defaultHotWord=defaultHotWord,
                           hotWordLen=hotWordLen,
                           sentences=sentences,
                           xData=X,
                           yData=Y,
                           comments=comments)


@pb.route('/hotTopic')
def hotTopic():
    username = session.get('username')
    topicList = getAllTopics()
    defaultTopic = topicList[0][0]
    if request.args.get('topic'):
        defaultTopic = request.args.get('topic')
    topicLen = getTopicLen(defaultTopic)
    X, Y = getTopicPageCreatedAtCharData()
    sentences = ''
    
    # ... 这里要嵌入 topic 相关内容(热度?)来填充 sentences
    
    comments = getCommentFilterDataTopic(defaultTopic)
    return render_template('hotWord.html',
                           username=username,
                           topicList=topicList,
                           defaultTopic=defaultTopic,
                           topicLen=topicLen,
                           sentences=sentences,
                           xData=X,
                           yData=Y,
                           comments=comments)


@pb.route('/tableData')
def tableData():
    username = session.get('username')
    defaultFlag = False
    if request.args.get('flag'): defaultFlag = True
    tableData = getTableDataList(defaultFlag)
    return render_template('tableData.html',
                           username=username,
                           tableData=tableData,
                           defaultFlag=defaultFlag)


@pb.route('/articleChar')
def articleChar():
    username = session.get('username')
    typeList = getTypeList()
    defaultType = typeList[0]
    if request.args.get('type'): defaultType = request.args.get('type')
    X, Y = getArticleLikeCount(defaultType)
    x1Data, y1Data = getArticleCommentsLen(defaultType)
    x2Data, y2Data = getArticleRepotsLen(defaultType)
    return render_template('articleChar.html',
                           username=username,
                           typeList=typeList,
                           defaultType=defaultType,
                           xData=X,
                           yData=Y,
                           x1Data=x1Data,
                           y1Data=y1Data,
                           x2Data=x2Data,
                           y2Data=y2Data)


@pb.route('/ipChar')
def ipChar():
    username = session.get('username')
    articleRegionData = getIPByArticleRegion()
    commentRegionData = getIPByCommentsRegion()
    return render_template('ipChar.html',
                           username=username,
                           articleRegionData=articleRegionData,
                           commentRegionData=commentRegionData)


@pb.route('/commentChar')
def commentChar():
    username = session.get('username')
    X, Y = getCommentDataOne()
    genderPieData = getCommentDataTwo()
    return render_template('commentChar.html',
                           username=username,
                           xData=X,
                           yData=Y,
                           genderPieData=genderPieData)


@pb.route('/yuqingChar')
def yuqingChar():
    username = session.get('username')
    # 获取模型选择参数
    model_type = request.args.get('model', 'pro')  # 默认使用改进模型
    
    X, Y, biedata = getYuQingCharDataOne()
    biedata1, biedata2 = getYuQingCharDataTwo(model_type)
    x1Data, y1Data = getYuQingCharDataThree()
    return render_template('yuqingChar.html',
                           username=username,
                           xData=X,
                           yData=Y,
                           biedata=biedata,
                           biedata1=biedata1,
                           biedata2=biedata2,
                           x1Data=x1Data,
                           y1Data=y1Data,
                           model_type=model_type)

@pb.route('/yuqingpredict')
def yuqingpredict():
    username = session.get('username')
    TopicList = getAllTopicData()
    defaultTopic = TopicList[0][0]
    if request.args.get('Topic'):
        defaultTopic = request.args.get('Topic')
    TopicLen = getTopicLen(defaultTopic)
    X, Y = getTopicCreatedAtandpredictData(defaultTopic)
    
    # 获取模型选择参数
    model_type = request.args.get('model', 'pro')  # 默认使用改进模型
    
    if model_type == 'basic':
        # 使用基础模型(SnowNLP)
        value = SnowNLP(defaultTopic).sentiments
        if value == 0.5:
            sentences = '中性'
        elif value > 0.5:
            sentences = '正面'
        elif value < 0.5:
            sentences = '负面'
    else:
        # 使用改进模型
        predicted_label, confidence = predict_sentiment(defaultTopic)
        if predicted_label is not None:
            sentences = '良好' if predicted_label == 0 else '不良'
            sentences = f"{sentences} (置信度: {confidence:.2f})"
        else:
            sentences = '预测失败'
    
    comments = getCommentFilterDataTopic(defaultTopic)
    return render_template('yuqingpredict.html',
                           username=username,
                           hotWordList=TopicList,
                           defaultHotWord=defaultTopic,
                           hotWordLen=TopicLen,
                           sentences=sentences,
                           xData=X,
                           yData=Y,
                           comments=comments,
                           model_type=model_type)


@pb.route('/articleCloud')
def articleCloud():
    username = session.get('username')
    return render_template('articleContentCloud.html', username=username)


@pb.route('/page/index')
def index():
    """首页路由"""
    try:
        hotWordList = getAllHotWords()
        logging.info("成功获取热词列表")
        return render_template('index.html', hotWordList=hotWordList)
    except Exception as e:
        logging.error(f"渲染首页时发生错误: {e}")
        return render_template('error.html', error_message="加载首页失败")

@pb.route('/page/article/<type>')
def article(type):
    """文章列表页路由"""
    try:
        articleList = getArticleByType(type)
        logging.info(f"成功获取类型为 {type} 的文章列表")
        return render_template('article.html', articleList=articleList)
    except Exception as e:
        logging.error(f"获取文章列表时发生错误: {e}")
        return render_template('error.html', error_message="加载文章列表失败")

@pb.route('/page/articleChar/<id>')
def articleChar(id):
    """文章详情页路由"""
    try:
        article = getArticleById(id)
        if not article:
            logging.warning(f"未找到ID为 {id} 的文章")
            return render_template('error.html', error_message="文章不存在")
        logging.info(f"成功获取ID为 {id} 的文章详情")
        return render_template('articleChar.html', article=article)
    except Exception as e:
        logging.error(f"获取文章详情时发生错误: {e}")
        return render_template('error.html', error_message="加载文章详情失败")