awt_quant.forecast.lag_llama_forecast¶
Module Contents¶
- awt_quant.forecast.lag_llama_forecast.LAG_LLAMA_CKPT_PATH = 'resources/lag_llama/model/lag_llama.ckpt'[source]¶
- awt_quant.forecast.lag_llama_forecast.get_device()[source]¶
Returns the appropriate device for computation.
Uses CUDA if available, otherwise falls back to CPU.
- Returns:
The device to use for model computations.
- Return type:
- awt_quant.forecast.lag_llama_forecast.fetch_stock_data(ticker, start_date, end_date)[source]¶
Fetches stock price data from Yahoo Finance and formats it for Lag-Llama.
- awt_quant.forecast.lag_llama_forecast.get_lag_llama_predictions(dataset, prediction_length, num_samples=100, context_length=32, use_rope_scaling=False)[source]¶
Runs Lag-Llama predictions on a given dataset.
- Parameters:
dataset (PandasDataset) – The dataset for forecasting.
prediction_length (int) – Forecast horizon.
num_samples (int, optional) – Number of Monte Carlo samples per timestep. Defaults to 100.
context_length (int, optional) – Context length for model. Defaults to 32.
use_rope_scaling (bool, optional) – Whether to use RoPE scaling for extended context. Defaults to False.
- Returns:
Forecasts and actual time series.
- Return type:
- awt_quant.forecast.lag_llama_forecast.plot_forecasts(forecasts, tss, ticker, prediction_length)[source]¶
Plots actual stock prices along with forecasted values.
- awt_quant.forecast.lag_llama_forecast.evaluate_forecasts(forecasts, tss)[source]¶
Evaluates forecasts using GluonTS Evaluator.