sesame-csm-space / generator.py
karagmercola's picture
Update generator.py
1f1e573 verified
raw
history blame
643 Bytes
def load_csm_1b(device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading CSM 1B model on {device}...")
model_dir = "/opt/ml/model"
if not os.path.exists(model_dir):
model_dir = "karagmercola/sesame-csm-1b"
try:
processor = AutoProcessor.from_pretrained(model_dir)
model = CsmModel.from_pretrained(model_dir)
model.to(device)
print(f"CSM 1B model loaded successfully on {device}")
return CSMGenerator(model, processor, device)
except Exception as e:
print(f"Error loading CSM 1B model: {e}")
raise