TempoPFN / src /synthetic_generation /sawtooth /sawtooth_generator_wrapper.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
from typing import Any
import numpy as np
from src.data.containers import TimeSeriesContainer
from src.synthetic_generation.abstract_classes import GeneratorWrapper
from src.synthetic_generation.generator_params import SawToothGeneratorParams
from src.synthetic_generation.sawtooth.sawtooth_generator import SawToothGenerator
class SawToothGeneratorWrapper(GeneratorWrapper):
"""
Wrapper for SawToothGenerator to generate batches of multivariate time series data
by stacking multiple univariate sawtooth wave series. Accepts a SawToothGeneratorParams
dataclass for configuration.
"""
def __init__(self, params: SawToothGeneratorParams):
super().__init__(params)
self.params: SawToothGeneratorParams = params
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
"""
Sample parameter values for batch generation with SawToothGenerator.
Returns
-------
Dict[str, Any]
Dictionary containing sampled parameter values.
"""
params = super()._sample_parameters(batch_size)
params.update(
{
"length": self.params.length,
"periods": self.params.periods,
"amplitude_range": self.params.amplitude_range,
"phase_range": self.params.phase_range,
"trend_slope_range": self.params.trend_slope_range,
"seasonality_amplitude_range": self.params.seasonality_amplitude_range,
"add_trend": self.params.add_trend,
"add_seasonality": self.params.add_seasonality,
}
)
return params
def generate_batch(
self,
batch_size: int,
seed: int | None = None,
params: dict[str, Any] | None = None,
) -> TimeSeriesContainer:
"""
Generate a batch of synthetic multivariate time series using SawToothGenerator.
Parameters
----------
batch_size : int
Number of time series to generate.
seed : int, optional
Random seed for this batch (default: None).
params : Dict[str, Any], optional
Pre-sampled parameters to use. If None, parameters will be sampled.
Returns
-------
TimeSeriesContainer
A container with the generated time series data.
"""
if seed is not None:
self._set_random_seeds(seed)
if params is None:
params = self._sample_parameters(batch_size)
generator = SawToothGenerator(
length=params["length"],
periods=params["periods"],
amplitude_range=params["amplitude_range"],
phase_range=params["phase_range"],
trend_slope_range=params["trend_slope_range"],
seasonality_amplitude_range=params["seasonality_amplitude_range"],
add_trend=params["add_trend"],
add_seasonality=params["add_seasonality"],
random_seed=seed,
)
batch_values = []
for i in range(batch_size):
batch_seed = None if seed is None else seed + i
values = generator.generate_time_series(random_seed=batch_seed)
batch_values.append(values)
return TimeSeriesContainer(
values=np.array(batch_values),
start=params["start"],
frequency=params["frequency"],
)