Spaces:
Sleeping
Sleeping
Switch Space trainer defaults to math_conjecture_sota profile and remove DeepSeek references
9a4f619 verified | #!/usr/bin/env python3 | |
| """Production preflight checks for the Math Conjecture Trainer Space.""" | |
| from __future__ import annotations | |
| import argparse | |
| import importlib | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, List | |
| import yaml | |
| ROOT = Path(__file__).resolve().parents[1] | |
| CONFIG_PATH = ROOT / "configs" / "math_conjecture_sota.yaml" | |
| HF_HOME_DIR = ROOT / "workspace" / ".hf_home" | |
| HF_DATASETS_CACHE_DIR = HF_HOME_DIR / "datasets" | |
| HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub" | |
| class CheckResult: | |
| name: str | |
| ok: bool | |
| detail: str | |
| def check_required_files() -> str: | |
| required = [ | |
| ROOT / "app.py", | |
| ROOT / "scripts" / "train_sota.py", | |
| ROOT / "scripts" / "eval_sota.py", | |
| CONFIG_PATH, | |
| ROOT / "requirements.txt", | |
| ] | |
| missing = [str(path) for path in required if not path.exists()] | |
| if missing: | |
| raise FileNotFoundError("Missing required files: " + ", ".join(missing)) | |
| return f"{len(required)} required files present." | |
| def check_config_shape() -> str: | |
| cfg = yaml.safe_load(CONFIG_PATH.read_text(encoding="utf-8")) | |
| if not isinstance(cfg, dict): | |
| raise ValueError("Config root must be a mapping.") | |
| required_sections = ("model", "data", "stages") | |
| for section in required_sections: | |
| if section not in cfg: | |
| raise ValueError(f"Missing config section: {section}") | |
| stages = cfg.get("stages") | |
| if not isinstance(stages, list) or not stages: | |
| raise ValueError("Config must contain at least one stage.") | |
| return f"Config valid with {len(stages)} stage(s)." | |
| def check_python_imports() -> str: | |
| modules = [ | |
| "gradio", | |
| "torch", | |
| "yaml", | |
| "huggingface_hub", | |
| "datasets", | |
| "transformers", | |
| "peft", | |
| ] | |
| versions: Dict[str, str] = {} | |
| for module_name in modules: | |
| mod = importlib.import_module(module_name) | |
| versions[module_name] = str(getattr(mod, "__version__", "unknown")) | |
| return "Imports OK: " + ", ".join(f"{k}={v}" for k, v in versions.items()) | |
| def check_module_integrity() -> str: | |
| root_str = str(ROOT) | |
| if root_str not in sys.path: | |
| sys.path.insert(0, root_str) | |
| app = importlib.import_module("app") | |
| train_sota = importlib.import_module("scripts.train_sota") | |
| eval_sota = importlib.import_module("scripts.eval_sota") | |
| runtime = app.run_runtime_snapshot() | |
| if not isinstance(runtime, dict): | |
| raise ValueError("Runtime snapshot is not a dictionary.") | |
| if "python" not in runtime or "torch" not in runtime: | |
| raise ValueError("Runtime snapshot missing expected keys.") | |
| train_cfg = train_sota.load_config(CONFIG_PATH) | |
| eval_cfg = eval_sota.load_config(CONFIG_PATH) | |
| if not isinstance(train_cfg, dict) or not isinstance(eval_cfg, dict): | |
| raise ValueError("Config loaders did not return dictionaries.") | |
| return "App/train/eval module imports and config loaders are healthy." | |
| def run_optional_training_dry_run(timeout_seconds: int) -> str: | |
| HF_HOME_DIR.mkdir(parents=True, exist_ok=True) | |
| HF_DATASETS_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| HF_HUB_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| env = os.environ.copy() | |
| env.setdefault("HF_HOME", str(HF_HOME_DIR)) | |
| env.setdefault("HF_DATASETS_CACHE", str(HF_DATASETS_CACHE_DIR)) | |
| env.setdefault("HUGGINGFACE_HUB_CACHE", str(HF_HUB_CACHE_DIR)) | |
| cmd = [ | |
| sys.executable, | |
| str(ROOT / "scripts" / "train_sota.py"), | |
| "--config", | |
| str(CONFIG_PATH), | |
| "--start-stage", | |
| "1", | |
| "--max-stages", | |
| "1", | |
| "--dry-run", | |
| ] | |
| completed = subprocess.run( | |
| cmd, | |
| cwd=str(ROOT), | |
| check=False, | |
| env=env, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| timeout=timeout_seconds, | |
| ) | |
| if completed.returncode != 0: | |
| tail = "\n".join((completed.stdout or "").splitlines()[-30:]) | |
| raise RuntimeError(f"Dry-run failed with exit code {completed.returncode}.\n{tail}") | |
| return "Optional training dry-run passed." | |
| def run_checks(checks: List[tuple[str, Callable[[], str]]]) -> List[CheckResult]: | |
| out: List[CheckResult] = [] | |
| for name, fn in checks: | |
| try: | |
| detail = fn() | |
| out.append(CheckResult(name=name, ok=True, detail=detail)) | |
| except Exception as exc: | |
| out.append(CheckResult(name=name, ok=False, detail=f"{type(exc).__name__}: {exc}")) | |
| return out | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run production preflight checks for the Space trainer.") | |
| parser.add_argument( | |
| "--run-training-dry-run", | |
| action="store_true", | |
| help="Also execute scripts/train_sota.py in --dry-run mode (stage 1 only).", | |
| ) | |
| parser.add_argument( | |
| "--dry-run-timeout-seconds", | |
| type=int, | |
| default=1800, | |
| help="Timeout for optional training dry-run step.", | |
| ) | |
| parser.add_argument( | |
| "--json", | |
| action="store_true", | |
| help="Print machine-readable JSON output.", | |
| ) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| checks: List[tuple[str, Callable[[], str]]] = [ | |
| ("required_files", check_required_files), | |
| ("config_shape", check_config_shape), | |
| ("python_imports", check_python_imports), | |
| ("module_integrity", check_module_integrity), | |
| ] | |
| if args.run_training_dry_run: | |
| checks.append( | |
| ( | |
| "training_dry_run", | |
| lambda: run_optional_training_dry_run(timeout_seconds=max(30, args.dry_run_timeout_seconds)), | |
| ) | |
| ) | |
| results = run_checks(checks) | |
| ok = all(item.ok for item in results) | |
| payload: Dict[str, Any] = { | |
| "ok": ok, | |
| "checks": [{"name": item.name, "ok": item.ok, "detail": item.detail} for item in results], | |
| } | |
| if args.json: | |
| print(json.dumps(payload, ensure_ascii=True, indent=2)) | |
| else: | |
| for item in results: | |
| status = "PASS" if item.ok else "FAIL" | |
| print(f"[{status}] {item.name}: {item.detail}") | |
| print("Overall:", "PASS" if ok else "FAIL") | |
| if not ok: | |
| raise SystemExit(1) | |
| if __name__ == "__main__": | |
| main() | |