180 lines
7.2 KiB
Python
180 lines
7.2 KiB
Python
import io
|
|
import json
|
|
import uuid
|
|
import requests
|
|
import websocket
|
|
from typing import Literal
|
|
|
|
def open_websocket_connection(server_address):
|
|
try:
|
|
client_id = str(uuid.uuid4())
|
|
ws = websocket.WebSocket()
|
|
ws.connect(f"ws://{server_address}/ws?clientId={client_id}")
|
|
return ws, client_id
|
|
except Exception as e:
|
|
print(f"Error open websocket connection: {e}")
|
|
|
|
def get_embeddings(server_address):
|
|
response = requests.get(f"http://{server_address}/embeddings")
|
|
return response.json()
|
|
|
|
def get_models_types(server_address):
|
|
response = requests.get(f"http://{server_address}/models")
|
|
return response.json()
|
|
|
|
def get_models_folder(server_address, folder):
|
|
response = requests.get(f"http://{server_address}/models/{folder}")
|
|
return response.json()
|
|
|
|
def get_extensions(server_address):
|
|
response = requests.get(f"http://{server_address}/extensions")
|
|
return response.json()
|
|
|
|
def upload_image_file(server_address, file, subfolder=None, overwrite=False,
|
|
type: Literal["input", "temp", "output"] = "input"):
|
|
try:
|
|
files = {"image": file}
|
|
data = {"type": type, "subfolder": subfolder, "overwrite": overwrite}
|
|
response = requests.post(f"http://{server_address}/upload/image", files=files, data=data)
|
|
response.raise_for_status()
|
|
file_info = response.json()
|
|
file_path = file_info["name"]
|
|
if file_info["subfolder"]:
|
|
file_path = file_info["subfolder"] + "/" + file_path
|
|
return file_path
|
|
except Exception as e:
|
|
print(f"Error upload image file: {e}")
|
|
|
|
def upload_mask_file(server_address, file, subfolder=None, overwrite=False,
|
|
type: Literal["input", "temp", "output"] = "input"):
|
|
try:
|
|
files = {"image": file}
|
|
data = {"type": type, "subfolder": subfolder, "overwrite": overwrite}
|
|
response = requests.post(f"http://{server_address}/upload/mask", files=files, data=data)
|
|
response.raise_for_status()
|
|
file_info = response.json()
|
|
file_path = file_info["name"]
|
|
if file_info["subfolder"]:
|
|
file_path = file_info["subfolder"] + "/" + file_path
|
|
return file_path
|
|
except Exception as e:
|
|
print(f"Error upload mask file: {e}")
|
|
|
|
def get_image_file(server_address, filename, subfolder=None, preview=None, channel=None,
|
|
type: Literal["input", "temp", "output"] = "output"):
|
|
try:
|
|
params = {"filename": filename, "type": type}
|
|
if subfolder:
|
|
params["subfolder"] = subfolder
|
|
if preview:
|
|
params["preview"] = preview
|
|
if channel:
|
|
params["channel"] = channel
|
|
response = requests.get(f"http://{server_address}/view", params=params)
|
|
response.raise_for_status()
|
|
image_file = io.BytesIO()
|
|
image_file.name = filename
|
|
image_file.write(response.content)
|
|
image_file.seek(0)
|
|
return image_file
|
|
except Exception as e:
|
|
print(f"Error get iamge file: {e}")
|
|
|
|
def get_images_files(server_address, prompt_id, download_preview=False):
|
|
hisstory = get_history_prompt(server_address, prompt_id)[prompt_id]
|
|
images_files = []
|
|
for node_id in hisstory["outputs"]:
|
|
node_output = hisstory["outputs"][node_id]
|
|
if "images" in node_output:
|
|
for image in node_output["images"]:
|
|
if image["type"] == "output" or (image["type"] == "temp" and download_preview):
|
|
image_file = get_image_file(server_address, image["filename"], image["subfolder"], type=image["type"])
|
|
images_files.append(image_file)
|
|
return images_files
|
|
|
|
def get_metadata(server_address, folder_name, filename=".safetensors"):
|
|
params = {"filename": filename}
|
|
response = requests.get(f"http://{server_address}/view_metadata/{folder_name}", params=params)
|
|
return response.json()
|
|
|
|
def get_system_stats(server_address):
|
|
response = requests.get(f"http://{server_address}/system_stats")
|
|
return response.json()
|
|
|
|
def get_prompt(server_address):
|
|
response = requests.get(f"http://{server_address}/prompt")
|
|
return response.json()
|
|
|
|
def get_object_info(server_address):
|
|
response = requests.get(f"http://{server_address}/object_info")
|
|
return response.json()
|
|
|
|
def get_object_info_node(server_address, node_class):
|
|
response = requests.get(f"http://{server_address}/object_info/{node_class}")
|
|
return response.json()
|
|
|
|
def get_history(server_address, max_items=None):
|
|
params = {"max_items": max_items}
|
|
response = requests.get(f"http://{server_address}/history", params=params)
|
|
return response.json()
|
|
|
|
def get_history_prompt(server_address, prompt_id):
|
|
response = requests.get(f"http://{server_address}/history/{prompt_id}")
|
|
return response.json()
|
|
|
|
def get_queue(server_address):
|
|
response = requests.get(f"http://{server_address}/queue")
|
|
return response.json()
|
|
|
|
def queue_prompt(server_address, prompt, client_id=None):
|
|
json = {"prompt": prompt}
|
|
if client_id:
|
|
json["client_id"] = client_id
|
|
response = requests.post(f"http://{server_address}/prompt", json=json)
|
|
return response.json()
|
|
|
|
def queue_clear_or_delete(server_address, clear=False, delete_prompt_id=None):
|
|
json = {"clear": clear}
|
|
if delete_prompt_id: # 删除指定队列
|
|
json["delete"] = delete_prompt_id
|
|
return requests.post(f"http://{server_address}/queue", json=json)
|
|
|
|
def queue_interrupt(server_address):
|
|
return requests.post(f"http://{server_address}/interrupt")
|
|
|
|
def queue_free(server_address, unload_models=False, free_memory=False):
|
|
json = {"unload_models": unload_models, "free_memory": free_memory}
|
|
return requests.post(f"http://{server_address}/free", json=json)
|
|
|
|
def history_clear_or_delete(server_address, clear=False, delete_prompt_id=None):
|
|
json = {"clear": clear}
|
|
if delete_prompt_id: # 删除历史记录
|
|
json["delete"] = delete_prompt_id
|
|
return requests.post(f"http://{server_address}/history", json=json)
|
|
|
|
def track_progress(ws, prompt, prompt_id):
|
|
node_ids = list(prompt.keys())
|
|
finished_nodes = []
|
|
while True:
|
|
message = ws.recv()
|
|
if isinstance(message, str):
|
|
message = json.loads(message)
|
|
if message["type"] == "progress":
|
|
step = message["data"]["value"]
|
|
max_step = message["data"]["max"]
|
|
print(f"K-Sampler Progress: Step {step} of {max_step}")
|
|
elif message["type"] == "execution_cached":
|
|
node_id = message["data"]["nodes"]
|
|
if node_id not in finished_nodes:
|
|
finished_nodes.append(node_id)
|
|
print(f"Total Progress: Tasks completed {len(finished_nodes)}/{len(node_ids)}")
|
|
if node_id is None and message["data"]["prompt_id"] == prompt_id:
|
|
break
|
|
elif message["type"] == "executing":
|
|
node_id = message["data"]["node"]
|
|
if node_id not in finished_nodes:
|
|
finished_nodes.append(node_id)
|
|
print(f"Total Progress: Tasks completed {len(finished_nodes)}/{len(node_ids)}")
|
|
if node_id is None and message["data"]["prompt_id"] == prompt_id:
|
|
break
|
|
return |