File size: 3,446 Bytes
0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
from typing import Any
import numpy as np
from src.data.containers import TimeSeriesContainer
from src.synthetic_generation.abstract_classes import GeneratorWrapper
from src.synthetic_generation.generator_params import SawToothGeneratorParams
from src.synthetic_generation.sawtooth.sawtooth_generator import SawToothGenerator
class SawToothGeneratorWrapper(GeneratorWrapper):
"""
Wrapper for SawToothGenerator to generate batches of multivariate time series data
by stacking multiple univariate sawtooth wave series. Accepts a SawToothGeneratorParams
dataclass for configuration.
"""
def __init__(self, params: SawToothGeneratorParams):
super().__init__(params)
self.params: SawToothGeneratorParams = params
def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
"""
Sample parameter values for batch generation with SawToothGenerator.
Returns
-------
Dict[str, Any]
Dictionary containing sampled parameter values.
"""
params = super()._sample_parameters(batch_size)
params.update(
{
"length": self.params.length,
"periods": self.params.periods,
"amplitude_range": self.params.amplitude_range,
"phase_range": self.params.phase_range,
"trend_slope_range": self.params.trend_slope_range,
"seasonality_amplitude_range": self.params.seasonality_amplitude_range,
"add_trend": self.params.add_trend,
"add_seasonality": self.params.add_seasonality,
}
)
return params
def generate_batch(
self,
batch_size: int,
seed: int | None = None,
params: dict[str, Any] | None = None,
) -> TimeSeriesContainer:
"""
Generate a batch of synthetic multivariate time series using SawToothGenerator.
Parameters
----------
batch_size : int
Number of time series to generate.
seed : int, optional
Random seed for this batch (default: None).
params : Dict[str, Any], optional
Pre-sampled parameters to use. If None, parameters will be sampled.
Returns
-------
TimeSeriesContainer
A container with the generated time series data.
"""
if seed is not None:
self._set_random_seeds(seed)
if params is None:
params = self._sample_parameters(batch_size)
generator = SawToothGenerator(
length=params["length"],
periods=params["periods"],
amplitude_range=params["amplitude_range"],
phase_range=params["phase_range"],
trend_slope_range=params["trend_slope_range"],
seasonality_amplitude_range=params["seasonality_amplitude_range"],
add_trend=params["add_trend"],
add_seasonality=params["add_seasonality"],
random_seed=seed,
)
batch_values = []
for i in range(batch_size):
batch_seed = None if seed is None else seed + i
values = generator.generate_time_series(random_seed=batch_seed)
batch_values.append(values)
return TimeSeriesContainer(
values=np.array(batch_values),
start=params["start"],
frequency=params["frequency"],
)
|