Wenkai Liang
Committed by GitHub

Merge pull request #7 from wjhgq/main

The new practice sequence model to complete the public opinion prediction function.
1 import numpy as np 1 import numpy as np
2 import datetime 2 import datetime
3 -import matplotlib.pyplot as plt 3 +import pandas as pd
  4 +from pmdarima import auto_arima
4 5
5 -  
6 -def datetime_to_number(date: str): # 格式化日期转换为 integer 6 +def datetime_to_number(date: str):
  7 + """Convert a date string 'YYYY-MM-DD' to a relative day number."""
7 date_number = datetime.datetime.strptime(date, "%Y-%m-%d") 8 date_number = datetime.datetime.strptime(date, "%Y-%m-%d")
8 base_number = datetime.datetime.strptime("2024-1-1", "%Y-%m-%d") 9 base_number = datetime.datetime.strptime("2024-1-1", "%Y-%m-%d")
9 return (date_number - base_number).days 10 return (date_number - base_number).days
10 11
  12 +def predict_future_values(data, forecast_days=5):
  13 + """
  14 + Use auto_arima from pmdarima to fit a suitable ARIMA/SARIMA model for the time series,
  15 + then predict future values for the specified number of days.
  16 +
  17 + Parameters:
  18 + data: dict, keys are date strings 'YYYY-MM-DD', values are integer counts
  19 + forecast_days: int, number of days to predict into the future
  20 +
  21 + Returns:
  22 + predictions: dict, keys are future date strings 'YYYY-MM-DD', values are predicted integers (≥0)
  23 + """
  24 + if not data:
  25 + return {}
11 26
12 -def predict_future_values(data):  
13 - # 提取并排序日期  
14 - sorted_dates = sorted(data.keys(), key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d"))  
15 - sorted_data = {k: data[k] for k in sorted_dates} 27 + # Sort data by date
  28 + sorted_dates = sorted(data.keys(), key=lambda d: datetime.datetime.strptime(d, "%Y-%m-%d"))
  29 + start_date = sorted_dates[0]
  30 + end_date = sorted_dates[-1]
16 31
17 - # 将日期转换为整数并提取相应的值  
18 - xs = np.array([datetime_to_number(date) for date in sorted_data.keys()])  
19 - ys = np.array([data[date] for date in sorted_data.keys()]) 32 + # Create a full date range to ensure continuity in the time series
  33 + full_range = pd.date_range(start=start_date, end=end_date, freq='D')
  34 + ts = pd.Series(0, index=full_range, dtype=float)
  35 + for d in data:
  36 + ts[pd.to_datetime(d)] = data[d]
20 37
21 - # 拟合线性回归模型  
22 - fit = np.polyfit(xs, ys, 1)  
23 - fn = np.poly1d(fit) 38 + # Simple smoothing: optional step to reduce noise (moving average over 3 days)
  39 + # This is a mild smoothing to handle noisy data. You can comment this out if not needed.
  40 + ts_smoothed = ts.rolling(window=3, min_periods=1).mean()
24 41
25 - # 获取最新日期,并生成未来三天的日期  
26 - latest_date = sorted_dates[-1]  
27 - latest_date_obj = datetime.datetime.strptime(latest_date, "%Y-%m-%d")  
28 - future_dates = [(latest_date_obj + datetime.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(1, 6)] 42 + # Fit the time series with auto_arima to find the best parameters
  43 + model = auto_arima(ts_smoothed,
  44 + start_p=1, start_q=1,
  45 + max_p=5, max_q=5,
  46 + seasonal=False,
  47 + trace=False, error_action='ignore', suppress_warnings=True, stepwise=True)
29 48
30 - # 预测未来日期的值 49 + # Predict the future values
  50 + forecast = model.predict(n_periods=forecast_days)
  51 + # Construct future dates
  52 + last_date = pd.to_datetime(end_date)
  53 + future_dates = [last_date + datetime.timedelta(days=i) for i in range(1, forecast_days+1)]
  54 +
  55 + # Convert forecast results to dict with non-negative integers
31 predictions = {} 56 predictions = {}
32 - for date in future_dates:  
33 - date_num = datetime_to_number(date)  
34 - if int(fn(date_num))<=0:  
35 - predictions[date] = 0  
36 - else:  
37 - predictions[date] = int(fn(date_num)) 57 + for d, v in zip(future_dates, forecast):
  58 + predictions[d.strftime("%Y-%m-%d")] = max(int(round(v)), 0)
38 59
39 return predictions 60 return predictions
40 61
41 -  
42 if __name__ == '__main__': 62 if __name__ == '__main__':
43 - data = {'2024-06-15': 1, '2024-06-18': 1, '2024-06-22': 1, '2024-06-23': 1, '2024-07-01': 3, '2024-07-02': 4, '2024-07-03': 4, '2024-07-04': 14}  
44 - predictions = predict_future_values(data)  
45 - print(predictions)  
46 - # for date, value in predictions.items():  
47 - # print(f'{date} PREDICTION: {value}') 63 + data = {
  64 + '2024-06-15': 1, '2024-06-18': 1, '2024-06-22': 1,
  65 + '2024-06-23': 1, '2024-07-01': 3, '2024-07-02': 4,
  66 + '2024-07-03': 4, '2024-07-04': 14
  67 + }
  68 + preds = predict_future_values(data)
  69 + print(preds)
1 from utils.getPublicData import * 1 from utils.getPublicData import *
2 -from utils.predict import *  
3 -articleList = getAllArticleData()  
4 -commentList = getAllCommentsData() 2 +from utils.predict import predict_future_values # Use the new function
5 import csv 3 import csv
6 import os 4 import os
7 import datetime 5 import datetime
8 -def getTopicByArticle():# 返回文章内容的话题字典  
9 - articleTopicDic = {}  
10 - for i in articleList:  
11 - if i[14] != None:  
12 - if i[14] in articleTopicDic.keys():  
13 - articleTopicDic[i[14]] += 1  
14 - else:  
15 - articleTopicDic[i[14]] = 1  
16 - resultData = []  
17 - for key,value in articleTopicDic.items():  
18 - resultData.append({  
19 - 'name':key,  
20 - 'value':value  
21 - })  
22 - return resultData  
23 -  
24 -def getTopicByComments():# 返回评论内容的话题字典  
25 - commentsTopicDic = {}  
26 - for i in commentList:  
27 - if i[9] != None:  
28 - if i[9] in commentsTopicDic:  
29 - commentsTopicDic[i[9]] += 1  
30 - else:  
31 - commentsTopicDic[i[9]] = 1  
32 - resultData = []  
33 - for key,value in commentsTopicDic.items():  
34 - resultData.append({  
35 - 'name':key,  
36 - 'value':value  
37 - })  
38 - return resultData  
39 -  
40 -def mergeTopics(article_topics, comment_topics):# 合并话题  
41 - merged_dict = {}  
42 - for topic in article_topics + comment_topics:  
43 - if topic['name'] in merged_dict:  
44 - merged_dict[topic['name']] += topic['value']  
45 - else:  
46 - merged_dict[topic['name']] = topic['value']  
47 - merged_dict = sorted(merged_dict.items(), key=lambda item: item[1], reverse=True)  
48 - merged_list = [[key, str(value)] for key, value in merged_dict]  
49 - return merged_list  
50 -def getAllTopicData():  
51 - # 读取合并文件 merge.csv  
52 - # data = []  
53 - # df = pd.read_csv('./merged_topics.csv',encoding='utf8')  
54 - # for i in df.values:  
55 - # try:  
56 - # data.append([  
57 - # re.search('[\u4e00-\u9fa5]+',str(i)).group(),  
58 - # re.search('\d+',str(i)).group()  
59 - # ])  
60 - # except:  
61 - # continue  
62 - return mergeTopics(getTopicByArticle(), getTopicByComments()) 6 +import pandas as pd
63 7
64 -def getTopicCreatedAtandpredictData(topic):# 统计特定话题的评论在每个日期的数量,并返回日期和对应的评论数量 8 +def getTopicCreatedAtandpredictData(topic):
65 createdAt = {} 9 createdAt = {}
66 for i in articleList: 10 for i in articleList:
67 if i[14]==topic: 11 if i[14]==topic:
@@ -75,30 +19,13 @@ def getTopicCreatedAtandpredictData(topic):# 统计特定话题的评论在每 @@ -75,30 +19,13 @@ def getTopicCreatedAtandpredictData(topic):# 统计特定话题的评论在每
75 createdAt[i[1]] += 1 19 createdAt[i[1]] += 1
76 else: 20 else:
77 createdAt[i[1]] = 1 21 createdAt[i[1]] = 1
78 - createdAt = {k: createdAt[k] for k in sorted(createdAt, key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d"))}  
79 - createdAt.update(predict_future_values(createdAt))  
80 - sorted_data = {k: createdAt[k] for k in sorted(createdAt, key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d"))}  
81 - # result_list = [0] * (len(sorted_data) - 5) + [1] * 5  
82 - print(list(createdAt.keys()),list(createdAt.values()))  
83 - return list(createdAt.keys()),list(createdAt.values())  
84 22
85 -def writeTopicsToCSV(topics, file_name):  
86 - # 检查文件是否存在,如果存在则附加写入,否则新建一个  
87 - file_exists = os.path.isfile(file_name)  
88 - # 按值的降序排序  
89 - sorted_topics = sorted(topics, key=lambda x: x['value'], reverse=True)  
90 - with open(file_name, 'w', newline='', encoding='utf-8') as csvfile:  
91 - fieldnames = ['name', 'value']  
92 - writer = csv.DictWriter(csvfile, fieldnames=fieldnames)  
93 - # 如果文件不存在,则写入表头  
94 - if not file_exists:  
95 - writer.writeheader()  
96 - # 写入数据  
97 - for topic in sorted_topics:  
98 - writer.writerow(topic)  
99 -if __name__ == '__main__':  
100 - # 将话题数据写入 CSV 文件  
101 - # print(mergeTopics(getTopicByArticle(), getTopicByComments()))  
102 - # writeTopicsToCSV(merged_topics, 'merged_topics.csv')  
103 - print(getAllTopicData()) 23 + # Use the improved time series prediction approach
  24 + predictions = predict_future_values(createdAt, forecast_days=5)
  25 +
  26 + # Merge historical data and predictions
  27 + combined_data = {**createdAt, **predictions}
  28 + combined_data = {k: combined_data[k] for k in sorted(combined_data, key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d"))}
104 29
  30 + print(list(combined_data.keys()), list(combined_data.values()))
  31 + return list(combined_data.keys()), list(combined_data.values())