awt_quant.forecast.lag_llama_forecast

Module Contents

awt_quant.forecast.lag_llama_forecast.LAG_LLAMA_CKPT_PATH = 'resources/lag_llama/model/lag_llama.ckpt'[source]
awt_quant.forecast.lag_llama_forecast.get_device()[source]

Returns the appropriate device for computation.

Uses CUDA if available, otherwise falls back to CPU.

Returns:

The device to use for model computations.

Return type:

torch.device

awt_quant.forecast.lag_llama_forecast.fetch_stock_data(ticker, start_date, end_date)[source]

Fetches stock price data from Yahoo Finance and formats it for Lag-Llama.

Parameters:
  • ticker (str) – Stock symbol.

  • start_date (str) – Start date in ‘YYYY-MM-DD’ format.

  • end_date (str) – End date in ‘YYYY-MM-DD’ format.

Returns:

The dataset formatted for Lag-Llama.

Return type:

PandasDataset

awt_quant.forecast.lag_llama_forecast.get_lag_llama_predictions(dataset, prediction_length, num_samples=100, context_length=32, use_rope_scaling=False)[source]

Runs Lag-Llama predictions on a given dataset.

Parameters:
  • dataset (PandasDataset) – The dataset for forecasting.

  • prediction_length (int) – Forecast horizon.

  • num_samples (int, optional) – Number of Monte Carlo samples per timestep. Defaults to 100.

  • context_length (int, optional) – Context length for model. Defaults to 32.

  • use_rope_scaling (bool, optional) – Whether to use RoPE scaling for extended context. Defaults to False.

Returns:

Forecasts and actual time series.

Return type:

Tuple[list, list]

awt_quant.forecast.lag_llama_forecast.plot_forecasts(forecasts, tss, ticker, prediction_length)[source]

Plots actual stock prices along with forecasted values.

Parameters:
  • forecasts (list) – List of forecasted series.

  • tss (list) – List of actual time series.

  • ticker (str) – Stock ticker symbol.

  • prediction_length (int) – Forecast horizon.

awt_quant.forecast.lag_llama_forecast.evaluate_forecasts(forecasts, tss)[source]

Evaluates forecasts using GluonTS Evaluator.

Parameters:
  • forecasts (list) – Forecasted time series.

  • tss (list) – Actual time series.

Returns:

Aggregated evaluation metrics including CRPS.

Return type:

dict

awt_quant.forecast.lag_llama_forecast.backtest(forecasts, actual_series)[source]

Computes backtest evaluation metrics by comparing forecasts against actual values.

Parameters:
  • forecasts (list) – List of forecasted time series.

  • actual_series (list) – List of actual time series.

Returns:

Evaluation metrics including mean error and quantiles.

Return type:

dict

awt_quant.forecast.lag_llama_forecast.main()[source]

Runs the end-to-end pipeline: - Fetches stock data - Runs Lag-Llama forecasting with context length 32 - Evaluates and plots the forecasts - Performs backtesting