| from typing import Any | |
| import numpy as np | |
| from src.data.containers import TimeSeriesContainer | |
| from src.synthetic_generation.abstract_classes import GeneratorWrapper | |
| from src.synthetic_generation.generator_params import SawToothGeneratorParams | |
| from src.synthetic_generation.sawtooth.sawtooth_generator import SawToothGenerator | |
| class SawToothGeneratorWrapper(GeneratorWrapper): | |
| """ | |
| Wrapper for SawToothGenerator to generate batches of multivariate time series data | |
| by stacking multiple univariate sawtooth wave series. Accepts a SawToothGeneratorParams | |
| dataclass for configuration. | |
| """ | |
| def __init__(self, params: SawToothGeneratorParams): | |
| super().__init__(params) | |
| self.params: SawToothGeneratorParams = params | |
| def _sample_parameters(self, batch_size: int) -> dict[str, Any]: | |
| """ | |
| Sample parameter values for batch generation with SawToothGenerator. | |
| Returns | |
| ------- | |
| Dict[str, Any] | |
| Dictionary containing sampled parameter values. | |
| """ | |
| params = super()._sample_parameters(batch_size) | |
| params.update( | |
| { | |
| "length": self.params.length, | |
| "periods": self.params.periods, | |
| "amplitude_range": self.params.amplitude_range, | |
| "phase_range": self.params.phase_range, | |
| "trend_slope_range": self.params.trend_slope_range, | |
| "seasonality_amplitude_range": self.params.seasonality_amplitude_range, | |
| "add_trend": self.params.add_trend, | |
| "add_seasonality": self.params.add_seasonality, | |
| } | |
| ) | |
| return params | |
| def generate_batch( | |
| self, | |
| batch_size: int, | |
| seed: int | None = None, | |
| params: dict[str, Any] | None = None, | |
| ) -> TimeSeriesContainer: | |
| """ | |
| Generate a batch of synthetic multivariate time series using SawToothGenerator. | |
| Parameters | |
| ---------- | |
| batch_size : int | |
| Number of time series to generate. | |
| seed : int, optional | |
| Random seed for this batch (default: None). | |
| params : Dict[str, Any], optional | |
| Pre-sampled parameters to use. If None, parameters will be sampled. | |
| Returns | |
| ------- | |
| TimeSeriesContainer | |
| A container with the generated time series data. | |
| """ | |
| if seed is not None: | |
| self._set_random_seeds(seed) | |
| if params is None: | |
| params = self._sample_parameters(batch_size) | |
| generator = SawToothGenerator( | |
| length=params["length"], | |
| periods=params["periods"], | |
| amplitude_range=params["amplitude_range"], | |
| phase_range=params["phase_range"], | |
| trend_slope_range=params["trend_slope_range"], | |
| seasonality_amplitude_range=params["seasonality_amplitude_range"], | |
| add_trend=params["add_trend"], | |
| add_seasonality=params["add_seasonality"], | |
| random_seed=seed, | |
| ) | |
| batch_values = [] | |
| for i in range(batch_size): | |
| batch_seed = None if seed is None else seed + i | |
| values = generator.generate_time_series(random_seed=batch_seed) | |
| batch_values.append(values) | |
| return TimeSeriesContainer( | |
| values=np.array(batch_values), | |
| start=params["start"], | |
| frequency=params["frequency"], | |
| ) | |