yan123yan
commited on
Commit
·
47fe089
1
Parent(s):
8cf3e4f
first version
Browse files- data/file1.dat.npz +3 -0
- data/file10.dat.npz +3 -0
- data/file2.dat.npz +3 -0
- data/file3.dat.npz +3 -0
- data/file4.dat.npz +3 -0
- data/file5.dat.npz +3 -0
- data/file6.dat.npz +3 -0
- data/file7.dat.npz +3 -0
- data/file8.dat.npz +3 -0
- data/file9.dat.npz +3 -0
- model/__pycache__/lstm.cpython-310.pyc +0 -0
- model/__pycache__/tcn.cpython-310.pyc +0 -0
- model/__pycache__/tcn_module.cpython-310.pyc +0 -0
- model/lstm.ckpt +3 -0
- model/lstm.py +22 -0
- model/tcn.ckpt +3 -0
- model/tcn.py +40 -0
- model/tcn_module.py +511 -0
- pages/inference.py +547 -0
- prediction.py +196 -0
- requirements.txt +7 -0
- utils/__pycache__/highlevel.cpython-310.pyc +0 -0
- utils/__pycache__/lowlevel.cpython-310.pyc +0 -0
- utils/__pycache__/metrics.cpython-310.pyc +0 -0
- utils/__pycache__/midpoint.cpython-310.pyc +0 -0
- utils/__pycache__/transform.cpython-310.pyc +0 -0
- utils/highlevel.py +160 -0
- utils/lowlevel.py +158 -0
- utils/metrics.py +28 -0
- utils/midpoint.py +164 -0
- utils/transform.py +8 -0
data/file1.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90713e007e25e2b4467711981274dec5f15548666bf7867a30cdf5e189b94b80
|
| 3 |
+
size 7200262
|
data/file10.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:123c044fc18c7e1b0671261fc081aa7e9a60eab726f35b2758b75f3e70ebe76e
|
| 3 |
+
size 7200262
|
data/file2.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ebdbdcb3465e49f16073eb828dc18163a12e8b5968cb7f4661e3931c13ac0cea
|
| 3 |
+
size 7200262
|
data/file3.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b14cf26774c287dd0fb2351a6fc6ce3a2eae135a57c265a93c03f017d479d3a
|
| 3 |
+
size 7200262
|
data/file4.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7810f50b9c30a19fbfcb3db77bf838b02df4fc54f6caa63e8dd3ca0f2abba6c1
|
| 3 |
+
size 7200262
|
data/file5.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:34f3485194b57325b0f32554c4f14218389a6d47d5b6edb7e37626fc48d6aae8
|
| 3 |
+
size 7200262
|
data/file6.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d59084be5aff0eb578c7ff5ee62027ef853d1f5d8d2794d6230d5b706aa0f6aa
|
| 3 |
+
size 7200262
|
data/file7.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9035c00bb34d17940b033e3bae40097296c493a4852630fbcf963c6f80391f5a
|
| 3 |
+
size 7200262
|
data/file8.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:955406a5ec5af34373c4ae8573195ca580bfcfbbef13efdcb410fc05d58d66f8
|
| 3 |
+
size 7200262
|
data/file9.dat.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3ea42398e0062a3280981795ea8fd3a49bd5e104de178decfad9e69e41c0de8c
|
| 3 |
+
size 7200262
|
model/__pycache__/lstm.cpython-310.pyc
ADDED
|
Binary file (1.2 kB). View file
|
|
|
model/__pycache__/tcn.cpython-310.pyc
ADDED
|
Binary file (1.59 kB). View file
|
|
|
model/__pycache__/tcn_module.cpython-310.pyc
ADDED
|
Binary file (15.6 kB). View file
|
|
|
model/lstm.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1080dbdc37acdb1e9e6a29c140711908d2426b39735617eafbadd49fb5772ef4
|
| 3 |
+
size 7286190
|
model/lstm.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class LSTMModel(pl.LightningModule):
|
| 6 |
+
def __init__(self, **config):
|
| 7 |
+
super(LSTMModel, self).__init__()
|
| 8 |
+
self.save_hyperparameters(config)
|
| 9 |
+
|
| 10 |
+
self.lstm = nn.LSTM(input_size=21,hidden_size=512,num_layers=3,proj_size=21,batch_first=True)
|
| 11 |
+
self.linear = nn.Linear(in_features=21, out_features=7)
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
outputs = []
|
| 15 |
+
hidden, cell = None, None
|
| 16 |
+
for i in range(20):
|
| 17 |
+
if i == 0:
|
| 18 |
+
output, (hidden, cell) = self.lstm(x)
|
| 19 |
+
else:
|
| 20 |
+
output, (hidden, cell) = self.lstm(output[:, -1, :].unsqueeze(1), (hidden, cell))
|
| 21 |
+
outputs.append(self.linear(output[:, -1, :]))
|
| 22 |
+
return torch.stack(outputs, dim=1)
|
model/tcn.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fc975af786670412a2e419978636038452467676e1a9dac59d2ed77033f9f67b
|
| 3 |
+
size 43742454
|
model/tcn.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ["KERAS_BACKEND"] = "torch"
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from keras.layers import Input, Dense
|
| 8 |
+
from keras.models import Model
|
| 9 |
+
from model.tcn_module import TCN
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TCNModel(pl.LightningModule):
|
| 13 |
+
def __init__(self, **config):
|
| 14 |
+
super(TCNModel, self).__init__()
|
| 15 |
+
self.save_hyperparameters(config)
|
| 16 |
+
|
| 17 |
+
input_layer = Input(shape=(self.hparams.windows_size, self.hparams.input_size))
|
| 18 |
+
self.tcn = TCN(input_shape=(self.hparams.windows_size, self.hparams.input_size))(input_layer)
|
| 19 |
+
self.linear = Dense(7)(self.tcn)
|
| 20 |
+
self.model = Model(inputs=input_layer, outputs=self.linear)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
output = self.model(x)
|
| 24 |
+
return torch.stack([output], dim=1)
|
| 25 |
+
|
| 26 |
+
def move_custom_layers_to_device(model, device):
|
| 27 |
+
for name, module in model.named_children():
|
| 28 |
+
# 如果是标准层,named_children已经处理了
|
| 29 |
+
if isinstance(module, nn.Module):
|
| 30 |
+
continue
|
| 31 |
+
|
| 32 |
+
# 对于非标准层,例如包含在列表或字典中的层
|
| 33 |
+
if isinstance(module, list):
|
| 34 |
+
for sub_module in module:
|
| 35 |
+
if isinstance(sub_module, nn.Module):
|
| 36 |
+
sub_module.to(device)
|
| 37 |
+
elif isinstance(module, dict):
|
| 38 |
+
for sub_module in module.values():
|
| 39 |
+
if isinstance(sub_module, nn.Module):
|
| 40 |
+
sub_module.to(device)
|
model/tcn_module.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
os.environ["KERAS_BACKEND"] = "torch"
|
| 6 |
+
import keras
|
| 7 |
+
|
| 8 |
+
# from keras_core import backend as K, Model, Input, optimizers
|
| 9 |
+
# from keras_core import backend as Model, Input, optimizers
|
| 10 |
+
# from keras_core import backend as K
|
| 11 |
+
|
| 12 |
+
from keras import Model
|
| 13 |
+
from keras import optimizers
|
| 14 |
+
from keras import ops as K
|
| 15 |
+
from keras import config as KK
|
| 16 |
+
|
| 17 |
+
from keras import layers
|
| 18 |
+
from keras.layers import Input, Layer, Conv1D, Dense, BatchNormalization, LayerNormalization, Activation, SpatialDropout1D, Lambda
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def is_power_of_two(num: int):
|
| 22 |
+
return num != 0 and ((num & (num - 1)) == 0)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def adjust_dilations(dilations: list):
|
| 26 |
+
if all([is_power_of_two(i) for i in dilations]):
|
| 27 |
+
return dilations
|
| 28 |
+
else:
|
| 29 |
+
new_dilations = [2 ** i for i in dilations]
|
| 30 |
+
return new_dilations
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ResidualBlock(Layer):
|
| 34 |
+
|
| 35 |
+
def __init__(self,
|
| 36 |
+
dilation_rate: int,
|
| 37 |
+
nb_filters: int,
|
| 38 |
+
kernel_size: int,
|
| 39 |
+
padding: str,
|
| 40 |
+
activation: str = 'relu',
|
| 41 |
+
dropout_rate: float = 0,
|
| 42 |
+
kernel_initializer: str = 'he_normal',
|
| 43 |
+
use_batch_norm: bool = False,
|
| 44 |
+
use_layer_norm: bool = False,
|
| 45 |
+
use_weight_norm: bool = False,
|
| 46 |
+
**kwargs):
|
| 47 |
+
"""Defines the residual block for the WaveNet TCN
|
| 48 |
+
Args:
|
| 49 |
+
x: The previous layer in the model
|
| 50 |
+
training: boolean indicating whether the layer should behave in training mode or in inference mode
|
| 51 |
+
dilation_rate: The dilation power of 2 we are using for this residual block
|
| 52 |
+
nb_filters: The number of convolutional filters to use in this block
|
| 53 |
+
kernel_size: The size of the convolutional kernel
|
| 54 |
+
padding: The padding used in the convolutional layers, 'same' or 'causal'.
|
| 55 |
+
activation: The final activation used in o = Activation(x + F(x))
|
| 56 |
+
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
|
| 57 |
+
kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
|
| 58 |
+
use_batch_norm: Whether to use batch normalization in the residual layers or not.
|
| 59 |
+
use_layer_norm: Whether to use layer normalization in the residual layers or not.
|
| 60 |
+
use_weight_norm: Whether to use weight normalization in the residual layers or not.
|
| 61 |
+
kwargs: Any initializers for Layer class.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
self.dilation_rate = dilation_rate
|
| 65 |
+
self.nb_filters = nb_filters
|
| 66 |
+
self.kernel_size = kernel_size
|
| 67 |
+
self.padding = padding
|
| 68 |
+
self.activation = activation
|
| 69 |
+
self.dropout_rate = dropout_rate
|
| 70 |
+
self.use_batch_norm = use_batch_norm
|
| 71 |
+
self.use_layer_norm = use_layer_norm
|
| 72 |
+
self.use_weight_norm = use_weight_norm
|
| 73 |
+
self.kernel_initializer = kernel_initializer
|
| 74 |
+
self.layers = []
|
| 75 |
+
self.shape_match_conv = None
|
| 76 |
+
self.res_output_shape = None
|
| 77 |
+
self.final_activation = None
|
| 78 |
+
|
| 79 |
+
super(ResidualBlock, self).__init__(**kwargs)
|
| 80 |
+
|
| 81 |
+
def _build_layer(self, layer):
|
| 82 |
+
"""Helper function for building layer
|
| 83 |
+
Args:
|
| 84 |
+
layer: Appends layer to internal layer list and builds it based on the current output
|
| 85 |
+
shape of ResidualBlocK. Updates current output shape.
|
| 86 |
+
"""
|
| 87 |
+
self.layers.append(layer)
|
| 88 |
+
self.layers[-1].build(self.res_output_shape)
|
| 89 |
+
self.res_output_shape = self.layers[-1].compute_output_shape(self.res_output_shape)
|
| 90 |
+
|
| 91 |
+
def build(self, input_shape):
|
| 92 |
+
|
| 93 |
+
#with K.name_scope(self.name): # name scope used to make sure weights get unique names
|
| 94 |
+
self.layers = []
|
| 95 |
+
self.res_output_shape = input_shape
|
| 96 |
+
|
| 97 |
+
for k in range(2): # dilated conv block.
|
| 98 |
+
name = 'conv1D_{}'.format(k)
|
| 99 |
+
# with K.name_scope(name): # name scope used to make sure weights get unique names
|
| 100 |
+
conv = Conv1D(
|
| 101 |
+
filters=self.nb_filters,
|
| 102 |
+
kernel_size=self.kernel_size,
|
| 103 |
+
dilation_rate=self.dilation_rate,
|
| 104 |
+
padding=self.padding,
|
| 105 |
+
name=name,
|
| 106 |
+
kernel_initializer=self.kernel_initializer
|
| 107 |
+
)
|
| 108 |
+
if self.use_weight_norm:
|
| 109 |
+
from tensorflow_addons.layers import WeightNormalization
|
| 110 |
+
# wrap it. WeightNormalization API is different than BatchNormalization or LayerNormalization.
|
| 111 |
+
#with K.name_scope('norm_{}'.format(k)):
|
| 112 |
+
conv = WeightNormalization(conv)
|
| 113 |
+
self._build_layer(conv)
|
| 114 |
+
|
| 115 |
+
#with K.name_scope('norm_{}'.format(k)):
|
| 116 |
+
if self.use_batch_norm:
|
| 117 |
+
self._build_layer(BatchNormalization())
|
| 118 |
+
elif self.use_layer_norm:
|
| 119 |
+
self._build_layer(LayerNormalization())
|
| 120 |
+
elif self.use_weight_norm:
|
| 121 |
+
pass # done above.
|
| 122 |
+
|
| 123 |
+
# with K.name_scope('act_and_dropout_{}'.format(k)):
|
| 124 |
+
self._build_layer(Activation(self.activation, name='Act_Conv1D_{}'.format(k)))
|
| 125 |
+
self._build_layer(SpatialDropout1D(rate=self.dropout_rate, name='SDropout_{}'.format(k)))
|
| 126 |
+
|
| 127 |
+
if self.nb_filters != input_shape[-1]:
|
| 128 |
+
# 1x1 conv to match the shapes (channel dimension).
|
| 129 |
+
name = 'matching_conv1D'
|
| 130 |
+
#with K.name_scope(name):
|
| 131 |
+
# make and build this layer separately because it directly uses input_shape.
|
| 132 |
+
# 1x1 conv.
|
| 133 |
+
self.shape_match_conv = Conv1D(
|
| 134 |
+
filters=self.nb_filters,
|
| 135 |
+
kernel_size=1,
|
| 136 |
+
padding='same',
|
| 137 |
+
name=name,
|
| 138 |
+
kernel_initializer=self.kernel_initializer
|
| 139 |
+
)
|
| 140 |
+
else:
|
| 141 |
+
name = 'matching_identity'
|
| 142 |
+
self.shape_match_conv = Lambda(lambda x: x, name=name)
|
| 143 |
+
|
| 144 |
+
#with K.name_scope(name):
|
| 145 |
+
self.shape_match_conv.build(input_shape)
|
| 146 |
+
self.res_output_shape = self.shape_match_conv.compute_output_shape(input_shape)
|
| 147 |
+
|
| 148 |
+
self._build_layer(Activation(self.activation, name='Act_Conv_Blocks'))
|
| 149 |
+
self.final_activation = Activation(self.activation, name='Act_Res_Block')
|
| 150 |
+
self.final_activation.build(self.res_output_shape) # probably isn't necessary
|
| 151 |
+
|
| 152 |
+
# this is done to force Keras to add the layers in the list to self._layers
|
| 153 |
+
for layer in self.layers:
|
| 154 |
+
self.__setattr__(layer.name, layer)
|
| 155 |
+
self.__setattr__(self.shape_match_conv.name, self.shape_match_conv)
|
| 156 |
+
self.__setattr__(self.final_activation.name, self.final_activation)
|
| 157 |
+
|
| 158 |
+
super(ResidualBlock, self).build(input_shape) # done to make sure self.built is set True
|
| 159 |
+
|
| 160 |
+
def call(self, inputs, training=None, **kwargs):
|
| 161 |
+
"""
|
| 162 |
+
Returns: A tuple where the first element is the residual model tensor, and the second
|
| 163 |
+
is the skip connection tensor.
|
| 164 |
+
"""
|
| 165 |
+
# https://arxiv.org/pdf/1803.01271.pdf page 4, Figure 1 (b).
|
| 166 |
+
# x1: Dilated Conv -> Norm -> Dropout (x2).
|
| 167 |
+
# x2: Residual (1x1 matching conv - optional).
|
| 168 |
+
# Output: x1 + x2.
|
| 169 |
+
# x1 -> connected to skip connections.
|
| 170 |
+
# x1 + x2 -> connected to the next block.
|
| 171 |
+
# input
|
| 172 |
+
# x1 x2
|
| 173 |
+
# conv1D 1x1 Conv1D (optional)
|
| 174 |
+
# ...
|
| 175 |
+
# conv1D
|
| 176 |
+
# ...
|
| 177 |
+
# x1 + x2
|
| 178 |
+
x1 = inputs
|
| 179 |
+
for layer in self.layers:
|
| 180 |
+
training_flag = 'training' in dict(inspect.signature(layer.call).parameters)
|
| 181 |
+
x1 = layer(x1, training=training) if training_flag else layer(x1)
|
| 182 |
+
x2 = self.shape_match_conv(inputs)
|
| 183 |
+
x1_x2 = self.final_activation(layers.add([x2, x1], name='Add_Res'))
|
| 184 |
+
return [x1_x2, x1]
|
| 185 |
+
|
| 186 |
+
def compute_output_shape(self, input_shape):
|
| 187 |
+
return [self.res_output_shape, self.res_output_shape]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class TCN(Layer):
|
| 191 |
+
"""Creates a TCN layer.
|
| 192 |
+
Input shape:
|
| 193 |
+
A tensor of shape (batch_size, timesteps, input_dim).
|
| 194 |
+
Args:
|
| 195 |
+
nb_filters: The number of filters to use in the convolutional layers. Can be a list.
|
| 196 |
+
kernel_size: The size of the kernel to use in each convolutional layer.
|
| 197 |
+
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
|
| 198 |
+
nb_stacks : The number of stacks of residual blocks to use.
|
| 199 |
+
padding: The padding to use in the convolutional layers, 'causal' or 'same'.
|
| 200 |
+
use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
|
| 201 |
+
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
|
| 202 |
+
activation: The activation used in the residual blocks o = Activation(x + F(x)).
|
| 203 |
+
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
|
| 204 |
+
kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
|
| 205 |
+
use_batch_norm: Whether to use batch normalization in the residual layers or not.
|
| 206 |
+
use_layer_norm: Whether to use layer normalization in the residual layers or not.
|
| 207 |
+
use_weight_norm: Whether to use weight normalization in the residual layers or not.
|
| 208 |
+
kwargs: Any other arguments for configuring parent class Layer. For example "name=str", Name of the model.
|
| 209 |
+
Use unique names when using multiple TCN.
|
| 210 |
+
Returns:
|
| 211 |
+
A TCN layer.
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(self,
|
| 215 |
+
nb_filters=256,
|
| 216 |
+
kernel_size=5,
|
| 217 |
+
nb_stacks=1,
|
| 218 |
+
dilations=(1, 2, 4, 8, 16, 32),
|
| 219 |
+
padding='causal',
|
| 220 |
+
use_skip_connections=True,
|
| 221 |
+
dropout_rate=0.0,
|
| 222 |
+
return_sequences=False,
|
| 223 |
+
activation='relu',
|
| 224 |
+
kernel_initializer='he_normal',
|
| 225 |
+
use_batch_norm=False,
|
| 226 |
+
use_layer_norm=False,
|
| 227 |
+
use_weight_norm=False,
|
| 228 |
+
**kwargs):
|
| 229 |
+
print("nb_filters:", nb_filters, "kernel_size", kernel_size)
|
| 230 |
+
self.return_sequences = return_sequences
|
| 231 |
+
self.dropout_rate = dropout_rate
|
| 232 |
+
self.use_skip_connections = use_skip_connections
|
| 233 |
+
self.dilations = dilations
|
| 234 |
+
self.nb_stacks = nb_stacks
|
| 235 |
+
self.kernel_size = kernel_size
|
| 236 |
+
self.nb_filters = nb_filters
|
| 237 |
+
self.activation_name = activation
|
| 238 |
+
self.padding = padding
|
| 239 |
+
self.kernel_initializer = kernel_initializer
|
| 240 |
+
self.use_batch_norm = use_batch_norm
|
| 241 |
+
self.use_layer_norm = use_layer_norm
|
| 242 |
+
self.use_weight_norm = use_weight_norm
|
| 243 |
+
self.skip_connections = []
|
| 244 |
+
self.residual_blocks = []
|
| 245 |
+
self.layers_outputs = []
|
| 246 |
+
self.build_output_shape = None
|
| 247 |
+
self.slicer_layer = None # in case return_sequence=False
|
| 248 |
+
self.output_slice_index = None # in case return_sequence=False
|
| 249 |
+
self.padding_same_and_time_dim_unknown = False # edge case if padding='same' and time_dim = None
|
| 250 |
+
|
| 251 |
+
if self.use_batch_norm + self.use_layer_norm + self.use_weight_norm > 1:
|
| 252 |
+
raise ValueError('Only one normalization can be specified at once.')
|
| 253 |
+
|
| 254 |
+
if isinstance(self.nb_filters, list):
|
| 255 |
+
assert len(self.nb_filters) == len(self.dilations)
|
| 256 |
+
if len(set(self.nb_filters)) > 1 and self.use_skip_connections:
|
| 257 |
+
raise ValueError('Skip connections are not compatible '
|
| 258 |
+
'with a list of filters, unless they are all equal.')
|
| 259 |
+
|
| 260 |
+
if padding != 'causal' and padding != 'same':
|
| 261 |
+
raise ValueError("Only 'causal' or 'same' padding are compatible for this layer.")
|
| 262 |
+
|
| 263 |
+
# initialize parent class
|
| 264 |
+
super(TCN, self).__init__(**kwargs)
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def receptive_field(self):
|
| 268 |
+
return 1 + 2 * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations)
|
| 269 |
+
|
| 270 |
+
def build(self, input_shape):
|
| 271 |
+
|
| 272 |
+
# member to hold current output shape of the layer for building purposes
|
| 273 |
+
self.build_output_shape = input_shape
|
| 274 |
+
|
| 275 |
+
# list to hold all the member ResidualBlocks
|
| 276 |
+
self.residual_blocks = []
|
| 277 |
+
total_num_blocks = self.nb_stacks * len(self.dilations)
|
| 278 |
+
if not self.use_skip_connections:
|
| 279 |
+
total_num_blocks += 1 # cheap way to do a false case for below
|
| 280 |
+
|
| 281 |
+
for s in range(self.nb_stacks):
|
| 282 |
+
for i, d in enumerate(self.dilations):
|
| 283 |
+
res_block_filters = self.nb_filters[i] if isinstance(self.nb_filters, list) else self.nb_filters
|
| 284 |
+
self.residual_blocks.append(ResidualBlock(dilation_rate=d,
|
| 285 |
+
nb_filters=res_block_filters,
|
| 286 |
+
kernel_size=self.kernel_size,
|
| 287 |
+
padding=self.padding,
|
| 288 |
+
activation=self.activation_name,
|
| 289 |
+
dropout_rate=self.dropout_rate,
|
| 290 |
+
use_batch_norm=self.use_batch_norm,
|
| 291 |
+
use_layer_norm=self.use_layer_norm,
|
| 292 |
+
use_weight_norm=self.use_weight_norm,
|
| 293 |
+
kernel_initializer=self.kernel_initializer,
|
| 294 |
+
name='residual_block_{}'.format(len(self.residual_blocks))))
|
| 295 |
+
# build newest residual block
|
| 296 |
+
self.residual_blocks[-1].build(self.build_output_shape)
|
| 297 |
+
self.build_output_shape = self.residual_blocks[-1].res_output_shape
|
| 298 |
+
|
| 299 |
+
# this is done to force keras to add the layers in the list to self._layers
|
| 300 |
+
for layer in self.residual_blocks:
|
| 301 |
+
self.__setattr__(layer.name, layer)
|
| 302 |
+
|
| 303 |
+
self.output_slice_index = None
|
| 304 |
+
if self.padding == 'same':
|
| 305 |
+
time = self.build_output_shape.as_list()[1]
|
| 306 |
+
if time is not None: # if time dimension is defined. e.g. shape = (bs, 500, input_dim).
|
| 307 |
+
self.output_slice_index = int(self.build_output_shape.as_list()[1] / 2)
|
| 308 |
+
else:
|
| 309 |
+
# It will known at call time. c.f. self.call.
|
| 310 |
+
self.padding_same_and_time_dim_unknown = True
|
| 311 |
+
|
| 312 |
+
else:
|
| 313 |
+
self.output_slice_index = -1 # causal case.
|
| 314 |
+
self.slicer_layer = Lambda(lambda tt: tt[:, self.output_slice_index, :], name='Slice_Output')
|
| 315 |
+
|
| 316 |
+
if type(self.build_output_shape) == tuple:
|
| 317 |
+
static = list(self.build_output_shape)
|
| 318 |
+
else:
|
| 319 |
+
static = self.build_output_shape.as_list()
|
| 320 |
+
self.slicer_layer.build(static)
|
| 321 |
+
|
| 322 |
+
def compute_output_shape(self, input_shape):
|
| 323 |
+
"""
|
| 324 |
+
Overridden in case keras uses it somewhere... no idea. Just trying to avoid future errors.
|
| 325 |
+
"""
|
| 326 |
+
if not self.built:
|
| 327 |
+
self.build(input_shape)
|
| 328 |
+
if not self.return_sequences:
|
| 329 |
+
batch_size = self.build_output_shape[0]
|
| 330 |
+
batch_size = batch_size.value if hasattr(batch_size, 'value') else batch_size
|
| 331 |
+
nb_filters = self.build_output_shape[-1]
|
| 332 |
+
return [batch_size, nb_filters]
|
| 333 |
+
else:
|
| 334 |
+
# Compatibility tensorflow 1.x
|
| 335 |
+
return [v.value if hasattr(v, 'value') else v for v in self.build_output_shape]
|
| 336 |
+
|
| 337 |
+
def call(self, inputs, training=None, **kwargs):
|
| 338 |
+
x = inputs
|
| 339 |
+
self.layers_outputs = [x]
|
| 340 |
+
self.skip_connections = []
|
| 341 |
+
for res_block in self.residual_blocks:
|
| 342 |
+
# try:
|
| 343 |
+
# x, skip_out = res_block(x, training=training)
|
| 344 |
+
# except TypeError: # compatibility with tensorflow 1.x
|
| 345 |
+
# x, skip_out = res_block(K.cast(x, 'float32'), training=training)
|
| 346 |
+
x, skip_out = res_block(x, training=training)
|
| 347 |
+
|
| 348 |
+
self.skip_connections.append(skip_out)
|
| 349 |
+
self.layers_outputs.append(x)
|
| 350 |
+
|
| 351 |
+
if self.use_skip_connections:
|
| 352 |
+
x = layers.add(self.skip_connections, name='Add_Skip_Connections')
|
| 353 |
+
self.layers_outputs.append(x)
|
| 354 |
+
|
| 355 |
+
if not self.return_sequences:
|
| 356 |
+
# case: time dimension is unknown. e.g. (bs, None, input_dim).
|
| 357 |
+
if self.padding_same_and_time_dim_unknown:
|
| 358 |
+
self.output_slice_index = K.shape(self.layers_outputs[-1])[1] // 2
|
| 359 |
+
x = self.slicer_layer(x)
|
| 360 |
+
self.layers_outputs.append(x)
|
| 361 |
+
return x
|
| 362 |
+
|
| 363 |
+
def get_config(self):
|
| 364 |
+
"""
|
| 365 |
+
Returns the config of a the layer. This is used for saving and loading from a model
|
| 366 |
+
:return: python dictionary with specs to rebuild layer
|
| 367 |
+
"""
|
| 368 |
+
config = super(TCN, self).get_config()
|
| 369 |
+
config['nb_filters'] = self.nb_filters
|
| 370 |
+
config['kernel_size'] = self.kernel_size
|
| 371 |
+
config['nb_stacks'] = self.nb_stacks
|
| 372 |
+
config['dilations'] = self.dilations
|
| 373 |
+
config['padding'] = self.padding
|
| 374 |
+
config['use_skip_connections'] = self.use_skip_connections
|
| 375 |
+
config['dropout_rate'] = self.dropout_rate
|
| 376 |
+
config['return_sequences'] = self.return_sequences
|
| 377 |
+
config['activation'] = self.activation_name
|
| 378 |
+
config['use_batch_norm'] = self.use_batch_norm
|
| 379 |
+
config['use_layer_norm'] = self.use_layer_norm
|
| 380 |
+
config['use_weight_norm'] = self.use_weight_norm
|
| 381 |
+
config['kernel_initializer'] = self.kernel_initializer
|
| 382 |
+
return config
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def compiled_tcn(num_feat, # type: int
|
| 386 |
+
num_classes, # type: int
|
| 387 |
+
nb_filters, # type: int
|
| 388 |
+
kernel_size, # type: int
|
| 389 |
+
dilations, # type: List[int]
|
| 390 |
+
nb_stacks, # type: int
|
| 391 |
+
max_len, # type: int
|
| 392 |
+
output_len=1, # type: int
|
| 393 |
+
padding='causal', # type: str
|
| 394 |
+
use_skip_connections=False, # type: bool
|
| 395 |
+
return_sequences=True,
|
| 396 |
+
regression=False, # type: bool
|
| 397 |
+
dropout_rate=0.05, # type: float
|
| 398 |
+
name='tcn', # type: str,
|
| 399 |
+
kernel_initializer='he_normal', # type: str,
|
| 400 |
+
activation='relu', # type:str,
|
| 401 |
+
opt='adam',
|
| 402 |
+
lr=0.002,
|
| 403 |
+
use_batch_norm=False,
|
| 404 |
+
use_layer_norm=False,
|
| 405 |
+
use_weight_norm=False):
|
| 406 |
+
# type: (...) -> Model
|
| 407 |
+
"""Creates a compiled TCN model for a given task (i.e. regression or classification).
|
| 408 |
+
Classification uses a sparse categorical loss. Please input class ids and not one-hot encodings.
|
| 409 |
+
Args:
|
| 410 |
+
num_feat: The number of features of your input, i.e. the last dimension of: (batch_size, timesteps, input_dim).
|
| 411 |
+
num_classes: The size of the final dense layer, how many classes we are predicting.
|
| 412 |
+
nb_filters: The number of filters to use in the convolutional layers.
|
| 413 |
+
kernel_size: The size of the kernel to use in each convolutional layer.
|
| 414 |
+
dilations: The list of the dilations. Example is: [1, 2, 4, 8, 16, 32, 64].
|
| 415 |
+
nb_stacks : The number of stacks of residual blocks to use.
|
| 416 |
+
max_len: The maximum sequence length, use None if the sequence length is dynamic.
|
| 417 |
+
padding: The padding to use in the convolutional layers.
|
| 418 |
+
use_skip_connections: Boolean. If we want to add skip connections from input to each residual blocK.
|
| 419 |
+
return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence.
|
| 420 |
+
regression: Whether the output should be continuous or discrete.
|
| 421 |
+
dropout_rate: Float between 0 and 1. Fraction of the input units to drop.
|
| 422 |
+
activation: The activation used in the residual blocks o = Activation(x + F(x)).
|
| 423 |
+
name: Name of the model. Useful when having multiple TCN.
|
| 424 |
+
kernel_initializer: Initializer for the kernel weights matrix (Conv1D).
|
| 425 |
+
opt: Optimizer name.
|
| 426 |
+
lr: Learning rate.
|
| 427 |
+
use_batch_norm: Whether to use batch normalization in the residual layers or not.
|
| 428 |
+
use_layer_norm: Whether to use layer normalization in the residual layers or not.
|
| 429 |
+
use_weight_norm: Whether to use weight normalization in the residual layers or not.
|
| 430 |
+
Returns:
|
| 431 |
+
A compiled keras TCN.
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
dilations = adjust_dilations(dilations)
|
| 435 |
+
|
| 436 |
+
input_layer = Input(shape=(max_len, num_feat))
|
| 437 |
+
|
| 438 |
+
x = TCN(nb_filters, kernel_size, nb_stacks, dilations, padding,
|
| 439 |
+
use_skip_connections, dropout_rate, return_sequences,
|
| 440 |
+
activation, kernel_initializer, use_batch_norm, use_layer_norm,
|
| 441 |
+
use_weight_norm, name=name)(input_layer)
|
| 442 |
+
|
| 443 |
+
print('x.shape=', x.shape)
|
| 444 |
+
|
| 445 |
+
def get_opt():
|
| 446 |
+
if opt == 'adam':
|
| 447 |
+
return optimizers.Adam(lr=lr, clipnorm=1.)
|
| 448 |
+
elif opt == 'rmsprop':
|
| 449 |
+
return optimizers.RMSprop(lr=lr, clipnorm=1.)
|
| 450 |
+
else:
|
| 451 |
+
raise Exception('Only Adam and RMSProp are available here')
|
| 452 |
+
|
| 453 |
+
if not regression:
|
| 454 |
+
# classification
|
| 455 |
+
print('asdasfdasfa')
|
| 456 |
+
x = Dense(num_classes)(x)
|
| 457 |
+
x = Activation('softmax')(x)
|
| 458 |
+
output_layer = x
|
| 459 |
+
model = Model(input_layer, output_layer)
|
| 460 |
+
|
| 461 |
+
# https://github.com/keras-team/keras/pull/11373
|
| 462 |
+
# It's now in Keras@master but still not available with pip.
|
| 463 |
+
# TODO remove later.
|
| 464 |
+
def accuracy(y_true, y_pred):
|
| 465 |
+
# reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
|
| 466 |
+
if K.ndim(y_true) == K.ndim(y_pred):
|
| 467 |
+
y_true = K.squeeze(y_true, -1)
|
| 468 |
+
# convert dense predictions to labels
|
| 469 |
+
y_pred_labels = K.argmax(y_pred, axis=-1)
|
| 470 |
+
y_pred_labels = K.cast(y_pred_labels, KK.floatx())
|
| 471 |
+
return K.cast(K.equal(y_true, y_pred_labels), KK.floatx())
|
| 472 |
+
|
| 473 |
+
model.compile(get_opt(), loss='sparse_categorical_crossentropy', metrics=[accuracy])
|
| 474 |
+
else:
|
| 475 |
+
# regression
|
| 476 |
+
x = Dense(output_len)(x)
|
| 477 |
+
x = Activation('linear')(x)
|
| 478 |
+
output_layer = x
|
| 479 |
+
model = Model(input_layer, output_layer)
|
| 480 |
+
model.compile(get_opt(), loss='mean_squared_error')
|
| 481 |
+
print('model.x = {}'.format(input_layer.shape))
|
| 482 |
+
print('model.y = {}'.format(output_layer.shape))
|
| 483 |
+
return model
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def tcn_full_summary(model: Model, expand_residual_blocks=True):
|
| 487 |
+
|
| 488 |
+
layers = model._layers.copy() # store existing layers
|
| 489 |
+
model._layers.clear() # clear layers
|
| 490 |
+
|
| 491 |
+
for i in range(len(layers)):
|
| 492 |
+
if isinstance(layers[i], TCN):
|
| 493 |
+
for layer in layers[i]._layers:
|
| 494 |
+
if not isinstance(layer, ResidualBlock):
|
| 495 |
+
if not hasattr(layer, '__iter__'):
|
| 496 |
+
model._layers.append(layer)
|
| 497 |
+
else:
|
| 498 |
+
if expand_residual_blocks:
|
| 499 |
+
for lyr in layer._layers:
|
| 500 |
+
if not hasattr(lyr, '__iter__'):
|
| 501 |
+
model._layers.append(lyr)
|
| 502 |
+
else:
|
| 503 |
+
model._layers.append(layer)
|
| 504 |
+
else:
|
| 505 |
+
model._layers.append(layers[i])
|
| 506 |
+
|
| 507 |
+
model.summary() # print summary
|
| 508 |
+
|
| 509 |
+
# restore original layers
|
| 510 |
+
model._layers.clear()
|
| 511 |
+
[model._layers.append(lyr) for lyr in layers]
|
pages/inference.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import random
|
| 8 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 9 |
+
import numpy as np
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
from model.lstm import LSTMModel
|
| 13 |
+
from model.tcn import TCNModel
|
| 14 |
+
from model.tcn import move_custom_layers_to_device
|
| 15 |
+
from utils.lowlevel import LowLevel
|
| 16 |
+
from utils.highlevel import HighLevel
|
| 17 |
+
from utils.midpoint import MidPoint
|
| 18 |
+
|
| 19 |
+
from utils.transform import compute_gradient
|
| 20 |
+
|
| 21 |
+
st.set_page_config(page_title="Inference", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto")
|
| 22 |
+
|
| 23 |
+
def uniform_sampling(data, n_sample):
|
| 24 |
+
k = len(data) // n_sample
|
| 25 |
+
return data[::k]
|
| 26 |
+
|
| 27 |
+
def low_level(option_time, slider_sample_orbit, progress_bar):
|
| 28 |
+
time.sleep(0.1)
|
| 29 |
+
low_level_total_start_time = time.time()
|
| 30 |
+
low_level_30000_start_time = time.time()
|
| 31 |
+
|
| 32 |
+
lowlevelhelper = LowLevel(j=slider_sample_orbit)
|
| 33 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
|
| 34 |
+
|
| 35 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
| 36 |
+
a2 = 1 - 2 * a1
|
| 37 |
+
jn = 0
|
| 38 |
+
t = 0.1
|
| 39 |
+
|
| 40 |
+
# Calculate the total number of iterations for the progress bar update
|
| 41 |
+
total_iterations = (float(option_time) - t) / h
|
| 42 |
+
current_iteration = 0
|
| 43 |
+
|
| 44 |
+
original_low_level_data = []
|
| 45 |
+
|
| 46 |
+
while t < float(option_time):
|
| 47 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 48 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 49 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 50 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
| 51 |
+
|
| 52 |
+
t = t + h
|
| 53 |
+
|
| 54 |
+
if jn % 10 == 0:
|
| 55 |
+
original_low_level_data.append([b, x, y, z, px, py, pz])
|
| 56 |
+
# Update progress bar
|
| 57 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
| 58 |
+
progress_bar.progress(progress_percentage)
|
| 59 |
+
|
| 60 |
+
if jn == 300000:
|
| 61 |
+
low_level_30000_end_time = time.time()
|
| 62 |
+
low_level_30000_execute_time = low_level_30000_end_time - low_level_30000_start_time
|
| 63 |
+
low_level_2000_start_time = time.time()
|
| 64 |
+
jn = jn + 1
|
| 65 |
+
current_iteration += 1
|
| 66 |
+
|
| 67 |
+
progress_bar.progress(100)
|
| 68 |
+
|
| 69 |
+
low_level_2000_end_time = time.time()
|
| 70 |
+
low_level_2000_execute_time = low_level_2000_end_time - low_level_2000_start_time
|
| 71 |
+
low_level_total_end_time = time.time()
|
| 72 |
+
low_level_total_execute_time = low_level_total_end_time - low_level_total_start_time
|
| 73 |
+
|
| 74 |
+
result = uniform_sampling(np.array(original_low_level_data), n_sample=int(option_time/100))
|
| 75 |
+
|
| 76 |
+
return low_level_30000_execute_time, low_level_2000_execute_time, low_level_total_execute_time, result
|
| 77 |
+
|
| 78 |
+
def high_level(option_time, slider_sample_orbit, progress_bar):
|
| 79 |
+
time.sleep(0.1)
|
| 80 |
+
high_level_total_start_time = time.time()
|
| 81 |
+
high_level_30000_start_time = time.time()
|
| 82 |
+
|
| 83 |
+
highlevelhelper = HighLevel(j=slider_sample_orbit)
|
| 84 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = highlevelhelper.initial()
|
| 85 |
+
|
| 86 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
| 87 |
+
a2 = 1 - 2 * a1
|
| 88 |
+
jn = 0
|
| 89 |
+
t = 0.1
|
| 90 |
+
|
| 91 |
+
# Calculate the total number of iterations for the progress bar update
|
| 92 |
+
total_iterations = (float(option_time) - t) / h
|
| 93 |
+
current_iteration = 0
|
| 94 |
+
|
| 95 |
+
original_high_level_data = []
|
| 96 |
+
|
| 97 |
+
while t < float(option_time):
|
| 98 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 99 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 100 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 101 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = highlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
| 102 |
+
|
| 103 |
+
t = t + h
|
| 104 |
+
vx, vy, vz, vpx, vpy, vpz, e = highlevelhelper.f(x, y, z, px, py, pz, b)
|
| 105 |
+
en = np.asarray(e).astype(np.float64)
|
| 106 |
+
|
| 107 |
+
if jn % 10 == 0:
|
| 108 |
+
original_high_level_data.append([b, x, y, z, px, py, pz])
|
| 109 |
+
if jn == 300000:
|
| 110 |
+
high_level_30000_end_time = time.time()
|
| 111 |
+
high_level_30000_execute_time = high_level_30000_end_time - high_level_30000_start_time
|
| 112 |
+
high_level_2000_start_time = time.time()
|
| 113 |
+
jn = jn + 1
|
| 114 |
+
|
| 115 |
+
# Update progress bar
|
| 116 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
| 117 |
+
progress_bar.progress(progress_percentage)
|
| 118 |
+
current_iteration += 1
|
| 119 |
+
|
| 120 |
+
progress_bar.progress(100)
|
| 121 |
+
high_level_2000_end_time = time.time()
|
| 122 |
+
high_level_2000_execute_time = high_level_2000_end_time - high_level_2000_start_time
|
| 123 |
+
high_level_total_end_time = time.time()
|
| 124 |
+
high_level_total_execute_time = high_level_total_end_time - high_level_total_start_time
|
| 125 |
+
|
| 126 |
+
result = uniform_sampling(np.array(original_high_level_data), n_sample=int(option_time / 100))
|
| 127 |
+
|
| 128 |
+
return high_level_30000_execute_time, high_level_2000_execute_time, high_level_total_execute_time, result
|
| 129 |
+
|
| 130 |
+
def midpoint(option_time, slider_sample_orbit, progress_bar):
|
| 131 |
+
time.sleep(0.1)
|
| 132 |
+
mid_point_total_start_time = time.time()
|
| 133 |
+
mid_point_30000_start_time = time.time()
|
| 134 |
+
|
| 135 |
+
midpointhelper = MidPoint(j=slider_sample_orbit)
|
| 136 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
|
| 137 |
+
|
| 138 |
+
#en0 = np.asarray(e0).astype(np.float64)
|
| 139 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
| 140 |
+
a2 = 1 - 2 * a1
|
| 141 |
+
jn = 0
|
| 142 |
+
t = 0.1
|
| 143 |
+
|
| 144 |
+
# Calculate the total number of iterations for the progress bar update
|
| 145 |
+
total_iterations = (float(option_time) - t) / h
|
| 146 |
+
current_iteration = 0
|
| 147 |
+
|
| 148 |
+
original_mid_point_data = []
|
| 149 |
+
|
| 150 |
+
while t < float(option_time):
|
| 151 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 152 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 153 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 154 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
| 155 |
+
|
| 156 |
+
t = t + h
|
| 157 |
+
|
| 158 |
+
if jn % 10 == 0:
|
| 159 |
+
original_mid_point_data.append([b, x, y, z, px, py, pz])
|
| 160 |
+
if jn == 300000:
|
| 161 |
+
mid_point_30000_end_time = time.time()
|
| 162 |
+
mid_point_30000_execute_time = mid_point_30000_end_time - mid_point_30000_start_time
|
| 163 |
+
mid_point_2000_start_time = time.time()
|
| 164 |
+
jn = jn + 1
|
| 165 |
+
|
| 166 |
+
# Update progress bar
|
| 167 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
| 168 |
+
progress_bar.progress(progress_percentage)
|
| 169 |
+
current_iteration += 1
|
| 170 |
+
|
| 171 |
+
#mid_point_df.to_excel('mid_point_df_output.xlsx', index=False)
|
| 172 |
+
progress_bar.progress(100)
|
| 173 |
+
mid_point_2000_end_time = time.time()
|
| 174 |
+
mid_point_2000_execute_time = mid_point_2000_end_time - mid_point_2000_start_time
|
| 175 |
+
mid_point_total_end_time = time.time()
|
| 176 |
+
mid_point_total_execute_time = mid_point_total_end_time - mid_point_total_start_time
|
| 177 |
+
|
| 178 |
+
result = uniform_sampling(np.array(original_mid_point_data), n_sample=int(option_time / 100))
|
| 179 |
+
|
| 180 |
+
return mid_point_30000_execute_time, mid_point_2000_execute_time, mid_point_total_execute_time, result
|
| 181 |
+
|
| 182 |
+
def low_level_lstm(slider_sample_orbit, lstm_progress_bar):
|
| 183 |
+
time.sleep(0.1)
|
| 184 |
+
total_start_time = time.time()
|
| 185 |
+
|
| 186 |
+
lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
|
| 187 |
+
lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
|
| 188 |
+
lstm_model.to("cpu")
|
| 189 |
+
lstm_model.eval()
|
| 190 |
+
|
| 191 |
+
# Initialize variables for the classical method
|
| 192 |
+
lowlevelhelper = LowLevel(j=slider_sample_orbit)
|
| 193 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
|
| 194 |
+
|
| 195 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
| 196 |
+
a2 = 1 - 2 * a1
|
| 197 |
+
jn = 0
|
| 198 |
+
t = 0.1
|
| 199 |
+
|
| 200 |
+
# Calculate the total number of iterations for the progress bar update
|
| 201 |
+
total_iterations = (float(30000) - t) / h
|
| 202 |
+
current_iteration = 0
|
| 203 |
+
|
| 204 |
+
original_low_level_data = []
|
| 205 |
+
|
| 206 |
+
low_level_start_time = time.time()
|
| 207 |
+
|
| 208 |
+
# Perform classical method prediction for the initial segment
|
| 209 |
+
while t < float(30000):
|
| 210 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 211 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 212 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 213 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
| 214 |
+
t = t + h
|
| 215 |
+
|
| 216 |
+
if jn % 10 == 0:
|
| 217 |
+
original_low_level_data.append([b, x, y, z, px, py, pz])
|
| 218 |
+
# Update progress bar
|
| 219 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
| 220 |
+
lstm_progress_bar.progress(progress_percentage)
|
| 221 |
+
|
| 222 |
+
jn = jn + 1
|
| 223 |
+
current_iteration += 1
|
| 224 |
+
|
| 225 |
+
original_low_level_data = np.array(original_low_level_data)
|
| 226 |
+
low_level_end_time = time.time()
|
| 227 |
+
low_level_data = original_low_level_data.copy()
|
| 228 |
+
low_level_data = uniform_sampling(low_level_data, n_sample=300)
|
| 229 |
+
scaler = MinMaxScaler()
|
| 230 |
+
low_level_data = scaler.fit_transform(low_level_data)
|
| 231 |
+
low_level_data = torch.tensor(np.stack(low_level_data)).float()
|
| 232 |
+
low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)
|
| 233 |
+
|
| 234 |
+
lstm_start_time = time.time()
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
lstm_preds = lstm_model(low_level_data[:, 100:300, :])
|
| 237 |
+
lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
|
| 238 |
+
|
| 239 |
+
original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)
|
| 240 |
+
|
| 241 |
+
lstm_end_time = time.time()
|
| 242 |
+
lstm_progress_bar.progress(100)
|
| 243 |
+
|
| 244 |
+
combined_preds = np.concatenate([original_low_level_data, lstm_innv_preds], axis=0)
|
| 245 |
+
|
| 246 |
+
lstm_total_time = lstm_end_time - lstm_start_time
|
| 247 |
+
low_level_total_time = low_level_end_time - low_level_start_time
|
| 248 |
+
|
| 249 |
+
total_end_time = time.time()
|
| 250 |
+
total_time = total_end_time - total_start_time
|
| 251 |
+
|
| 252 |
+
return low_level_total_time, lstm_total_time, total_time, combined_preds
|
| 253 |
+
|
| 254 |
+
def mid_point_lstm(slider_sample_orbit, lstm_progress_bar):
|
| 255 |
+
time.sleep(0.1)
|
| 256 |
+
total_start_time = time.time()
|
| 257 |
+
|
| 258 |
+
lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
|
| 259 |
+
lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
|
| 260 |
+
lstm_model.to("cpu")
|
| 261 |
+
lstm_model.eval()
|
| 262 |
+
|
| 263 |
+
midpointhelper = MidPoint(j=slider_sample_orbit)
|
| 264 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
|
| 265 |
+
|
| 266 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
| 267 |
+
a2 = 1 - 2 * a1
|
| 268 |
+
jn = 0
|
| 269 |
+
t = 0.1
|
| 270 |
+
|
| 271 |
+
# Calculate the total number of iterations for the progress bar update
|
| 272 |
+
total_iterations = (float(30000) - t) / h
|
| 273 |
+
current_iteration = 0
|
| 274 |
+
|
| 275 |
+
original_mid_point_data = []
|
| 276 |
+
|
| 277 |
+
mid_point_start_time = time.time()
|
| 278 |
+
|
| 279 |
+
while t < float(30000):
|
| 280 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 281 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 282 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 283 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa,pya, pza)
|
| 284 |
+
|
| 285 |
+
t = t + h
|
| 286 |
+
|
| 287 |
+
if jn % 10 == 0:
|
| 288 |
+
original_mid_point_data.append([b, x, y, z, px, py, pz])
|
| 289 |
+
# Update progress bar
|
| 290 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
| 291 |
+
lstm_progress_bar.progress(progress_percentage)
|
| 292 |
+
jn = jn + 1
|
| 293 |
+
current_iteration += 1
|
| 294 |
+
|
| 295 |
+
original_mid_point_data = np.array(original_mid_point_data)
|
| 296 |
+
mid_point_end_time = time.time()
|
| 297 |
+
mid_point_data = original_mid_point_data.copy()
|
| 298 |
+
mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
|
| 299 |
+
scaler = MinMaxScaler()
|
| 300 |
+
mid_point_data = scaler.fit_transform(mid_point_data)
|
| 301 |
+
mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
|
| 302 |
+
mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)
|
| 303 |
+
|
| 304 |
+
lstm_start_time = time.time()
|
| 305 |
+
with torch.no_grad():
|
| 306 |
+
lstm_preds = lstm_model(mid_point_data[:, 100:300, :])
|
| 307 |
+
lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
|
| 308 |
+
|
| 309 |
+
original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)
|
| 310 |
+
|
| 311 |
+
lstm_end_time = time.time()
|
| 312 |
+
lstm_progress_bar.progress(100)
|
| 313 |
+
|
| 314 |
+
combined_preds = np.concatenate([original_mid_point_data, lstm_innv_preds], axis=0)
|
| 315 |
+
|
| 316 |
+
lstm_total_time = lstm_end_time - lstm_start_time
|
| 317 |
+
mid_point_total_time = mid_point_end_time - mid_point_start_time
|
| 318 |
+
|
| 319 |
+
total_end_time = time.time()
|
| 320 |
+
total_time = total_end_time - total_start_time
|
| 321 |
+
|
| 322 |
+
return mid_point_total_time, lstm_total_time, total_time, combined_preds
|
| 323 |
+
|
| 324 |
+
def low_level_tcn(slider_sample_orbit, tcn_progress_bar):
|
| 325 |
+
time.sleep(0.1)
|
| 326 |
+
total_start_time = time.time()
|
| 327 |
+
|
| 328 |
+
tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
|
| 329 |
+
tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
|
| 330 |
+
move_custom_layers_to_device(tcn_model, "cpu")
|
| 331 |
+
tcn_model.eval()
|
| 332 |
+
|
| 333 |
+
# Initialize variables for the classical method
|
| 334 |
+
lowlevelhelper = LowLevel(j=slider_sample_orbit)
|
| 335 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = lowlevelhelper.initial()
|
| 336 |
+
|
| 337 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
| 338 |
+
a2 = 1 - 2 * a1
|
| 339 |
+
jn = 0
|
| 340 |
+
t = 0.1
|
| 341 |
+
|
| 342 |
+
# Calculate the total number of iterations for the progress bar update
|
| 343 |
+
total_iterations = (float(30000) - t) / h
|
| 344 |
+
current_iteration = 0
|
| 345 |
+
|
| 346 |
+
original_low_level_data = []
|
| 347 |
+
|
| 348 |
+
low_level_start_time = time.time()
|
| 349 |
+
|
| 350 |
+
# Perform classical method prediction for the initial segment
|
| 351 |
+
while t < float(30000):
|
| 352 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 353 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 354 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 355 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = lowlevelhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)
|
| 356 |
+
|
| 357 |
+
t = t + h
|
| 358 |
+
|
| 359 |
+
if jn % 10 == 0:
|
| 360 |
+
original_low_level_data.append([b, x, y, z, px, py, pz])
|
| 361 |
+
# Update progress bar
|
| 362 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
| 363 |
+
tcn_progress_bar.progress(progress_percentage)
|
| 364 |
+
|
| 365 |
+
jn = jn + 1
|
| 366 |
+
current_iteration += 1
|
| 367 |
+
|
| 368 |
+
original_low_level_data = np.array(original_low_level_data)
|
| 369 |
+
low_level_end_time = time.time()
|
| 370 |
+
low_level_data = original_low_level_data.copy()
|
| 371 |
+
low_level_data = uniform_sampling(low_level_data, n_sample=300)
|
| 372 |
+
scaler = MinMaxScaler()
|
| 373 |
+
low_level_data = scaler.fit_transform(low_level_data)
|
| 374 |
+
low_level_data = torch.tensor(np.stack(low_level_data)).float()
|
| 375 |
+
low_level_data = torch.stack([compute_gradient(i, degree=2) for i in low_level_data]).unsqueeze(0)
|
| 376 |
+
|
| 377 |
+
tcn_start_time = time.time()
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
tcn_preds = None
|
| 380 |
+
for i in range(20):
|
| 381 |
+
if i == 0:
|
| 382 |
+
tcn_preds = tcn_model(low_level_data[:, :300, :])
|
| 383 |
+
else:
|
| 384 |
+
gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
|
| 385 |
+
output = tcn_model(torch.cat([low_level_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
|
| 386 |
+
tcn_preds = torch.cat([tcn_preds, output], dim=1)
|
| 387 |
+
tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
|
| 388 |
+
|
| 389 |
+
original_low_level_data = uniform_sampling(original_low_level_data, n_sample=300)
|
| 390 |
+
|
| 391 |
+
tcn_end_time = time.time()
|
| 392 |
+
tcn_progress_bar.progress(100)
|
| 393 |
+
|
| 394 |
+
combined_preds = np.concatenate([original_low_level_data, tcn_innv_preds], axis=0)
|
| 395 |
+
|
| 396 |
+
tcn_total_time = tcn_end_time - tcn_start_time
|
| 397 |
+
low_level_total_time = low_level_end_time - low_level_start_time
|
| 398 |
+
|
| 399 |
+
total_end_time = time.time()
|
| 400 |
+
total_time = total_end_time - total_start_time
|
| 401 |
+
|
| 402 |
+
return low_level_total_time, tcn_total_time, total_time, combined_preds
|
| 403 |
+
|
| 404 |
+
def mid_point_tcn(slider_sample_orbit, tcn_progress_bar):
|
| 405 |
+
time.sleep(0.1)
|
| 406 |
+
total_start_time = time.time()
|
| 407 |
+
|
| 408 |
+
tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
|
| 409 |
+
tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
|
| 410 |
+
move_custom_layers_to_device(tcn_model, "cpu")
|
| 411 |
+
tcn_model.eval()
|
| 412 |
+
|
| 413 |
+
# Initialize variables for the classical method
|
| 414 |
+
midpointhelper = MidPoint(j=slider_sample_orbit)
|
| 415 |
+
j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza = midpointhelper.initial()
|
| 416 |
+
|
| 417 |
+
a1 = 1 / (2 - 2 ** (1 / 3))
|
| 418 |
+
a2 = 1 - 2 * a1
|
| 419 |
+
jn = 0
|
| 420 |
+
t = 0.1
|
| 421 |
+
|
| 422 |
+
# Calculate the total number of iterations for the progress bar update
|
| 423 |
+
total_iterations = (float(30000) - t) / h
|
| 424 |
+
current_iteration = 0
|
| 425 |
+
|
| 426 |
+
original_mid_point_data = []
|
| 427 |
+
|
| 428 |
+
mid_point_start_time = time.time()
|
| 429 |
+
|
| 430 |
+
# Perform classical method prediction for the initial segment
|
| 431 |
+
while t < float(30000):
|
| 432 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 433 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a2, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 434 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.symplectic(h * a1, x, y, z, px, py, pz, xa, ya,za, pxa, pya, pza, b)
|
| 435 |
+
x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza = midpointhelper.rejust(x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza)
|
| 436 |
+
|
| 437 |
+
t = t + h
|
| 438 |
+
|
| 439 |
+
if jn % 10 == 0:
|
| 440 |
+
original_mid_point_data.append([b, x, y, z, px, py, pz])
|
| 441 |
+
# Update progress bar
|
| 442 |
+
progress_percentage = int((current_iteration / total_iterations) * 100)
|
| 443 |
+
tcn_progress_bar.progress(progress_percentage)
|
| 444 |
+
jn = jn + 1
|
| 445 |
+
current_iteration += 1
|
| 446 |
+
|
| 447 |
+
original_mid_point_data = np.array(original_mid_point_data)
|
| 448 |
+
mid_point_end_time = time.time()
|
| 449 |
+
mid_point_data = original_mid_point_data.copy()
|
| 450 |
+
mid_point_data = uniform_sampling(mid_point_data, n_sample=300)
|
| 451 |
+
scaler = MinMaxScaler()
|
| 452 |
+
mid_point_data = scaler.fit_transform(mid_point_data)
|
| 453 |
+
mid_point_data = torch.tensor(np.stack(mid_point_data)).float()
|
| 454 |
+
mid_point_data = torch.stack([compute_gradient(i, degree=2) for i in mid_point_data]).unsqueeze(0)
|
| 455 |
+
|
| 456 |
+
tcn_start_time = time.time()
|
| 457 |
+
with torch.no_grad():
|
| 458 |
+
tcn_preds = None
|
| 459 |
+
for i in range(20):
|
| 460 |
+
if i == 0:
|
| 461 |
+
tcn_preds = tcn_model(mid_point_data[:, :300, :])
|
| 462 |
+
else:
|
| 463 |
+
gd_y_hat = compute_gradient(tcn_preds[:, :i, :], degree=2).to('cpu')
|
| 464 |
+
output = tcn_model(torch.cat([mid_point_data[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
|
| 465 |
+
tcn_preds = torch.cat([tcn_preds, output], dim=1)
|
| 466 |
+
tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
|
| 467 |
+
|
| 468 |
+
original_mid_point_data = uniform_sampling(original_mid_point_data, n_sample=300)
|
| 469 |
+
|
| 470 |
+
tcn_end_time = time.time()
|
| 471 |
+
tcn_progress_bar.progress(100)
|
| 472 |
+
|
| 473 |
+
combined_preds = np.concatenate([original_mid_point_data, tcn_innv_preds], axis=0)
|
| 474 |
+
|
| 475 |
+
tcn_total_time = tcn_end_time - tcn_start_time
|
| 476 |
+
mid_point_total_time = mid_point_end_time - mid_point_start_time
|
| 477 |
+
|
| 478 |
+
total_end_time = time.time()
|
| 479 |
+
total_time = total_end_time - total_start_time
|
| 480 |
+
|
| 481 |
+
return mid_point_total_time, tcn_total_time, total_time, combined_preds
|
| 482 |
+
|
| 483 |
+
container = st.container()
|
| 484 |
+
container1, container2 = st.columns(2)
|
| 485 |
+
plot_container = st.container()
|
| 486 |
+
|
| 487 |
+
with st.sidebar:
|
| 488 |
+
slider_sample_orbit = st.slider('Orbit Sample ID', 1, 10, 1)
|
| 489 |
+
option_time = 32000
|
| 490 |
+
st.write(f'Total Time Step: {option_time}')
|
| 491 |
+
options_method = st.multiselect(
|
| 492 |
+
'Compared Methods',
|
| 493 |
+
['Low-Level', 'High-Level', 'Midpoint', 'Low-Level with LSTM', 'Low-Level with TCN', 'Midpoint with LSTM', 'Midpoint with TCN'],
|
| 494 |
+
['Low-Level'])
|
| 495 |
+
btn_go = st.button("Go", type="primary", use_container_width=True)
|
| 496 |
+
|
| 497 |
+
if btn_go:
|
| 498 |
+
if 'Low-Level' in options_method:
|
| 499 |
+
with container1:
|
| 500 |
+
st.write('Low Level Progress Bar')
|
| 501 |
+
low_level_progress_bar = st.progress(0)
|
| 502 |
+
low_level_30000_time, low_level_2000_time, low_level_total_time, low_level_result = low_level(option_time, slider_sample_orbit, low_level_progress_bar)
|
| 503 |
+
with container2:
|
| 504 |
+
st.table(pd.DataFrame({'Model':"Low Level", '30000 Time Steps (s)': [low_level_30000_time], '2000 Time Steps (s)': [low_level_2000_time], 'Total Time (s)': [low_level_total_time]}))
|
| 505 |
+
if 'High-Level' in options_method:
|
| 506 |
+
with container1:
|
| 507 |
+
st.write('High Level Progress Bar')
|
| 508 |
+
high_level_progress_bar = st.progress(0)
|
| 509 |
+
high_level_30000_time, high_level_2000_time, high_level_total_time, high_level_result = high_level(option_time, slider_sample_orbit, high_level_progress_bar)
|
| 510 |
+
with container2:
|
| 511 |
+
st.table(pd.DataFrame({'Model':"High Level", '30000 Time Steps (s)': [high_level_30000_time], '2000 Time Steps (s)': [high_level_2000_time], 'Total Time (s)': [high_level_total_time]}))
|
| 512 |
+
if 'Midpoint' in options_method:
|
| 513 |
+
with container1:
|
| 514 |
+
st.write('Midpoint Progress Bar')
|
| 515 |
+
mid_point_progress_bar = st.progress(0)
|
| 516 |
+
mid_point_30000_time, mid_point_2000_time, mid_point_total_time, mid_point_result = midpoint(option_time, slider_sample_orbit, mid_point_progress_bar)
|
| 517 |
+
with container2:
|
| 518 |
+
st.table(pd.DataFrame({'Model':"Midpoint", '30000 Time Steps (s)': [mid_point_30000_time], '2000 Time Steps (s)': [mid_point_2000_time], 'Total Time (s)': [mid_point_total_time]}))
|
| 519 |
+
if 'Low-Level with LSTM' in options_method:
|
| 520 |
+
with container1:
|
| 521 |
+
st.write('Low Level LSTM Progress Bar')
|
| 522 |
+
low_level_lstm_progress_bar = st.progress(0)
|
| 523 |
+
lstm_30000_time, lstm_2000_time, lstm_total_time, lstm_result = low_level_lstm(slider_sample_orbit, low_level_lstm_progress_bar)
|
| 524 |
+
with container2:
|
| 525 |
+
st.table(pd.DataFrame({'Model':"Low Level + LSTM", '30000 Time Steps (s)': [lstm_30000_time], '2000 Time Steps (s)': [lstm_2000_time], 'Total Time (s)': [lstm_total_time]}))
|
| 526 |
+
if 'Low-Level with TCN' in options_method:
|
| 527 |
+
with container1:
|
| 528 |
+
st.write('Low Level TCN Progress Bar')
|
| 529 |
+
low_level_tcn_progress_bar = st.progress(0)
|
| 530 |
+
tcn_30000_time, tcn_2000_time, tcn_total_time, tcn_result = low_level_tcn(slider_sample_orbit, low_level_tcn_progress_bar)
|
| 531 |
+
with container2:
|
| 532 |
+
st.table(pd.DataFrame({'Model':"Low Level + TCN", '30000 Time Steps (s)': [tcn_30000_time], '2000 Time Steps (s)': [tcn_2000_time], 'Total Time (s)': [tcn_total_time]}))
|
| 533 |
+
if 'Midpoint with LSTM' in options_method:
|
| 534 |
+
with container1:
|
| 535 |
+
st.write('Midpoint LSTM Progress Bar')
|
| 536 |
+
mid_point_lstm_progress_bar = st.progress(0)
|
| 537 |
+
md_lstm_30000_time, md_lstm_2000_time, md_lstm_total_time, md_lstm_result = mid_point_lstm(slider_sample_orbit, mid_point_lstm_progress_bar)
|
| 538 |
+
with container2:
|
| 539 |
+
st.table(pd.DataFrame({'Model':"Midpoint + LSTM", '30000 Time Steps (s)': [md_lstm_30000_time], '2000 Time Steps (s)': [md_lstm_2000_time], 'Total Time (s)': [md_lstm_total_time]}))
|
| 540 |
+
if 'Midpoint with TCN' in options_method:
|
| 541 |
+
with container1:
|
| 542 |
+
st.write('Midpoint TCN Progress Bar')
|
| 543 |
+
mid_point_tcn_progress_bar = st.progress(0)
|
| 544 |
+
md_tcn_30000_time, md_tcn_2000_time, md_tcn_total_time, md_tcn_result = mid_point_tcn(slider_sample_orbit, mid_point_tcn_progress_bar)
|
| 545 |
+
with container2:
|
| 546 |
+
st.table(pd.DataFrame({'Model':"Midpoint + TCN", '30000 Time Steps (s)': [md_tcn_30000_time], '2000 Time Steps (s)': [md_tcn_2000_time], 'Total Time (s)': [md_tcn_total_time]}))
|
| 547 |
+
|
prediction.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from sklearn.preprocessing import MinMaxScaler
|
| 6 |
+
import torch
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
from utils.transform import compute_gradient
|
| 10 |
+
from model.lstm import LSTMModel
|
| 11 |
+
from model.tcn import TCNModel
|
| 12 |
+
from model.tcn import move_custom_layers_to_device
|
| 13 |
+
from utils.metrics import calculate_metrics
|
| 14 |
+
|
| 15 |
+
each_feature_name = ["q_1","q_2","q_3","p_1","p_2","p_3"]
|
| 16 |
+
|
| 17 |
+
def uniform_sampling(data, n_sample):
|
| 18 |
+
k = len(data) // n_sample
|
| 19 |
+
return data[::k]
|
| 20 |
+
|
| 21 |
+
st.set_page_config(page_title="Prediction", page_icon=":chart_with_upwards_trend:", layout="wide", initial_sidebar_state="auto")
|
| 22 |
+
|
| 23 |
+
#st.title("Prediction")
|
| 24 |
+
|
| 25 |
+
with st.sidebar:
|
| 26 |
+
slider_predict_step = st.slider('Predicted Step', 0, 20, 20)
|
| 27 |
+
|
| 28 |
+
number_input_sample_id = st.number_input("Select Sample ID 1~10", value=1, placeholder="Type a number...", min_value=1, max_value=10, step=1)
|
| 29 |
+
|
| 30 |
+
squences_start_idx = st.slider('Squences Start Index', 0, 700 - slider_predict_step, 0)
|
| 31 |
+
|
| 32 |
+
st.subheader("Model Configuration")
|
| 33 |
+
st.write("LSTM Window Size: ", 200)
|
| 34 |
+
st.write("TCN Window Size: ", 300)
|
| 35 |
+
st.write("Predicted Step: ", slider_predict_step)
|
| 36 |
+
st.write("Feature Augmentation: ", "Second-order derivative")
|
| 37 |
+
|
| 38 |
+
file_path = os.path.join("data", "file"+str(number_input_sample_id)+".dat.npz")
|
| 39 |
+
data = pd.DataFrame(np.load(file_path)['data'])
|
| 40 |
+
scaler = MinMaxScaler()
|
| 41 |
+
uniform_data = uniform_sampling(data, n_sample=1000).sort_index().values[:, 1:8]
|
| 42 |
+
normal_uniform_data = scaler.fit_transform(uniform_data)
|
| 43 |
+
data_sequences = torch.tensor(np.stack(normal_uniform_data)).float()
|
| 44 |
+
original_data_sequences = torch.tensor(np.stack(uniform_data)).float()
|
| 45 |
+
selected_data = data_sequences[squences_start_idx:squences_start_idx+300+slider_predict_step]
|
| 46 |
+
original_selected_data = original_data_sequences[squences_start_idx:squences_start_idx+300+slider_predict_step]
|
| 47 |
+
input_data = torch.stack([compute_gradient(i, degree=2) for i in selected_data]).unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
with st.sidebar:
|
| 50 |
+
st.subheader("Data Configuration")
|
| 51 |
+
st.write("Sample ID: ", number_input_sample_id)
|
| 52 |
+
#st.write("Origianl Shape: ", data_sequences.shape)
|
| 53 |
+
st.write("Squences Start Index: ", squences_start_idx)
|
| 54 |
+
#st.write("Selected Shape: ", selected_data.shape)
|
| 55 |
+
#st.write("Input Shape: ", input_data.shape)
|
| 56 |
+
|
| 57 |
+
# st.write(selected_data[0])
|
| 58 |
+
# st.write(input_data[0][0])
|
| 59 |
+
|
| 60 |
+
#################################################
|
| 61 |
+
## LSTM GPU Inference
|
| 62 |
+
#################################################
|
| 63 |
+
lstm_ckpt_file = os.path.join("model", "lstm.ckpt")
|
| 64 |
+
lstm_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
|
| 65 |
+
lstm_model.eval()
|
| 66 |
+
lstm_start_time = time.time()
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
lstm_preds = lstm_model(input_data[:, 100:300, :].cuda())
|
| 69 |
+
lstm_end_time = time.time()
|
| 70 |
+
lstm_innv_preds = scaler.inverse_transform(lstm_preds.squeeze().cpu().numpy())
|
| 71 |
+
lstm_normal_preds = lstm_preds.squeeze().cpu().numpy()
|
| 72 |
+
|
| 73 |
+
#lstm_model.to_onnx("model/lstm.onnx", torch.randn((1, 200, 21)), export_params=True)
|
| 74 |
+
|
| 75 |
+
del lstm_model
|
| 76 |
+
|
| 77 |
+
#################################################
|
| 78 |
+
## LSTM CPU Inference
|
| 79 |
+
#################################################
|
| 80 |
+
lstm_cpu_ckpt_file = os.path.join("model", "lstm.ckpt")
|
| 81 |
+
lstm_cpu_model = LSTMModel.load_from_checkpoint(lstm_ckpt_file)
|
| 82 |
+
lstm_cpu_model.to("cpu")
|
| 83 |
+
lstm_cpu_model.eval()
|
| 84 |
+
lstm_cpu_start_time = time.time()
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
lstm_cpu_preds = lstm_cpu_model(input_data[:, 100:300, :])
|
| 87 |
+
lstm_cpu_end_time = time.time()
|
| 88 |
+
|
| 89 |
+
del lstm_cpu_model
|
| 90 |
+
|
| 91 |
+
#################################################
|
| 92 |
+
## TCN GPU Inference
|
| 93 |
+
#################################################
|
| 94 |
+
tcn_ckpt_file = os.path.join("model", "tcn.ckpt")
|
| 95 |
+
tcn_model = TCNModel.load_from_checkpoint(tcn_ckpt_file)
|
| 96 |
+
tcn_model.eval()
|
| 97 |
+
tcn_start_time = time.time()
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
input_data_cuda = input_data[:,:300,:].cuda()
|
| 100 |
+
y_hat = tcn_model(input_data_cuda)
|
| 101 |
+
for i in range(1, slider_predict_step):
|
| 102 |
+
gd_y_hat = compute_gradient(y_hat[:, :i, :], degree=2).cuda()
|
| 103 |
+
output = tcn_model(torch.cat([input_data[:, i:300, :].cuda(), gd_y_hat], dim=1)).cuda()
|
| 104 |
+
y_hat = torch.cat([y_hat, output], dim=1)
|
| 105 |
+
tcn_end_time = time.time()
|
| 106 |
+
tcn_preds = y_hat
|
| 107 |
+
tcn_innv_preds = scaler.inverse_transform(tcn_preds.squeeze().cpu().numpy())
|
| 108 |
+
tcn_normal_preds = tcn_preds.squeeze().cpu().numpy()
|
| 109 |
+
|
| 110 |
+
#tcn_model.to_onnx("model/tcn.onnx", torch.randn((1, 300, 21)), export_params=True)
|
| 111 |
+
|
| 112 |
+
del tcn_model
|
| 113 |
+
del y_hat, gd_y_hat, output
|
| 114 |
+
|
| 115 |
+
#################################################
|
| 116 |
+
## TCN CPU Inference
|
| 117 |
+
#################################################
|
| 118 |
+
input_data_cpu = input_data.to("cpu")
|
| 119 |
+
tcn_cpu_ckpt_file = os.path.join("model", "tcn.ckpt")
|
| 120 |
+
tcn_cpu_model = TCNModel.load_from_checkpoint(tcn_cpu_ckpt_file)
|
| 121 |
+
move_custom_layers_to_device(tcn_cpu_model, "cpu")
|
| 122 |
+
tcn_cpu_model.eval()
|
| 123 |
+
tcn_cpu_start_time = time.time()
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
y_hat = None
|
| 126 |
+
for i in range(slider_predict_step):
|
| 127 |
+
if i == 0:
|
| 128 |
+
y_hat = tcn_cpu_model(input_data_cpu[:,:300,:])
|
| 129 |
+
else:
|
| 130 |
+
gd_y_hat = compute_gradient(y_hat[:, :i, :], degree=2).to('cpu')
|
| 131 |
+
output = tcn_cpu_model(torch.concatenate([input_data_cpu[:, i:300, :], gd_y_hat], dim=1).to('cpu'))
|
| 132 |
+
y_hat = torch.concatenate([y_hat, output], dim=1)
|
| 133 |
+
tcn_cpu_preds = y_hat
|
| 134 |
+
tcn_cpu_end_time = time.time()
|
| 135 |
+
|
| 136 |
+
del tcn_cpu_model
|
| 137 |
+
|
| 138 |
+
st.subheader("Normalized Prediction")
|
| 139 |
+
|
| 140 |
+
i = 1
|
| 141 |
+
for each_col in st.columns(6):
|
| 142 |
+
with each_col:
|
| 143 |
+
raw_data = selected_data[:, i]
|
| 144 |
+
lstm_data = [np.nan] * 300 + lstm_normal_preds[:slider_predict_step, :][:, i].tolist()
|
| 145 |
+
tcn_data = [np.nan] * 300 + tcn_normal_preds[:, i].tolist()
|
| 146 |
+
st.markdown(f"<div style='text-align: center'>{each_feature_name[i-1]}</div>", unsafe_allow_html=True)
|
| 147 |
+
#st.write(np.array(raw_data).shape, np.array(lstm_data).shape, np.array(tcn_data).shape)
|
| 148 |
+
st.line_chart(pd.DataFrame({"Original": raw_data, "LSTM": lstm_data, "TCN": tcn_data}),
|
| 149 |
+
color=["#EE4035", "#0077BB", "#7BC043"])
|
| 150 |
+
i += 1
|
| 151 |
+
|
| 152 |
+
# with st.sidebar:
|
| 153 |
+
# st.write("Predicted Shape: ", lstm_preds.shape)
|
| 154 |
+
|
| 155 |
+
st.subheader("Inverse Normalized Prediction")
|
| 156 |
+
|
| 157 |
+
i = 1
|
| 158 |
+
for each_col in st.columns(6):
|
| 159 |
+
with each_col:
|
| 160 |
+
raw_data = original_selected_data[:, i]
|
| 161 |
+
lstm_data = [np.nan] * 300 + lstm_innv_preds[:slider_predict_step, :][:, i].tolist()
|
| 162 |
+
tcn_data = [np.nan] * 300 + tcn_innv_preds[:, i].tolist()
|
| 163 |
+
st.markdown(f"<div style='text-align: center'>{each_feature_name[i - 1]}</div>", unsafe_allow_html=True)
|
| 164 |
+
st.line_chart(pd.DataFrame({"Original": raw_data, "LSTM": lstm_data, "TCN": tcn_data}),
|
| 165 |
+
color=["#EE4035", "#0077BB", "#7BC043"])
|
| 166 |
+
i += 1
|
| 167 |
+
|
| 168 |
+
LSTM_SMAPE, LSTM_MSE, LSTM_RMSE, LSTM_MAE, LSTM_R2, LSTM_PSD = calculate_metrics(selected_data[300:300+slider_predict_step, :].cpu().numpy(), lstm_normal_preds[:slider_predict_step, :])
|
| 169 |
+
TCN_SMAPE, TCN_MSE, TCN_RMSE, TCN_MAE, TCN_R2, TCN_PSD = calculate_metrics(selected_data[300:300+slider_predict_step, :].cpu().numpy(), tcn_normal_preds)
|
| 170 |
+
|
| 171 |
+
results_df = pd.DataFrame({
|
| 172 |
+
"Model": ["LSTM", "TCN"],
|
| 173 |
+
"SMAPE": [LSTM_SMAPE, TCN_SMAPE],
|
| 174 |
+
"MSE": [LSTM_MSE, TCN_MSE],
|
| 175 |
+
"RMSE": [LSTM_RMSE, TCN_RMSE],
|
| 176 |
+
"MAE": [LSTM_MAE, TCN_MAE],
|
| 177 |
+
"R2": [LSTM_R2, TCN_R2],
|
| 178 |
+
"PSD": [LSTM_PSD, TCN_PSD]
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
+
time_df = pd.DataFrame({
|
| 182 |
+
"Model": ["LSTM-GPU", "TCN-GPU", "LSTM-CPU", "TCN-CPU"],
|
| 183 |
+
"Time(ms)": [(lstm_end_time - lstm_start_time)*1000,
|
| 184 |
+
(tcn_end_time - tcn_start_time)*1000,
|
| 185 |
+
(lstm_cpu_end_time - lstm_cpu_start_time)*1000,
|
| 186 |
+
(tcn_cpu_end_time - tcn_cpu_start_time)*1000]
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
col1, col2 = st.columns(2)
|
| 190 |
+
with col1:
|
| 191 |
+
st.subheader("Evaluation Metrics")
|
| 192 |
+
st.write(results_df)
|
| 193 |
+
|
| 194 |
+
with col2:
|
| 195 |
+
st.subheader("Prediction Time")
|
| 196 |
+
st.write(time_df)
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas==1.5.3
|
| 2 |
+
scikit-learn==1.3.2
|
| 3 |
+
torch==2.1.1
|
| 4 |
+
streamlit==1.32.2
|
| 5 |
+
keras==3.0.0
|
| 6 |
+
torchvision==0.16.1
|
| 7 |
+
pytorch-lightning==2.1.2
|
utils/__pycache__/highlevel.cpython-310.pyc
ADDED
|
Binary file (4.08 kB). View file
|
|
|
utils/__pycache__/lowlevel.cpython-310.pyc
ADDED
|
Binary file (6.04 kB). View file
|
|
|
utils/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (2.22 kB). View file
|
|
|
utils/__pycache__/midpoint.cpython-310.pyc
ADDED
|
Binary file (6.02 kB). View file
|
|
|
utils/__pycache__/transform.cpython-310.pyc
ADDED
|
Binary file (448 Bytes). View file
|
|
|
utils/highlevel.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sympy as sp
|
| 2 |
+
|
| 3 |
+
class HighLevel():
|
| 4 |
+
|
| 5 |
+
def __init__(self, j):
|
| 6 |
+
self.j = j
|
| 7 |
+
|
| 8 |
+
def initial(self):
|
| 9 |
+
j = self.j
|
| 10 |
+
|
| 11 |
+
#init parameters
|
| 12 |
+
h, b, n, u = 0.1, None, None, None
|
| 13 |
+
x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
|
| 14 |
+
xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
|
| 15 |
+
|
| 16 |
+
if j == 1:
|
| 17 |
+
b = 5.0 / 4
|
| 18 |
+
n = b / (1 + b) ** 2
|
| 19 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 20 |
+
x = 10.0
|
| 21 |
+
py = 0.5
|
| 22 |
+
elif j == 2:
|
| 23 |
+
b = 3.0 / 4
|
| 24 |
+
n = b / (1 + b) ** 2
|
| 25 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 26 |
+
x = 8.3
|
| 27 |
+
py = 0.6
|
| 28 |
+
elif j == 3:
|
| 29 |
+
b = 3.0 / 2
|
| 30 |
+
x = 12.0
|
| 31 |
+
py = 0.4
|
| 32 |
+
elif j == 4:
|
| 33 |
+
b = 7.0 / 4
|
| 34 |
+
n = b / (1 + b) ** 2
|
| 35 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 36 |
+
x = 15.0
|
| 37 |
+
py = 0.35
|
| 38 |
+
elif j == 5:
|
| 39 |
+
b = 1.0
|
| 40 |
+
n = b / (1 + b) ** 2
|
| 41 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 42 |
+
x = 18.0
|
| 43 |
+
py = 0.3
|
| 44 |
+
elif j == 6:
|
| 45 |
+
b = 3.0 / 5
|
| 46 |
+
n = b / (1 + b) ** 2
|
| 47 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 48 |
+
x = 20.0
|
| 49 |
+
py = 0.25
|
| 50 |
+
elif j == 7:
|
| 51 |
+
b = 5.0 / 7
|
| 52 |
+
n = b / (1 + b) ** 2
|
| 53 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 54 |
+
x = 22.0
|
| 55 |
+
py = 0.22
|
| 56 |
+
elif j == 8:
|
| 57 |
+
b = 2.0
|
| 58 |
+
x = 26.0
|
| 59 |
+
py = 0.2
|
| 60 |
+
elif j == 9:
|
| 61 |
+
b = 0.5
|
| 62 |
+
n = b / (1 + b) ** 2
|
| 63 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 64 |
+
x = 30.0
|
| 65 |
+
y = 0.5
|
| 66 |
+
z = 0.1
|
| 67 |
+
pz = 0.01
|
| 68 |
+
elif j == 10:
|
| 69 |
+
b = 5.0
|
| 70 |
+
n = b / (1 + b) ** 2
|
| 71 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 72 |
+
x = 35.0
|
| 73 |
+
y = 2.0
|
| 74 |
+
z = 0.1
|
| 75 |
+
pz = 0.03
|
| 76 |
+
py = 0.15
|
| 77 |
+
|
| 78 |
+
xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
|
| 79 |
+
return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
|
| 80 |
+
|
| 81 |
+
def f(self, x, y, z, px, py, pz, b):
|
| 82 |
+
x_val, y_val, z_val, px_val, py_val, pz_val, b_val = x, y, z, px, py, pz, b
|
| 83 |
+
x, y, z, px, py, pz, b = sp.symbols('x y z px py pz b')
|
| 84 |
+
|
| 85 |
+
c = 1.0
|
| 86 |
+
|
| 87 |
+
u = 1 / (1 / b + b + 2)
|
| 88 |
+
ht = px ** 2 / 2 + py ** 2 / 2 + pz ** 2 / 2
|
| 89 |
+
hv = -1 / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
| 90 |
+
h1pn = 1 / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) - (((u + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2 + (u * (
|
| 91 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
| 92 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2) / (
|
| 93 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + ((3 * u) / 8 - 1 / 8) * (
|
| 94 |
+
px ** 2 + py ** 2 + pz ** 2) ** 2
|
| 95 |
+
|
| 96 |
+
e = ht + hv + h1pn
|
| 97 |
+
|
| 98 |
+
de_dx = sp.diff(e, x)
|
| 99 |
+
de_dy = sp.diff(e, y)
|
| 100 |
+
de_dz = sp.diff(e, z)
|
| 101 |
+
de_dpx = sp.diff(e, px)
|
| 102 |
+
de_dpy = sp.diff(e, py)
|
| 103 |
+
de_dpz = sp.diff(e, pz)
|
| 104 |
+
|
| 105 |
+
de_dx_val = de_dx.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
| 106 |
+
de_dy_val = de_dy.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
| 107 |
+
de_dz_val = de_dz.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
| 108 |
+
de_dpx_val = de_dpx.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
| 109 |
+
de_dpy_val = de_dpy.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
| 110 |
+
de_dpz_val = de_dpz.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
| 111 |
+
|
| 112 |
+
e_val = e.subs({x: x_val, y: y_val, z: z_val, px: px_val, py: py_val, pz: pz_val, b: b_val})
|
| 113 |
+
|
| 114 |
+
return de_dx_val, de_dy_val, de_dz_val, de_dpx_val, de_dpy_val, de_dpz_val, e_val
|
| 115 |
+
|
| 116 |
+
def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
|
| 117 |
+
|
| 118 |
+
x = (x + xa) / 2
|
| 119 |
+
y = (y + ya) / 2
|
| 120 |
+
z = (z + za) / 2
|
| 121 |
+
|
| 122 |
+
px = (px + pxa) / 2
|
| 123 |
+
py = (py + pya) / 2
|
| 124 |
+
pz = (pz + pza) / 2
|
| 125 |
+
xa = x
|
| 126 |
+
ya = y
|
| 127 |
+
za = z
|
| 128 |
+
pxa = px
|
| 129 |
+
pya = py
|
| 130 |
+
pza = pz
|
| 131 |
+
|
| 132 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
| 133 |
+
|
| 134 |
+
def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
|
| 135 |
+
|
| 136 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
| 137 |
+
x = x + h / 2 * vpx
|
| 138 |
+
y = y + h / 2 * vpy
|
| 139 |
+
z = z + h / 2 * vpz
|
| 140 |
+
pxa = pxa - h / 2 * vxa
|
| 141 |
+
pya = pya - h / 2 * vya
|
| 142 |
+
pza = pza - h / 2 * vza
|
| 143 |
+
|
| 144 |
+
vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
|
| 145 |
+
xa = xa + h * vpxa
|
| 146 |
+
ya = ya + h * vpya
|
| 147 |
+
za = za + h * vpza
|
| 148 |
+
px = px - h * vx
|
| 149 |
+
py = py - h * vy
|
| 150 |
+
pz = pz - h * vz
|
| 151 |
+
|
| 152 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
| 153 |
+
x = x + h / 2 * vpx
|
| 154 |
+
y = y + h / 2 * vpy
|
| 155 |
+
z = z + h / 2 * vpz
|
| 156 |
+
pxa = pxa - h / 2 * vxa
|
| 157 |
+
pya = pya - h / 2 * vya
|
| 158 |
+
pza = pza - h / 2 * vza
|
| 159 |
+
|
| 160 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
utils/lowlevel.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class LowLevel():
|
| 5 |
+
|
| 6 |
+
def __init__(self, j):
|
| 7 |
+
self.j = j
|
| 8 |
+
|
| 9 |
+
def initial(self):
|
| 10 |
+
j = self.j
|
| 11 |
+
|
| 12 |
+
#init parameters
|
| 13 |
+
h, b, n, u = 0.1, None, None, None
|
| 14 |
+
x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
|
| 15 |
+
xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
|
| 16 |
+
|
| 17 |
+
if j == 1:
|
| 18 |
+
b = 5.0 / 4
|
| 19 |
+
n = b / (1 + b) ** 2
|
| 20 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 21 |
+
x = 10.0
|
| 22 |
+
py = 0.5
|
| 23 |
+
elif j == 2:
|
| 24 |
+
b = 3.0 / 4
|
| 25 |
+
n = b / (1 + b) ** 2
|
| 26 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 27 |
+
x = 8.3
|
| 28 |
+
py = 0.6
|
| 29 |
+
elif j == 3:
|
| 30 |
+
b = 3.0 / 2
|
| 31 |
+
x = 12.0
|
| 32 |
+
py = 0.4
|
| 33 |
+
elif j == 4:
|
| 34 |
+
b = 7.0 / 4
|
| 35 |
+
n = b / (1 + b) ** 2
|
| 36 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 37 |
+
x = 15.0
|
| 38 |
+
py = 0.35
|
| 39 |
+
elif j == 5:
|
| 40 |
+
b = 1.0
|
| 41 |
+
n = b / (1 + b) ** 2
|
| 42 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 43 |
+
x = 18.0
|
| 44 |
+
py = 0.3
|
| 45 |
+
elif j == 6:
|
| 46 |
+
b = 3.0 / 5
|
| 47 |
+
n = b / (1 + b) ** 2
|
| 48 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 49 |
+
x = 20.0
|
| 50 |
+
py = 0.25
|
| 51 |
+
elif j == 7:
|
| 52 |
+
b = 5.0 / 7
|
| 53 |
+
n = b / (1 + b) ** 2
|
| 54 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 55 |
+
x = 22.0
|
| 56 |
+
py = 0.22
|
| 57 |
+
elif j == 8:
|
| 58 |
+
b = 2.0
|
| 59 |
+
x = 26.0
|
| 60 |
+
py = 0.2
|
| 61 |
+
elif j == 9:
|
| 62 |
+
b = 0.5
|
| 63 |
+
n = b / (1 + b) ** 2
|
| 64 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 65 |
+
x = 30.0
|
| 66 |
+
y = 0.5
|
| 67 |
+
z = 0.1
|
| 68 |
+
pz = 0.01
|
| 69 |
+
elif j == 10:
|
| 70 |
+
b = 5.0
|
| 71 |
+
n = b / (1 + b) ** 2
|
| 72 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 73 |
+
x = 35.0
|
| 74 |
+
y = 2.0
|
| 75 |
+
z = 0.1
|
| 76 |
+
pz = 0.03
|
| 77 |
+
py = 0.15
|
| 78 |
+
|
| 79 |
+
xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
|
| 80 |
+
return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
|
| 81 |
+
|
| 82 |
+
def f(self, x, y, z, px, py, pz, b):
|
| 83 |
+
n = b / (1 + b)**2
|
| 84 |
+
u = 1 / (1 / b + b + 2)
|
| 85 |
+
ht = px**2 / 2 + py**2 / 2 + pz**2 / 2
|
| 86 |
+
hv = -1 / (x**2 + y**2 + z**2)**(1/2)
|
| 87 |
+
h1pn = 1/(2*x**2 + 2*y**2 + 2*z**2) - (((u + 3)*(px**2 + py**2 +pz**2))/2 + (u*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/ (x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))**2)/2)/(x**2 + y**2 + z**2)**(1/2) + ((3*u)/8 -1/8)*(px**2 + py**2 + pz**2)**2
|
| 88 |
+
|
| 89 |
+
e = ht + hv + h1pn
|
| 90 |
+
|
| 91 |
+
vnpx=px
|
| 92 |
+
v1pnpx=4*px*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (px*(n + 3) + (n*x*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
| 93 |
+
vpx=vnpx+v1pnpx
|
| 94 |
+
|
| 95 |
+
vnpy=py
|
| 96 |
+
v1pnpy=4*py*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (py*(n + 3) + (n*y*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
| 97 |
+
vpy=vnpy+v1pnpy
|
| 98 |
+
|
| 99 |
+
vnpz=pz
|
| 100 |
+
v1pnpz=4*pz*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (pz*(n + 3) + (n*z*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
| 101 |
+
vpz=vnpz+v1pnpz
|
| 102 |
+
|
| 103 |
+
vnx=x/(x**2 + y**2 + z**2)**(3/2)
|
| 104 |
+
v1pnx=(x*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*x)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((px*x**2)/(x**2 + y**2 + z**2)**(3/2) -px/(x**2 + y**2 + z**2)**(1/2) + (py*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*x*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
| 105 |
+
vx=vnx+v1pnx
|
| 106 |
+
|
| 107 |
+
vny=y/(x**2 + y**2 + z**2)**(3/2)
|
| 108 |
+
v1pny=(y*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*y)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((py*y**2)/(x**2 + y**2 + z**2)**(3/2) -py/(x**2 + y**2 + z**2)**(1/2) + (px*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
| 109 |
+
vy=vny+v1pny
|
| 110 |
+
|
| 111 |
+
vnz=z/(x**2 + y**2 + z**2)**(3/2)
|
| 112 |
+
v1pnz=(z*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*z)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((pz*z**2)/(x**2 + y**2 + z**2)**(3/2) -pz/(x**2 + y**2 + z**2)**(1/2) + (px*x*z)/(x**2 + y**2 + z**2)**(3/2) + (py*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
| 113 |
+
vz=vnz+v1pnz
|
| 114 |
+
|
| 115 |
+
return vx,vy,vz,vpx,vpy,vpz,e
|
| 116 |
+
|
| 117 |
+
def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
|
| 118 |
+
x = (x + xa) / 2
|
| 119 |
+
y = (y + ya) / 2
|
| 120 |
+
z = (z + za) / 2
|
| 121 |
+
|
| 122 |
+
px = (px + pxa) / 2
|
| 123 |
+
py = (py + pya) / 2
|
| 124 |
+
pz = (pz + pza) / 2
|
| 125 |
+
xa = x
|
| 126 |
+
ya = y
|
| 127 |
+
za = z
|
| 128 |
+
pxa = px
|
| 129 |
+
pya = py
|
| 130 |
+
pza = pz
|
| 131 |
+
return x,y,z,px,py,pz,xa,ya,za,pxa,pya,pza
|
| 132 |
+
|
| 133 |
+
def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
|
| 134 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
| 135 |
+
x = x + h / 2 * vpx
|
| 136 |
+
y = y + h / 2 * vpy
|
| 137 |
+
z = z + h / 2 * vpz
|
| 138 |
+
pxa = pxa - h / 2 * vxa
|
| 139 |
+
pya = pya - h / 2 * vya
|
| 140 |
+
pza = pza - h / 2 * vza
|
| 141 |
+
|
| 142 |
+
vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
|
| 143 |
+
xa = xa + h * vpxa
|
| 144 |
+
ya = ya + h * vpya
|
| 145 |
+
za = za + h * vpza
|
| 146 |
+
px = px - h * vx
|
| 147 |
+
py = py - h * vy
|
| 148 |
+
pz = pz - h * vz
|
| 149 |
+
|
| 150 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
| 151 |
+
x = x + h / 2 * vpx
|
| 152 |
+
y = y + h / 2 * vpy
|
| 153 |
+
z = z + h / 2 * vpz
|
| 154 |
+
pxa = pxa - h / 2 * vxa
|
| 155 |
+
pya = pya - h / 2 * vya
|
| 156 |
+
pza = pza - h / 2 * vza
|
| 157 |
+
|
| 158 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
utils/metrics.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
|
| 3 |
+
|
| 4 |
+
def calculate_metrics(y, y_hat, y_train=None):
|
| 5 |
+
def smape(a, f):
|
| 6 |
+
return 1/len(a) * np.sum(2 * np.abs(f - a) / (np.abs(a) + np.abs(f) + np.finfo(float).eps))
|
| 7 |
+
|
| 8 |
+
def mase(y_actual, y_pred, y_train):
|
| 9 |
+
n = y_train.shape[1]
|
| 10 |
+
d = np.abs(np.diff(y_train)).sum() / (n - 1)
|
| 11 |
+
errors = np.abs(y_actual - y_pred)
|
| 12 |
+
return errors.mean() / d
|
| 13 |
+
|
| 14 |
+
def phase_space_distance(y_actual, y_pred):
|
| 15 |
+
return np.sqrt(np.sum(np.square(y_actual - y_pred)))
|
| 16 |
+
|
| 17 |
+
SMAPE = np.mean([smape(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
| 18 |
+
MSE = np.mean([mean_squared_error(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
| 19 |
+
RMSE = np.mean([np.sqrt(mean_squared_error(yi.reshape(-1), y_hati.reshape(-1))) for yi, y_hati in zip(y, y_hat)])
|
| 20 |
+
MAE = np.mean([mean_absolute_error(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
| 21 |
+
R2 = np.mean([r2_score(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
| 22 |
+
PSD = np.mean([phase_space_distance(yi.reshape(-1), y_hati.reshape(-1)) for yi, y_hati in zip(y, y_hat)])
|
| 23 |
+
|
| 24 |
+
if y_train is None:
|
| 25 |
+
return SMAPE, MSE, RMSE, MAE, R2, PSD
|
| 26 |
+
else:
|
| 27 |
+
MASE = np.mean([mase(yi, y_hati, yt) for yi, y_hati, yt in zip(y, y_hat, y_train)])
|
| 28 |
+
return SMAPE, MSE, RMSE, MAE, R2, MASE, PSD
|
utils/midpoint.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
class MidPoint():
|
| 3 |
+
|
| 4 |
+
def __init__(self, j):
|
| 5 |
+
self.j = j
|
| 6 |
+
|
| 7 |
+
def initial(self):
|
| 8 |
+
j = self.j
|
| 9 |
+
|
| 10 |
+
#init parameters
|
| 11 |
+
h, b, n, u = 0.1, None, None, None
|
| 12 |
+
x, y, z, px, py, pz = None, 0.1, 0.001, 0.01, None, 0.0001
|
| 13 |
+
xa, ya, za, pxa, pya, pza = None, None, None, None, None, None
|
| 14 |
+
|
| 15 |
+
if j == 1:
|
| 16 |
+
b = 5.0 / 4
|
| 17 |
+
n = b / (1 + b) ** 2
|
| 18 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 19 |
+
x = 10.0
|
| 20 |
+
py = 0.5
|
| 21 |
+
elif j == 2:
|
| 22 |
+
b = 3.0 / 4
|
| 23 |
+
n = b / (1 + b) ** 2
|
| 24 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 25 |
+
x = 8.3
|
| 26 |
+
py = 0.6
|
| 27 |
+
elif j == 3:
|
| 28 |
+
b = 3.0 / 2
|
| 29 |
+
x = 12.0
|
| 30 |
+
py = 0.4
|
| 31 |
+
elif j == 4:
|
| 32 |
+
b = 7.0 / 4
|
| 33 |
+
n = b / (1 + b) ** 2
|
| 34 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 35 |
+
x = 15.0
|
| 36 |
+
py = 0.35
|
| 37 |
+
elif j == 5:
|
| 38 |
+
b = 1.0
|
| 39 |
+
n = b / (1 + b) ** 2
|
| 40 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 41 |
+
x = 18.0
|
| 42 |
+
py = 0.3
|
| 43 |
+
elif j == 6:
|
| 44 |
+
b = 3.0 / 5
|
| 45 |
+
n = b / (1 + b) ** 2
|
| 46 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 47 |
+
x = 20.0
|
| 48 |
+
py = 0.25
|
| 49 |
+
elif j == 7:
|
| 50 |
+
b = 5.0 / 7
|
| 51 |
+
n = b / (1 + b) ** 2
|
| 52 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 53 |
+
x = 22.0
|
| 54 |
+
py = 0.22
|
| 55 |
+
elif j == 8:
|
| 56 |
+
b = 2.0
|
| 57 |
+
x = 26.0
|
| 58 |
+
py = 0.2
|
| 59 |
+
elif j == 9:
|
| 60 |
+
b = 0.5
|
| 61 |
+
n = b / (1 + b) ** 2
|
| 62 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 63 |
+
x = 30.0
|
| 64 |
+
y = 0.5
|
| 65 |
+
z = 0.1
|
| 66 |
+
pz = 0.01
|
| 67 |
+
elif j == 10:
|
| 68 |
+
b = 5.0
|
| 69 |
+
n = b / (1 + b) ** 2
|
| 70 |
+
u = 1.0 / (1.0 / b + b + 2.0)
|
| 71 |
+
x = 35.0
|
| 72 |
+
y = 2.0
|
| 73 |
+
z = 0.1
|
| 74 |
+
pz = 0.03
|
| 75 |
+
py = 0.15
|
| 76 |
+
|
| 77 |
+
xa, ya, za, pxa, pya, pza = x, y, z, px, py, pz
|
| 78 |
+
return j, h, b, n, x, y, z, xa, ya, za, px, py, pz, pxa, pya, pza
|
| 79 |
+
|
| 80 |
+
def f(self, x, y, z, px, py, pz, b):
|
| 81 |
+
|
| 82 |
+
n = b / (1 + b) ** 2
|
| 83 |
+
u = 1 / (1 / b + b + 2)
|
| 84 |
+
ht = px ** 2 / 2 + py ** 2 / 2 + pz ** 2 / 2
|
| 85 |
+
hv = -1 / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)
|
| 86 |
+
h1pn = 1 / (2 * x ** 2 + 2 * y ** 2 + 2 * z ** 2) - (((u + 3) * (px ** 2 + py ** 2 + pz ** 2)) / 2 + (u * (
|
| 87 |
+
(px * x) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + (py * y) / (x ** 2 + y ** 2 + z ** 2) ** (
|
| 88 |
+
1 / 2) + (pz * z) / (x ** 2 + y ** 2 + z ** 2) ** (1 / 2)) ** 2) / 2) / (
|
| 89 |
+
x ** 2 + y ** 2 + z ** 2) ** (1 / 2) + ((3 * u) / 8 - 1 / 8) * (
|
| 90 |
+
px ** 2 + py ** 2 + pz ** 2) ** 2
|
| 91 |
+
|
| 92 |
+
e = ht + hv + h1pn
|
| 93 |
+
|
| 94 |
+
vnpx=px
|
| 95 |
+
v1pnpx=4*px*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (px*(n + 3) + (n*x*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
| 96 |
+
vpx=vnpx+v1pnpx
|
| 97 |
+
|
| 98 |
+
vnpy=py
|
| 99 |
+
v1pnpy=4*py*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (py*(n + 3) + (n*y*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
| 100 |
+
vpy=vnpy+v1pnpy
|
| 101 |
+
|
| 102 |
+
vnpz=pz
|
| 103 |
+
v1pnpz=4*pz*((3*n)/8 - 1/8)*(px**2 + py**2 + pz**2) - (pz*(n + 3) + (n*z*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2)))/(x**2 + y**2 + z**2)**(1/2))/(x**2 + y**2 + z**2)**(1/2)
|
| 104 |
+
vpz=vnpz+v1pnpz
|
| 105 |
+
|
| 106 |
+
vnx=x/(x**2 + y**2 + z**2)**(3/2)
|
| 107 |
+
v1pnx=(x*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*x)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((px*x**2)/(x**2 + y**2 + z**2)**(3/2) -px/(x**2 + y**2 + z**2)**(1/2) + (py*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*x*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
| 108 |
+
vx=vnx+v1pnx
|
| 109 |
+
|
| 110 |
+
vny=y/(x**2 + y**2 + z**2)**(3/2)
|
| 111 |
+
v1pny=(y*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*y)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((py*y**2)/(x**2 + y**2 + z**2)**(3/2) -py/(x**2 + y**2 + z**2)**(1/2) + (px*x*y)/(x**2 + y**2 + z**2)**(3/2) + (pz*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
| 112 |
+
vy=vny+v1pny
|
| 113 |
+
|
| 114 |
+
vnz=z/(x**2 + y**2 + z**2)**(3/2)
|
| 115 |
+
v1pnz=(z*((n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 +y**2 + z**2)**(1/2))**2)/2 + ((n + 3)*(px**2 + py**2 + pz**2))/2))/(x**2 + y**2 + z**2)**(3/2) -(4*z)/(2*x**2 + 2*y**2 + 2*z**2)**2 + (n*((px*x)/(x**2 + y**2 + z**2)**(1/2) + (py*y)/(x**2 + y**2 + z**2)**(1/2) + (pz*z)/(x**2 + y**2 + z**2)**(1/2))*((pz*z**2)/(x**2 + y**2 + z**2)**(3/2) -pz/(x**2 + y**2 + z**2)**(1/2) + (px*x*z)/(x**2 + y**2 + z**2)**(3/2) + (py*y*z)/(x**2 + y**2 +z**2)**(3/2)))/(x**2 + y**2 + z**2)**(1/2)
|
| 116 |
+
vz=vnz+v1pnz
|
| 117 |
+
|
| 118 |
+
return vx,vy,vz,vpx,vpy,vpz,e
|
| 119 |
+
|
| 120 |
+
def rejust(self, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza):
|
| 121 |
+
|
| 122 |
+
x = (x + xa) / 2
|
| 123 |
+
y = (y + ya) / 2
|
| 124 |
+
z = (z + za) / 2
|
| 125 |
+
|
| 126 |
+
px = (px + pxa) / 2
|
| 127 |
+
py = (py + pya) / 2
|
| 128 |
+
pz = (pz + pza) / 2
|
| 129 |
+
xa = x
|
| 130 |
+
ya = y
|
| 131 |
+
za = z
|
| 132 |
+
pxa = px
|
| 133 |
+
pya = py
|
| 134 |
+
pza = pz
|
| 135 |
+
|
| 136 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
| 137 |
+
|
| 138 |
+
def symplectic(self, h, x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza, b):
|
| 139 |
+
|
| 140 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
| 141 |
+
x = x + h / 2 * vpx
|
| 142 |
+
y = y + h / 2 * vpy
|
| 143 |
+
z = z + h / 2 * vpz
|
| 144 |
+
pxa = pxa - h / 2 * vxa
|
| 145 |
+
pya = pya - h / 2 * vya
|
| 146 |
+
pza = pza - h / 2 * vza
|
| 147 |
+
|
| 148 |
+
vx, vy, vz, vpxa, vpya, vpza, e = self.f(x, y, z, pxa, pya, pza, b)
|
| 149 |
+
xa = xa + h * vpxa
|
| 150 |
+
ya = ya + h * vpya
|
| 151 |
+
za = za + h * vpza
|
| 152 |
+
px = px - h * vx
|
| 153 |
+
py = py - h * vy
|
| 154 |
+
pz = pz - h * vz
|
| 155 |
+
|
| 156 |
+
vxa, vya, vza, vpx, vpy, vpz, e = self.f(xa, ya, za, px, py, pz, b)
|
| 157 |
+
x = x + h / 2 * vpx
|
| 158 |
+
y = y + h / 2 * vpy
|
| 159 |
+
z = z + h / 2 * vpz
|
| 160 |
+
pxa = pxa - h / 2 * vxa
|
| 161 |
+
pya = pya - h / 2 * vya
|
| 162 |
+
pza = pza - h / 2 * vza
|
| 163 |
+
|
| 164 |
+
return x, y, z, px, py, pz, xa, ya, za, pxa, pya, pza
|
utils/transform.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def compute_gradient(x, degree):
|
| 4 |
+
gradients = [x]
|
| 5 |
+
for i in range(degree):
|
| 6 |
+
x = torch.diff(x, dim=-1, prepend=x[..., 0:1])
|
| 7 |
+
gradients.append(x)
|
| 8 |
+
return torch.concatenate(gradients, dim=-1)
|