wjhgq
Committed by GitHub

Update predict.py. The prediction model is optimized to a time series model, whi…

…ch significantly improves the modeling fitness.

In the original method, only linear regression is used to perform simple trend extrapolation, which leads to insufficient prediction accuracy. This optimization adopts time series model, and uses the auto_arima method of pmdarima to automatically select appropriate model parameters (including p, d, q and seasonal parameters) according to historical data. It significantly improves the suitability of the model in time series modeling. In this way, the model can better capture the trend and periodicity of the data, and predict the future heat more reasonable and accurate.
1 import numpy as np 1 import numpy as np
2 import datetime 2 import datetime
3 -import matplotlib.pyplot as plt 3 +import pandas as pd
  4 +from pmdarima import auto_arima
4 5
5 -  
6 -def datetime_to_number(date: str): # 格式化日期转换为 integer 6 +def datetime_to_number(date: str):
  7 + """Convert a date string 'YYYY-MM-DD' to a relative day number."""
7 date_number = datetime.datetime.strptime(date, "%Y-%m-%d") 8 date_number = datetime.datetime.strptime(date, "%Y-%m-%d")
8 base_number = datetime.datetime.strptime("2024-1-1", "%Y-%m-%d") 9 base_number = datetime.datetime.strptime("2024-1-1", "%Y-%m-%d")
9 return (date_number - base_number).days 10 return (date_number - base_number).days
10 11
  12 +def predict_future_values(data, forecast_days=5):
  13 + """
  14 + Use auto_arima from pmdarima to fit a suitable ARIMA/SARIMA model for the time series,
  15 + then predict future values for the specified number of days.
11 16
12 -def predict_future_values(data):  
13 - # 提取并排序日期  
14 - sorted_dates = sorted(data.keys(), key=lambda date: datetime.datetime.strptime(date, "%Y-%m-%d"))  
15 - sorted_data = {k: data[k] for k in sorted_dates}  
16 -  
17 - # 将日期转换为整数并提取相应的值  
18 - xs = np.array([datetime_to_number(date) for date in sorted_data.keys()])  
19 - ys = np.array([data[date] for date in sorted_data.keys()]) 17 + Parameters:
  18 + data: dict, keys are date strings 'YYYY-MM-DD', values are integer counts
  19 + forecast_days: int, number of days to predict into the future
20 20
21 - # 拟合线性回归模型  
22 - fit = np.polyfit(xs, ys, 1)  
23 - fn = np.poly1d(fit) 21 + Returns:
  22 + predictions: dict, keys are future date strings 'YYYY-MM-DD', values are predicted integers (≥0)
  23 + """
  24 + if not data:
  25 + return {}
24 26
25 - # 获取最新日期,并生成未来三天的日期  
26 - latest_date = sorted_dates[-1]  
27 - latest_date_obj = datetime.datetime.strptime(latest_date, "%Y-%m-%d")  
28 - future_dates = [(latest_date_obj + datetime.timedelta(days=i)).strftime("%Y-%m-%d") for i in range(1, 6)] 27 + # Sort data by date
  28 + sorted_dates = sorted(data.keys(), key=lambda d: datetime.datetime.strptime(d, "%Y-%m-%d"))
  29 + start_date = sorted_dates[0]
  30 + end_date = sorted_dates[-1]
  31 +
  32 + # Create a full date range to ensure continuity in the time series
  33 + full_range = pd.date_range(start=start_date, end=end_date, freq='D')
  34 + ts = pd.Series(0, index=full_range, dtype=float)
  35 + for d in data:
  36 + ts[pd.to_datetime(d)] = data[d]
29 37
30 - # 预测未来日期的值 38 + # Simple smoothing: optional step to reduce noise (moving average over 3 days)
  39 + # This is a mild smoothing to handle noisy data. You can comment this out if not needed.
  40 + ts_smoothed = ts.rolling(window=3, min_periods=1).mean()
  41 +
  42 + # Fit the time series with auto_arima to find the best parameters
  43 + model = auto_arima(ts_smoothed,
  44 + start_p=1, start_q=1,
  45 + max_p=5, max_q=5,
  46 + seasonal=False,
  47 + trace=False, error_action='ignore', suppress_warnings=True, stepwise=True)
  48 +
  49 + # Predict the future values
  50 + forecast = model.predict(n_periods=forecast_days)
  51 + # Construct future dates
  52 + last_date = pd.to_datetime(end_date)
  53 + future_dates = [last_date + datetime.timedelta(days=i) for i in range(1, forecast_days+1)]
  54 +
  55 + # Convert forecast results to dict with non-negative integers
31 predictions = {} 56 predictions = {}
32 - for date in future_dates:  
33 - date_num = datetime_to_number(date)  
34 - if int(fn(date_num))<=0:  
35 - predictions[date] = 0  
36 - else:  
37 - predictions[date] = int(fn(date_num)) 57 + for d, v in zip(future_dates, forecast):
  58 + predictions[d.strftime("%Y-%m-%d")] = max(int(round(v)), 0)
38 59
39 return predictions 60 return predictions
40 61
41 -  
42 if __name__ == '__main__': 62 if __name__ == '__main__':
43 - data = {'2024-06-15': 1, '2024-06-18': 1, '2024-06-22': 1, '2024-06-23': 1, '2024-07-01': 3, '2024-07-02': 4, '2024-07-03': 4, '2024-07-04': 14}  
44 - predictions = predict_future_values(data)  
45 - print(predictions)  
46 - # for date, value in predictions.items():  
47 - # print(f'{date} PREDICTION: {value}') 63 + data = {
  64 + '2024-06-15': 1, '2024-06-18': 1, '2024-06-22': 1,
  65 + '2024-06-23': 1, '2024-07-01': 3, '2024-07-02': 4,
  66 + '2024-07-03': 4, '2024-07-04': 14
  67 + }
  68 + preds = predict_future_values(data)
  69 + print(preds)