Source code for awt_quant.forecast.lag_llama_forecast

import os
import yfinance as yf
import pandas as pd
import torch
from gluonts.dataset.pandas import PandasDataset
from resources.lag_llama.repo.lag_llama.gluon.estimator import LagLlamaEstimator
from gluonts.evaluation import make_evaluation_predictions, Evaluator
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from itertools import islice

# Constant for Lag-Llama checkpoint storage
[docs] LAG_LLAMA_CKPT_PATH = "resources/lag_llama/model/lag_llama.ckpt"
[docs] def get_device(): """ Returns the appropriate device for computation. Uses CUDA if available, otherwise falls back to CPU. Returns: torch.device: The device to use for model computations. """ return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
[docs] def fetch_stock_data(ticker, start_date, end_date): """ Fetches stock price data from Yahoo Finance and formats it for Lag-Llama. Args: ticker (str): Stock symbol. start_date (str): Start date in 'YYYY-MM-DD' format. end_date (str): End date in 'YYYY-MM-DD' format. Returns: PandasDataset: The dataset formatted for Lag-Llama. """ df = yf.download(ticker, start=start_date, end=end_date) df = df.reset_index() # Format for Lag-Llama df["target"] = df["Close"].astype("float32") # Set target variable df["item_id"] = ticker # Stock identifier df = df[["Date", "target", "item_id"]] return PandasDataset.from_long_dataframe(df, target="target", item_id="item_id")
[docs] def get_lag_llama_predictions(dataset, prediction_length, num_samples=100, context_length=32, use_rope_scaling=False): """ Runs Lag-Llama predictions on a given dataset. Args: dataset (PandasDataset): The dataset for forecasting. prediction_length (int): Forecast horizon. num_samples (int, optional): Number of Monte Carlo samples per timestep. Defaults to 100. context_length (int, optional): Context length for model. Defaults to 32. use_rope_scaling (bool, optional): Whether to use RoPE scaling for extended context. Defaults to False. Returns: Tuple[list, list]: Forecasts and actual time series. """ device = get_device() if not os.path.exists(LAG_LLAMA_CKPT_PATH): raise FileNotFoundError(f"Lag-Llama checkpoint not found at {LAG_LLAMA_CKPT_PATH}. Download it before proceeding.") ckpt = torch.load(LAG_LLAMA_CKPT_PATH, map_location=device) estimator_args = ckpt["hyper_parameters"]["model_kwargs"] rope_scaling_args = { "type": "linear", "factor": max(1.0, (context_length + prediction_length) / estimator_args["context_length"]), } estimator = LagLlamaEstimator( ckpt_path=LAG_LLAMA_CKPT_PATH, prediction_length=prediction_length, context_length=context_length, input_size=estimator_args["input_size"], n_layer=estimator_args["n_layer"], n_embd_per_head=estimator_args["n_embd_per_head"], n_head=estimator_args["n_head"], scaling=estimator_args["scaling"], time_feat=estimator_args["time_feat"], rope_scaling=rope_scaling_args if use_rope_scaling else None, batch_size=1, num_parallel_samples=num_samples, device=device, ) predictor = estimator.create_predictor( estimator.create_transformation(), estimator.create_lightning_module() ) forecast_it, ts_it = make_evaluation_predictions( dataset=dataset, predictor=predictor, num_samples=num_samples ) return list(forecast_it), list(ts_it)
[docs] def plot_forecasts(forecasts, tss, ticker, prediction_length): """ Plots actual stock prices along with forecasted values. Args: forecasts (list): List of forecasted series. tss (list): List of actual time series. ticker (str): Stock ticker symbol. prediction_length (int): Forecast horizon. """ plt.figure(figsize=(12, 6)) date_formatter = mdates.DateFormatter('%b %d') for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 1): plt.plot(ts[-4 * prediction_length:].to_timestamp(), label="Actual", color='black') forecast.plot(color='g') # Forecasted path plt.xticks(rotation=45) plt.gca().xaxis.set_major_formatter(date_formatter) plt.title(f"Lag-Llama Forecast for {ticker}") plt.legend() plt.show()
[docs] def evaluate_forecasts(forecasts, tss): """ Evaluates forecasts using GluonTS Evaluator. Args: forecasts (list): Forecasted time series. tss (list): Actual time series. Returns: dict: Aggregated evaluation metrics including CRPS. """ evaluator = Evaluator() agg_metrics, _ = evaluator(iter(tss), iter(forecasts)) return agg_metrics
[docs] def backtest(forecasts, actual_series): """ Computes backtest evaluation metrics by comparing forecasts against actual values. Args: forecasts (list): List of forecasted time series. actual_series (list): List of actual time series. Returns: dict: Evaluation metrics including mean error and quantiles. """ forecast_vals = forecasts[0].samples.mean(axis=0) # Mean forecast values actual_vals = actual_series[0].to_numpy()[-len(forecast_vals):] # Align with forecast length error = 100 * (actual_vals - forecast_vals) / actual_vals # Percent error quantiles = [0.01, 0.25, 0.5, 0.75, 0.99] quantile_errors = {q: 100 * (actual_vals - forecasts[0].quantile(q)) / actual_vals for q in quantiles} return { "Mean Error": error.mean(), "Quantile Errors": quantile_errors }
[docs] def main(): """ Runs the end-to-end pipeline: - Fetches stock data - Runs Lag-Llama forecasting with context length 32 - Evaluates and plots the forecasts - Performs backtesting """ ticker = "AAPL" start_date = "2023-01-01" end_date = "2024-01-01" prediction_length = 30 num_samples = 100 dataset = fetch_stock_data(ticker, start_date, end_date) # Forecast with context length 32 forecasts_ctx_len_32, tss_ctx_len_32 = get_lag_llama_predictions( dataset, prediction_length, num_samples, context_length=32, use_rope_scaling=False ) plot_forecasts(forecasts_ctx_len_32, tss_ctx_len_32, ticker, prediction_length) # Run Evaluator eval_metrics = evaluate_forecasts(forecasts_ctx_len_32, tss_ctx_len_32) print("CRPS:", eval_metrics["mean_wQuantileLoss"]) # Run Backtest backtest_results = backtest(forecasts_ctx_len_32, tss_ctx_len_32) print("Backtest Results:", backtest_results)
if __name__ == "__main__": main()