Spaces:
Runtime error
Runtime error
| # Copyright 2024-present, David Berenstein, Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import io | |
| import os | |
| import random | |
| import time | |
| import requests | |
| from PIL import Image | |
| from dataset_viber import AnnotatorInterFace | |
| HF_TOKEN = os.environ["HF_TOKEN"] | |
| HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| DATASET_SERVER_URL = "/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%26quot%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END --> | |
| DATASET_NAME = "poloclub%2Fdiffusiondb&config=2m_random_1k&split=train" | |
| MODEL_URL = ( | |
| "/static-proxy?url=https%3A%2F%2Fapi-inference.huggingface.co%2Fmodels%2Fblack-forest-labs%2FFLUX.1-schnell%26quot%3B%3C%2Fspan%3E%3C!-- HTML_TAG_END --> | |
| ) | |
| def retrieve_sample(idx): | |
| api_url = f"{DATASET_SERVER_URL}/rows?dataset={DATASET_NAME}&offset={idx}&length=1" | |
| response = requests.get(api_url, headers=HEADERS) | |
| data = response.json() | |
| img_url = data["rows"][0]["row"]["image"]["src"] | |
| prompt = data["rows"][0]["row"]["prompt"] | |
| return img_url, prompt | |
| def get_rows(): | |
| api_url = f"{DATASET_SERVER_URL}/size?dataset={DATASET_NAME}" | |
| response = requests.get(api_url, headers=HEADERS) | |
| num_rows = response.json()["size"]["config"]["num_rows"] | |
| return num_rows | |
| def generate_response(prompt): | |
| def _get_response(prompt): | |
| payload = { | |
| "inputs": prompt, | |
| } | |
| response = requests.post(MODEL_URL, headers=HEADERS, json=payload) | |
| if response.status_code != 200: | |
| time.sleep(10) | |
| return _get_response(prompt) | |
| return response | |
| response = _get_response(prompt) | |
| image = Image.open(io.BytesIO(response.content)) | |
| return image | |
| def next_input(_prompt, _completion_a, _completion_b): | |
| random_idx = random.randint(0, get_rows()) - 1 | |
| img_url, prompt = retrieve_sample(random_idx) | |
| generated_image = generate_response(prompt) | |
| return (prompt, img_url, generated_image) | |
| if __name__ == "__main__": | |
| interface = AnnotatorInterFace.for_image_generation_preference( | |
| interactive=False, fn_next_input=next_input | |
| ) | |
| interface.launch() | |