File size: 3,446 Bytes
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
 
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Any

import numpy as np

from src.data.containers import TimeSeriesContainer
from src.synthetic_generation.abstract_classes import GeneratorWrapper
from src.synthetic_generation.generator_params import SawToothGeneratorParams
from src.synthetic_generation.sawtooth.sawtooth_generator import SawToothGenerator


class SawToothGeneratorWrapper(GeneratorWrapper):
    """
    Wrapper for SawToothGenerator to generate batches of multivariate time series data
    by stacking multiple univariate sawtooth wave series. Accepts a SawToothGeneratorParams
    dataclass for configuration.
    """

    def __init__(self, params: SawToothGeneratorParams):
        super().__init__(params)
        self.params: SawToothGeneratorParams = params

    def _sample_parameters(self, batch_size: int) -> dict[str, Any]:
        """
        Sample parameter values for batch generation with SawToothGenerator.

        Returns
        -------
        Dict[str, Any]
            Dictionary containing sampled parameter values.
        """
        params = super()._sample_parameters(batch_size)
        params.update(
            {
                "length": self.params.length,
                "periods": self.params.periods,
                "amplitude_range": self.params.amplitude_range,
                "phase_range": self.params.phase_range,
                "trend_slope_range": self.params.trend_slope_range,
                "seasonality_amplitude_range": self.params.seasonality_amplitude_range,
                "add_trend": self.params.add_trend,
                "add_seasonality": self.params.add_seasonality,
            }
        )
        return params

    def generate_batch(
        self,
        batch_size: int,
        seed: int | None = None,
        params: dict[str, Any] | None = None,
    ) -> TimeSeriesContainer:
        """
        Generate a batch of synthetic multivariate time series using SawToothGenerator.

        Parameters
        ----------
        batch_size : int
            Number of time series to generate.
        seed : int, optional
            Random seed for this batch (default: None).
        params : Dict[str, Any], optional
            Pre-sampled parameters to use. If None, parameters will be sampled.

        Returns
        -------
        TimeSeriesContainer
            A container with the generated time series data.
        """
        if seed is not None:
            self._set_random_seeds(seed)
        if params is None:
            params = self._sample_parameters(batch_size)

        generator = SawToothGenerator(
            length=params["length"],
            periods=params["periods"],
            amplitude_range=params["amplitude_range"],
            phase_range=params["phase_range"],
            trend_slope_range=params["trend_slope_range"],
            seasonality_amplitude_range=params["seasonality_amplitude_range"],
            add_trend=params["add_trend"],
            add_seasonality=params["add_seasonality"],
            random_seed=seed,
        )

        batch_values = []

        for i in range(batch_size):
            batch_seed = None if seed is None else seed + i
            values = generator.generate_time_series(random_seed=batch_seed)
            batch_values.append(values)

        return TimeSeriesContainer(
            values=np.array(batch_values),
            start=params["start"],
            frequency=params["frequency"],
        )