Committed by
GitHub
Merge pull request #7 from wjhgq/main
The new practice sequence model to complete the public opinion prediction function.
Showing
2 changed files
with
63 additions
and
114 deletions
| 1 | import numpy as np | 1 | import numpy as np |
| 2 | import datetime | 2 | import datetime |
| 3 | -import matplotlib.pyplot as plt | 3 | +import pandas as pd |
| 4 | +from pmdarima import auto_arima | ||
| 4 | 5 | ||
| 5 | - | ||
| 6 | -def datetime_to_number(date: str): # 格式化日期转换为 integer | 6 | +def datetime_to_number(date: str): |
| 7 | + """Convert a date string 'YYYY-MM-DD' to a relative day number.""" | ||
| 7 | date_number = datetime.datetime.strptime(date, "%Y-%m-%d") | 8 | date_number = datetime.datetime.strptime(date, "%Y-%m-%d") |
| 8 | base_number = datetime.datetime.strptime("2024-1-1", "%Y-%m-%d") | 9 | base_number = datetime.datetime.strptime("2024-1-1", "%Y-%m-%d") |
| 9 | return (date_number - base_number).days | 10 | return (date_number - base_number).days |
| 10 | 11 | ||
| 12 | +def predict_future_values(data, forecast_days=5): | ||
| 13 | + """ | ||
| 14 | + Use auto_arima from pmdarima to fit a suitable ARIMA/SARIMA model for the time series, | ||
| 15 | + then predict future values for the specified number of days. | ||
| 16 | + | ||
| 17 | + Parameters: | ||
| 18 | + data: dict, keys are date strings 'YYYY-MM-DD', values are integer counts | ||
| 19 | + forecast_days: int, number of days to predict into the future | ||
| 20 | + | ||
| 21 | + Returns: | ||
| 22 | + predictions: dict, keys are future date strings 'YYYY-MM-DD', values are predicted integers (≥0) | ||
| 23 | + """ | ||
| 24 | + if not data: | ||
| 25 | + return {} | ||
| 11 | 26 | ||
| 12 | -def predict_future_values(data): | ||
| 13 | - # 提取并排序日期 | ||
| 14 | - sorted_dates = sorted(data.keys(), key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d")) | ||
| 15 | - sorted_data = {k: data[k] for k in sorted_dates} | 27 | + # Sort data by date |
| 28 | + sorted_dates = sorted(data.keys(), key=lambda d: datetime.datetime.strptime(d, "%Y-%m-%d")) | ||
| 29 | + start_date = sorted_dates[0] | ||
| 30 | + end_date = sorted_dates[-1] | ||
| 16 | 31 | ||
| 17 | - # 将日期转换为整数并提取相应的值 | ||
| 18 | - xs = np.array([datetime_to_number(date) for date in sorted_data.keys()]) | ||
| 19 | - ys = np.array([data[date] for date in sorted_data.keys()]) | 32 | + # Create a full date range to ensure continuity in the time series |
| 33 | + full_range = pd.date_range(start=start_date, end=end_date, freq='D') | ||
| 34 | + ts = pd.Series(0, index=full_range, dtype=float) | ||
| 35 | + for d in data: | ||
| 36 | + ts[pd.to_datetime(d)] = data[d] | ||
| 20 | 37 | ||
| 21 | - # 拟合线性回归模型 | ||
| 22 | - fit = np.polyfit(xs, ys, 1) | ||
| 23 | - fn = np.poly1d(fit) | 38 | + # Simple smoothing: optional step to reduce noise (moving average over 3 days) |
| 39 | + # This is a mild smoothing to handle noisy data. You can comment this out if not needed. | ||
| 40 | + ts_smoothed = ts.rolling(window=3, min_periods=1).mean() | ||
| 24 | 41 | ||
| 25 | - # 获取最新日期,并生成未来三天的日期 | ||
| 26 | - latest_date = sorted_dates[-1] | ||
| 27 | - latest_date_obj = datetime.datetime.strptime(latest_date, "%Y-%m-%d") | ||
| 28 | - future_dates = [(latest_date_obj + datetime.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(1, 6)] | 42 | + # Fit the time series with auto_arima to find the best parameters |
| 43 | + model = auto_arima(ts_smoothed, | ||
| 44 | + start_p=1, start_q=1, | ||
| 45 | + max_p=5, max_q=5, | ||
| 46 | + seasonal=False, | ||
| 47 | + trace=False, error_action='ignore', suppress_warnings=True, stepwise=True) | ||
| 29 | 48 | ||
| 30 | - # 预测未来日期的值 | 49 | + # Predict the future values |
| 50 | + forecast = model.predict(n_periods=forecast_days) | ||
| 51 | + # Construct future dates | ||
| 52 | + last_date = pd.to_datetime(end_date) | ||
| 53 | + future_dates = [last_date + datetime.timedelta(days=i) for i in range(1, forecast_days+1)] | ||
| 54 | + | ||
| 55 | + # Convert forecast results to dict with non-negative integers | ||
| 31 | predictions = {} | 56 | predictions = {} |
| 32 | - for date in future_dates: | ||
| 33 | - date_num = datetime_to_number(date) | ||
| 34 | - if int(fn(date_num))<=0: | ||
| 35 | - predictions[date] = 0 | ||
| 36 | - else: | ||
| 37 | - predictions[date] = int(fn(date_num)) | 57 | + for d, v in zip(future_dates, forecast): |
| 58 | + predictions[d.strftime("%Y-%m-%d")] = max(int(round(v)), 0) | ||
| 38 | 59 | ||
| 39 | return predictions | 60 | return predictions |
| 40 | 61 | ||
| 41 | - | ||
| 42 | if __name__ == '__main__': | 62 | if __name__ == '__main__': |
| 43 | - data = {'2024-06-15': 1, '2024-06-18': 1, '2024-06-22': 1, '2024-06-23': 1, '2024-07-01': 3, '2024-07-02': 4, '2024-07-03': 4, '2024-07-04': 14} | ||
| 44 | - predictions = predict_future_values(data) | ||
| 45 | - print(predictions) | ||
| 46 | - # for date, value in predictions.items(): | ||
| 47 | - # print(f'{date} PREDICTION: {value}') | 63 | + data = { |
| 64 | + '2024-06-15': 1, '2024-06-18': 1, '2024-06-22': 1, | ||
| 65 | + '2024-06-23': 1, '2024-07-01': 3, '2024-07-02': 4, | ||
| 66 | + '2024-07-03': 4, '2024-07-04': 14 | ||
| 67 | + } | ||
| 68 | + preds = predict_future_values(data) | ||
| 69 | + print(preds) |
| 1 | from utils.getPublicData import * | 1 | from utils.getPublicData import * |
| 2 | -from utils.predict import * | ||
| 3 | -articleList = getAllArticleData() | ||
| 4 | -commentList = getAllCommentsData() | 2 | +from utils.predict import predict_future_values # Use the new function |
| 5 | import csv | 3 | import csv |
| 6 | import os | 4 | import os |
| 7 | import datetime | 5 | import datetime |
| 8 | -def getTopicByArticle():# 返回文章内容的话题字典 | ||
| 9 | - articleTopicDic = {} | ||
| 10 | - for i in articleList: | ||
| 11 | - if i[14] != None: | ||
| 12 | - if i[14] in articleTopicDic.keys(): | ||
| 13 | - articleTopicDic[i[14]] += 1 | ||
| 14 | - else: | ||
| 15 | - articleTopicDic[i[14]] = 1 | ||
| 16 | - resultData = [] | ||
| 17 | - for key,value in articleTopicDic.items(): | ||
| 18 | - resultData.append({ | ||
| 19 | - 'name':key, | ||
| 20 | - 'value':value | ||
| 21 | - }) | ||
| 22 | - return resultData | ||
| 23 | - | ||
| 24 | -def getTopicByComments():# 返回评论内容的话题字典 | ||
| 25 | - commentsTopicDic = {} | ||
| 26 | - for i in commentList: | ||
| 27 | - if i[9] != None: | ||
| 28 | - if i[9] in commentsTopicDic: | ||
| 29 | - commentsTopicDic[i[9]] += 1 | ||
| 30 | - else: | ||
| 31 | - commentsTopicDic[i[9]] = 1 | ||
| 32 | - resultData = [] | ||
| 33 | - for key,value in commentsTopicDic.items(): | ||
| 34 | - resultData.append({ | ||
| 35 | - 'name':key, | ||
| 36 | - 'value':value | ||
| 37 | - }) | ||
| 38 | - return resultData | ||
| 39 | - | ||
| 40 | -def mergeTopics(article_topics, comment_topics):# 合并话题 | ||
| 41 | - merged_dict = {} | ||
| 42 | - for topic in article_topics + comment_topics: | ||
| 43 | - if topic['name'] in merged_dict: | ||
| 44 | - merged_dict[topic['name']] += topic['value'] | ||
| 45 | - else: | ||
| 46 | - merged_dict[topic['name']] = topic['value'] | ||
| 47 | - merged_dict = sorted(merged_dict.items(), key=lambda item: item[1], reverse=True) | ||
| 48 | - merged_list = [[key, str(value)] for key, value in merged_dict] | ||
| 49 | - return merged_list | ||
| 50 | -def getAllTopicData(): | ||
| 51 | - # 读取合并文件 merge.csv | ||
| 52 | - # data = [] | ||
| 53 | - # df = pd.read_csv('./merged_topics.csv',encoding='utf8') | ||
| 54 | - # for i in df.values: | ||
| 55 | - # try: | ||
| 56 | - # data.append([ | ||
| 57 | - # re.search('[\u4e00-\u9fa5]+',str(i)).group(), | ||
| 58 | - # re.search('\d+',str(i)).group() | ||
| 59 | - # ]) | ||
| 60 | - # except: | ||
| 61 | - # continue | ||
| 62 | - return mergeTopics(getTopicByArticle(), getTopicByComments()) | 6 | +import pandas as pd |
| 63 | 7 | ||
| 64 | -def getTopicCreatedAtandpredictData(topic):# 统计特定话题的评论在每个日期的数量,并返回日期和对应的评论数量 | 8 | +def getTopicCreatedAtandpredictData(topic): |
| 65 | createdAt = {} | 9 | createdAt = {} |
| 66 | for i in articleList: | 10 | for i in articleList: |
| 67 | if i[14]==topic: | 11 | if i[14]==topic: |
| @@ -75,30 +19,13 @@ def getTopicCreatedAtandpredictData(topic):# 统计特定话题的评论在每 | @@ -75,30 +19,13 @@ def getTopicCreatedAtandpredictData(topic):# 统计特定话题的评论在每 | ||
| 75 | createdAt[i[1]] += 1 | 19 | createdAt[i[1]] += 1 |
| 76 | else: | 20 | else: |
| 77 | createdAt[i[1]] = 1 | 21 | createdAt[i[1]] = 1 |
| 78 | - createdAt = {k: createdAt[k] for k in sorted(createdAt, key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d"))} | ||
| 79 | - createdAt.update(predict_future_values(createdAt)) | ||
| 80 | - sorted_data = {k: createdAt[k] for k in sorted(createdAt, key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d"))} | ||
| 81 | - # result_list = [0] * (len(sorted_data) - 5) + [1] * 5 | ||
| 82 | - print(list(createdAt.keys()),list(createdAt.values())) | ||
| 83 | - return list(createdAt.keys()),list(createdAt.values()) | ||
| 84 | 22 | ||
| 85 | -def writeTopicsToCSV(topics, file_name): | ||
| 86 | - # 检查文件是否存在,如果存在则附加写入,否则新建一个 | ||
| 87 | - file_exists = os.path.isfile(file_name) | ||
| 88 | - # 按值的降序排序 | ||
| 89 | - sorted_topics = sorted(topics, key=lambda x: x['value'], reverse=True) | ||
| 90 | - with open(file_name, 'w', newline='', encoding='utf-8') as csvfile: | ||
| 91 | - fieldnames = ['name', 'value'] | ||
| 92 | - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) | ||
| 93 | - # 如果文件不存在,则写入表头 | ||
| 94 | - if not file_exists: | ||
| 95 | - writer.writeheader() | ||
| 96 | - # 写入数据 | ||
| 97 | - for topic in sorted_topics: | ||
| 98 | - writer.writerow(topic) | ||
| 99 | -if __name__ == '__main__': | ||
| 100 | - # 将话题数据写入 CSV 文件 | ||
| 101 | - # print(mergeTopics(getTopicByArticle(), getTopicByComments())) | ||
| 102 | - # writeTopicsToCSV(merged_topics, 'merged_topics.csv') | ||
| 103 | - print(getAllTopicData()) | 23 | + # Use the improved time series prediction approach |
| 24 | + predictions = predict_future_values(createdAt, forecast_days=5) | ||
| 25 | + | ||
| 26 | + # Merge historical data and predictions | ||
| 27 | + combined_data = {**createdAt, **predictions} | ||
| 28 | + combined_data = {k: combined_data[k] for k in sorted(combined_data, key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d"))} | ||
| 104 | 29 | ||
| 30 | + print(list(combined_data.keys()), list(combined_data.values())) | ||
| 31 | + return list(combined_data.keys()), list(combined_data.values()) |
-
Please register or login to post a comment