diff --git a/api_client.py b/api_client.py new file mode 100644 index 0000000..44b44b2 --- /dev/null +++ b/api_client.py @@ -0,0 +1,180 @@ +import io +import json +import uuid +import requests +import websocket +from typing import Literal + +def open_websocket_connection(server_address): + try: + client_id = str(uuid.uuid4()) + ws = websocket.WebSocket() + ws.connect(f"ws://{server_address}/ws?clientId={client_id}") + return ws, client_id + except Exception as e: + print(f"Error open websocket connection: {e}") + +def get_embeddings(server_address): + response = requests.get(f"http://{server_address}/embeddings") + return response.json() + +def get_models_types(server_address): + response = requests.get(f"http://{server_address}/models") + return response.json() + +def get_models_folder(server_address, folder): + response = requests.get(f"http://{server_address}/models/{folder}") + return response.json() + +def get_extensions(server_address): + response = requests.get(f"http://{server_address}/extensions") + return response.json() + +def upload_image_file(server_address, file, subfolder=None, overwrite=False, + type: Literal["input", "temp", "output"] = "input"): + try: + files = {"image": file} + data = {"type": type, "subfolder": subfolder, "overwrite": overwrite} + response = requests.post(f"http://{server_address}/upload/image", files=files, data=data) + response.raise_for_status() + file_info = response.json() + file_path = file_info["name"] + if file_info["subfolder"]: + file_path = file_info["subfolder"] + "/" + file_path + return file_path + except Exception as e: + print(f"Error upload image file: {e}") + +def upload_mask_file(server_address, file, subfolder=None, overwrite=False, + type: Literal["input", "temp", "output"] = "input"): + try: + files = {"image": file} + data = {"type": type, "subfolder": subfolder, "overwrite": overwrite} + response = requests.post(f"http://{server_address}/upload/mask", files=files, data=data) + response.raise_for_status() + file_info = response.json() + file_path = file_info["name"] + if file_info["subfolder"]: + file_path = file_info["subfolder"] + "/" + file_path + return file_path + except Exception as e: + print(f"Error upload mask file: {e}") + +def get_image_file(server_address, filename, subfolder=None, preview=None, channel=None, + type: Literal["input", "temp", "output"] = "output"): + try: + params = {"filename": filename, "type": type} + if subfolder: + params["subfolder"] = subfolder + if preview: + params["preview"] = preview + if channel: + params["channel"] = channel + response = requests.get(f"http://{server_address}/view", params=params) + response.raise_for_status() + image_file = io.BytesIO() + image_file.name = filename + image_file.write(response.content) + image_file.seek(0) + return image_file + except Exception as e: + print(f"Error get iamge file: {e}") + +def get_images_files(server_address, prompt_id, download_preview=False): + hisstory = get_history_prompt(server_address, prompt_id)[prompt_id] + images_files = [] + for node_id in hisstory["outputs"]: + node_output = hisstory["outputs"][node_id] + if "images" in node_output: + for image in node_output["images"]: + if image["type"] == "output" or (image["type"] == "temp" and download_preview): + image_file = get_image_file(server_address, image["filename"], image["subfolder"], type=image["type"]) + images_files.append(image_file) + return images_files + +def get_metadata(server_address, folder_name, filename=".safetensors"): + params = {"filename": filename} + response = requests.get(f"http://{server_address}/view_metadata/{folder_name}", params=params) + return response.json() + +def get_system_stats(server_address): + response = requests.get(f"http://{server_address}/system_stats") + return response.json() + +def get_prompt(server_address): + response = requests.get(f"http://{server_address}/prompt") + return response.json() + +def get_object_info(server_address): + response = requests.get(f"http://{server_address}/object_info") + return response.json() + +def get_object_info_node(server_address, node_class): + response = requests.get(f"http://{server_address}/object_info/{node_class}") + return response.json() + +def get_history(server_address, max_items=None): + params = {"max_items": max_items} + response = requests.get(f"http://{server_address}/history", params=params) + return response.json() + +def get_history_prompt(server_address, prompt_id): + response = requests.get(f"http://{server_address}/history/{prompt_id}") + return response.json() + +def get_queue(server_address): + response = requests.get(f"http://{server_address}/queue") + return response.json() + +def queue_prompt(server_address, prompt, client_id=None): + json = {"prompt": prompt} + if client_id: + json["client_id"] = client_id + response = requests.post(f"http://{server_address}/prompt", json=json) + return response.json() + +def queue_clear_or_delete(server_address, clear=False, delete_prompt_id=None): + json = {"clear": clear} + if delete_prompt_id: # 删除指定队列 + json["delete"] = delete_prompt_id + return requests.post(f"http://{server_address}/queue", json=json) + +def queue_interrupt(server_address): + return requests.post(f"http://{server_address}/interrupt") + +def queue_free(server_address, unload_models=False, free_memory=False): + json = {"unload_models": unload_models, "free_memory": free_memory} + return requests.post(f"http://{server_address}/free", json=json) + +def history_clear_or_delete(server_address, clear=False, delete_prompt_id=None): + json = {"clear": clear} + if delete_prompt_id: # 删除历史记录 + json["delete"] = delete_prompt_id + return requests.post(f"http://{server_address}/history", json=json) + +def track_progress(ws, prompt, prompt_id): + node_ids = list(prompt.keys()) + finished_nodes = [] + while True: + message = ws.recv() + if isinstance(message, str): + message = json.loads(message) + if message["type"] == "progress": + step = message["data"]["value"] + max_step = message["data"]["max"] + print(f"K-Sampler Progress: Step {step} of {max_step}") + elif message["type"] == "execution_cached": + node_id = message["data"]["nodes"] + if node_id not in finished_nodes: + finished_nodes.append(node_id) + print(f"Total Progress: Tasks completed {len(finished_nodes)}/{len(node_ids)}") + if node_id is None and message["data"]["prompt_id"] == prompt_id: + break + elif message["type"] == "executing": + node_id = message["data"]["node"] + if node_id not in finished_nodes: + finished_nodes.append(node_id) + print(f"Total Progress: Tasks completed {len(finished_nodes)}/{len(node_ids)}") + if node_id is None and message["data"]["prompt_id"] == prompt_id: + break + return \ No newline at end of file diff --git a/img2vid.py b/img2vid.py new file mode 100644 index 0000000..9c447e9 --- /dev/null +++ b/img2vid.py @@ -0,0 +1,57 @@ +import io +import json +import random +from api_client import * +from pydantic import BaseModel +from fastapi import FastAPI, UploadFile, File, Form, responses + +API_URL = "127.0.0.1:8181" + +with open("prompt/img2vid.json", "r") as f: + prompt = json.load(f) + +class Img2vidParams(BaseModel): + file: UploadFile = File(...) # 图片文件 + width: int = 512 + height: int = 512 + video_frames: int = 25 # 视频帧数 + motion_bucket_id: int = 100 # 视频动作量 + fps: int = 8 # 视频流畅度 + seed: int = None + steps: int = 20 + cfg: float = 2.0 + save_fps: int = 8 # 视频帧率 + +app = FastAPI() +@app.post("/img2vid") +async def img2vid(params: Img2vidParams = Form(...)): + image_file = io.BytesIO() + image_file.name = params.file.filename + image_file.write(await params.file.read()) + image_file.seek(0) + image_path = upload_image_file(API_URL, image_file) + prompt["4"]["inputs"]["image"] = image_path + + prompt["3"]["inputs"]["width"] = params.width + prompt["3"]["inputs"]["height"] = params.height + prompt["3"]["inputs"]["video_frames"] = params.video_frames + prompt["3"]["inputs"]["motion_bucket_id"] = params.motion_bucket_id + prompt["3"]["inputs"]["fps"] = params.fps + prompt["5"]["inputs"]["seed"] = random.randint(0, 1e16) + if params.seed is not None: + prompt["5"]["inputs"]["seed"] = params.seed + prompt["5"]["inputs"]["steps"] = params.steps + prompt["5"]["inputs"]["cfg"] = params.cfg + prompt["7"]["inputs"]["save_fps"] = params.save_fps + + ws, client_id = open_websocket_connection(API_URL) + print("client_id: ", client_id) + response = queue_prompt(API_URL, prompt, client_id) + prompt_id = response["prompt_id"] + print("prompt_id: ", prompt_id) + track_progress(ws, prompt, prompt_id) + outputs = get_images_files(API_URL, prompt_id) + return responses.Response(content=outputs[0].read()) + +import uvicorn +uvicorn.run(app, host="0.0.0.0", port=8182) diff --git a/prompt/img2vid.json b/prompt/img2vid.json new file mode 100644 index 0000000..2e9d6ae --- /dev/null +++ b/prompt/img2vid.json @@ -0,0 +1,123 @@ +{ + "1": { + "inputs": { + "ckpt_name": "svd_xt_image_decoder.safetensors" + }, + "class_type": "ImageOnlyCheckpointLoader", + "_meta": { + "title": "Image Only Checkpoint Loader (img2vid model)" + } + }, + "2": { + "inputs": { + "min_cfg": 1, + "model": [ + "1", + 0 + ] + }, + "class_type": "VideoLinearCFGGuidance", + "_meta": { + "title": "VideoLinearCFGGuidance" + } + }, + "3": { + "inputs": { + "width": 512, + "height": 512, + "video_frames": 25, + "motion_bucket_id": 100, + "fps": 8, + "augmentation_level": 0, + "clip_vision": [ + "1", + 1 + ], + "init_image": [ + "4", + 0 + ], + "vae": [ + "1", + 2 + ] + }, + "class_type": "SVD_img2vid_Conditioning", + "_meta": { + "title": "SVD_img2vid_Conditioning" + } + }, + "4": { + "inputs": { + "image": "chenglong.png", + "upload": "image" + }, + "class_type": "LoadImage", + "_meta": { + "title": "Load Image" + } + }, + "5": { + "inputs": { + "seed": 670980613132736, + "steps": 20, + "cfg": 2, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1, + "model": [ + "2", + 0 + ], + "positive": [ + "3", + 0 + ], + "negative": [ + "3", + 1 + ], + "latent_image": [ + "3", + 2 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "6": { + "inputs": { + "samples": [ + "5", + 0 + ], + "vae": [ + "1", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "7": { + "inputs": { + "filename_prefix": "ComfyUI", + "fps": 8, + "lossless": false, + "quality": 80, + "method": "default", + "images": [ + "6", + 0 + ] + }, + "class_type": "SaveAnimatedWEBP", + "_meta": { + "title": "SaveAnimatedWEBP" + } + } +} \ No newline at end of file diff --git a/prompt/text2img.json b/prompt/text2img.json new file mode 100644 index 0000000..087eb4d --- /dev/null +++ b/prompt/text2img.json @@ -0,0 +1,106 @@ +{ + "1": { + "inputs": { + "ckpt_name": "helloyoung25d_V15j.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "2": { + "inputs": { + "text": "a boy", + "clip": [ + "1", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "3": { + "inputs": { + "text": "text", + "clip": [ + "1", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "4": { + "inputs": { + "seed": 971044526173875, + "steps": 20, + "cfg": 8, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1, + "model": [ + "1", + 0 + ], + "positive": [ + "2", + 0 + ], + "negative": [ + "3", + 0 + ], + "latent_image": [ + "5", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "5": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "6": { + "inputs": { + "samples": [ + "4", + 0 + ], + "vae": [ + "1", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "7": { + "inputs": { + "images": [ + "6", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + } +} \ No newline at end of file diff --git a/prompt/text2img_lora.json b/prompt/text2img_lora.json new file mode 100644 index 0000000..729cdbf --- /dev/null +++ b/prompt/text2img_lora.json @@ -0,0 +1,125 @@ +{ + "1": { + "inputs": { + "ckpt_name": "helloyoung25d_V15j.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "2": { + "inputs": { + "lora_name": "ClayAnimation.safetensors", + "strength_model": 1, + "strength_clip": 1, + "model": [ + "1", + 0 + ], + "clip": [ + "1", + 1 + ] + }, + "class_type": "LoraLoader", + "_meta": { + "title": "Load LoRA" + } + }, + "3": { + "inputs": { + "text": "a boy", + "clip": [ + "2", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "4": { + "inputs": { + "text": "text", + "clip": [ + "2", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "5": { + "inputs": { + "seed": 101018934678671, + "steps": 20, + "cfg": 8, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1, + "model": [ + "2", + 0 + ], + "positive": [ + "3", + 0 + ], + "negative": [ + "4", + 0 + ], + "latent_image": [ + "6", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "6": { + "inputs": { + "width": 512, + "height": 512, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "7": { + "inputs": { + "samples": [ + "5", + 0 + ], + "vae": [ + "1", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "8": { + "inputs": { + "images": [ + "7", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + } +} \ No newline at end of file diff --git a/test/input/test.png b/test/input/test.png new file mode 100644 index 0000000..4c85694 Binary files /dev/null and b/test/input/test.png differ diff --git a/test/input/test2.png b/test/input/test2.png new file mode 100644 index 0000000..547de0c Binary files /dev/null and b/test/input/test2.png differ diff --git a/test/output/test.png b/test/output/test.png new file mode 100644 index 0000000..bd8c51d Binary files /dev/null and b/test/output/test.png differ diff --git a/test/output/test.webp b/test/output/test.webp new file mode 100644 index 0000000..f238797 Binary files /dev/null and b/test/output/test.webp differ diff --git a/test/output/test2.png b/test/output/test2.png new file mode 100644 index 0000000..d24d1f7 Binary files /dev/null and b/test/output/test2.png differ diff --git a/test/output/test2.webp b/test/output/test2.webp new file mode 100644 index 0000000..e2a2a14 Binary files /dev/null and b/test/output/test2.webp differ diff --git a/test/test.ipynb b/test/test.ipynb new file mode 100644 index 0000000..62388b8 --- /dev/null +++ b/test/test.ipynb @@ -0,0 +1,56 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "data = {\n", + " \"width\": 512, \"height\": 512, \"video_frames\": 25, \"motion_bucket_id\": 100, \"fps\": 8, # 条件\n", + " \"seed\": None, \"steps\": 20, \"cfg\": 2.0, #采样\n", + " \"save_fps\": 8, # 保存\n", + "}\n", + "files = {\"file\": open(\"input/test.png\", \"rb\")}\n", + "response = requests.post(\"http://192.168.13.121:8182/img2vid\", files=files, data=data)\n", + "with open(\"output/test.webp\", \"wb\") as f:\n", + " f.write(response.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "json = {\"prompt\": \"a cute boy\", \"style\": \"Clay\", \"seed\": None, \"num\": 1}\n", + "response = requests.post(\"http://192.168.13.121:8183/text2img_lora\", json=json)\n", + "with open(\"output/test.png\", \"wb\") as f:\n", + " f.write(response.content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "diffusion", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/text2img_lora.py b/text2img_lora.py new file mode 100644 index 0000000..d0d2c10 --- /dev/null +++ b/text2img_lora.py @@ -0,0 +1,39 @@ +import json +import random +from api_client import * +from typing import Literal +from pydantic import BaseModel +from fastapi import FastAPI, Form, responses + +API_URL = "127.0.0.1:8181" + +with open("prompt/text2img_lora.json", "r") as f: + prompt = json.load(f) + +class Text2imgParams(BaseModel): + prompt: str = "a cute boy" + style: Literal["qban", "guofeng", "Clay", "Cyberpunk"] = "Clay" + seed: int = None + num: int = 1 + +app = FastAPI() +@app.post("/text2img_lora") +async def text2img(params: Text2imgParams = Form(...)): + prompt["2"]["inputs"]["lora_name"] = params.style + ".safetensors" + prompt["3"]["inputs"]["text"] = params.style + ", " + params.prompt + prompt["5"]["inputs"]["seed"] = random.randint(0, 1e16) + if params.seed is not None: + prompt["5"]["inputs"]["seed"] = params.seed + prompt["6"]["inputs"]["batch_size"] = params.num + + ws, client_id = open_websocket_connection(API_URL) + print("client_id: ", client_id) + response = queue_prompt(API_URL, prompt, client_id) + prompt_id = response["prompt_id"] + print("prompt_id: ", prompt_id) + track_progress(ws, prompt, prompt_id) + outputs = get_images_files(API_URL, prompt_id, download_preview=True) + return responses.Response(content=outputs[0].read()) + +import uvicorn +uvicorn.run(app, host="0.0.0.0", port=8183)