CodeNight-KI/text2image/text2image.py

135 lines
5.0 KiB
Python

import asyncio
import json
import random
import requests
import websocket
import os
from pydantic import BaseModel
from typing import Optional
class Nodes(BaseModel):
prompt: int
width: int
height: int
negative_prompt: Optional[int] = None
seed: Optional[int] = None
steps: Optional[int] = None
class ComfyUIGenerateImageForm(BaseModel):
workflow: str
prompt: str
nodes: Nodes
negative_prompt: Optional[str] = None
width: int
height: int
steps: Optional[int] = None
seed: Optional[int] = None
class GenerateImage:
def __init__(self, payload, client_id, base_url, image_file_folder, image_file_name):
self.payload = payload
self.client_id = client_id
self.base_url = base_url
self.image_file_folder = image_file_folder
self.image_file_name = image_file_name
asyncio.run(self.__generate())
def save_image(self, image_url):
"""Lädt ein Bild von einer URL herunter und speichert es im angegebenen Ordner mit dem gewünschten Dateinamen."""
response = requests.get(image_url, stream=True)
if response.status_code == 200:
os.makedirs(self.image_file_folder, exist_ok=True) # Erstelle den Ordner, falls er nicht existiert
file_path = os.path.join(self.image_file_folder, self.image_file_name)
with open(file_path, "wb") as file:
for chunk in response.iter_content(1024):
file.write(chunk)
print(f"Bild gespeichert unter: {file_path}")
else:
print(f"Fehler beim Download des Bildes: {response.status_code}")
def get_image_url(self, filename, subfolder, img_type):
return f"{self.base_url}/view?filename={filename}&subfolder={subfolder}&type={img_type}"
def queue_prompt(self):
response = requests.post(
f"{self.base_url}/prompt",
json={"prompt": self.workflow, "client_id": self.client_id},
headers={"Authorization": f"Bearer "},
)
return response.json()
def get_history(self, prompt_id):
response = requests.get(
f"{self.base_url}/history/{prompt_id}",
headers={"Authorization": f"Bearer "},
)
return response.json()
def get_images(self, ws):
prompt_id = self.queue_prompt()["prompt_id"]
output_images = []
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = self.get_history(prompt_id)[prompt_id]
for node_id, node_output in history["outputs"].items():
if "images" in node_output:
for image in node_output["images"]:
url = self.get_image_url(
image["filename"], image["subfolder"], image["type"]
)
output_images.append({"url": url})
return {"data": output_images}
async def comfyui_generate_image(self):
ws_url = self.base_url.replace("http://", "ws://").replace("https://", "wss://")
self.workflow = json.loads(self.payload.workflow)
self.workflow[f"{self.payload.nodes.prompt}"]["inputs"]["text"] = self.payload.prompt
self.workflow[f"{self.payload.nodes.width}"]["inputs"]["width"] = self.payload.width
self.workflow[f"{self.payload.nodes.height}"]["inputs"]["height"] = self.payload.height
if self.payload.seed and self.payload.nodes.seed:
self.workflow[f"{self.payload.nodes.seed}"]["inputs"]["seed"] = self.payload.seed
if self.payload.steps and self.payload.nodes.steps:
self.workflow[f"{self.payload.nodes.steps}"]["inputs"]["steps"] = self.payload.steps
if self.payload.negative_prompt and self.payload.nodes.negative_prompt:
self.workflow[f"{self.payload.nodes.negative_prompt}"]["inputs"]["text"] = self.payload.negative_prompt
try:
ws = websocket.WebSocket()
headers = {"Authorization": f"Bearer "}
ws.connect(f"{ws_url}/ws?clientId={self.client_id}", header=headers)
except Exception as e:
return None
try:
images = await asyncio.to_thread(self.get_images, ws)
except Exception as e:
images = None
finally:
ws.close()
return images
async def __generate(self):
images = await self.comfyui_generate_image()
if images and "data" in images and images["data"]:
image_url = images["data"][0]["url"]
self.save_image(image_url)
def load_workflow(file_path):
with open(file_path, "r", encoding="utf-8") as file:
return file.read()