| import io, requests | |
| import torch | |
| import torch.nn as nn | |
| from dall_e.encoder import Encoder | |
| from dall_e.decoder import Decoder | |
| from dall_e.utils import map_pixels, unmap_pixels | |
| def load_model(path: str, device: torch.device = None) -> nn.Module: | |
| if path.startswith('http://') or path.startswith('https://'): | |
| resp = requests.get(path) | |
| resp.raise_for_status() | |
| with io.BytesIO(resp.content) as buf: | |
| return torch.load(buf, map_location=device) | |
| else: | |
| with open(path, 'rb') as f: | |
| return torch.load(f, map_location=device) | |