| import torch | |
| from utils.model import STRModel | |
| # Create PyTorch Model Object | |
| model = STRModel(input_channels=1, output_channels=512, num_classes=37) | |
| # Load model weights from external file | |
| state = torch.load("None-ResNet-None-CTC.pth", map_location=torch.device('cpu')) | |
| state = {key.replace("module.", ""): value for key, value in state.items()} | |
| model.load_state_dict(state) | |
| # Create ONNX file by tracing model | |
| trace_input = torch.randn(1, 1, 32, 100) | |
| torch.onnx.export(model, trace_input, "str.onnx", verbose=True) |