use std::{any::Any, collections::HashMap};
use dyn_clone::DynClone;
use serde::{Deserialize, Serialize};
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
pub struct Prompt {
#[serde(flatten)]
pub workflow: HashMap<String, NodeOrUnknown>,
}
impl Prompt {
pub fn get_node_by_id(&self, id: &str) -> Option<&dyn Node> {
match self.workflow.get(id) {
Some(NodeOrUnknown::Node(node)) => Some(node.as_ref()),
Some(NodeOrUnknown::GenericNode(node)) => Some(node),
_ => None,
}
}
pub fn get_node_by_id_mut(&mut self, id: &str) -> Option<&mut dyn Node> {
match self.workflow.get_mut(id) {
Some(NodeOrUnknown::Node(node)) => Some(node.as_mut()),
Some(NodeOrUnknown::GenericNode(node)) => Some(node),
_ => None,
}
}
pub fn get_nodes_by_type<T: Node + 'static>(&self) -> impl Iterator<Item = (&str, &T)> {
self.workflow.iter().filter_map(|(key, node)| match node {
NodeOrUnknown::Node(node) => as_node::<T>(node.as_ref()).map(|n| (key.as_str(), n)),
NodeOrUnknown::GenericNode(node) => as_node::<T>(node).map(|n| (key.as_str(), n)),
})
}
pub fn get_nodes_by_type_mut<T: Node + 'static>(
&mut self,
) -> impl Iterator<Item = (&str, &mut T)> {
self.workflow
.iter_mut()
.filter_map(|(key, node)| match node {
NodeOrUnknown::Node(node) => {
as_node_mut::<T>(node.as_mut()).map(|n| (key.as_str(), n))
}
NodeOrUnknown::GenericNode(node) => {
as_node_mut::<T>(node).map(|n| (key.as_str(), n))
}
})
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum NodeOrUnknown {
Node(Box<dyn Node>),
GenericNode(GenericNode),
}
impl<T: Any> AsAny for T {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
}
pub trait AsAny {
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
pub fn as_node<T: Node + 'static>(node: &dyn Node) -> Option<&T> {
node.as_any().downcast_ref::<T>()
}
pub fn as_node_mut<T: Node + 'static>(node: &mut dyn Node) -> Option<&mut T> {
node.as_any_mut().downcast_mut::<T>()
}
dyn_clone::clone_trait_object!(Node);
#[typetag::serde(tag = "class_type", content = "inputs")]
pub trait Node: std::fmt::Debug + Send + Sync + AsAny + DynClone {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_>;
fn name(&self) -> &str {
self.typetag_name()
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Meta {
pub title: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct GenericNode {
pub class_type: String,
pub inputs: HashMap<String, GenericValue>,
#[serde(rename = "_meta")]
pub meta: Option<Meta>,
}
#[typetag::serde]
impl Node for GenericNode {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new(self.inputs.values().filter_map(|input| input.node_id()))
}
fn name(&self) -> &str {
&self.class_type
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum GenericValue {
Bool(bool),
Int(i64),
Float(f32),
String(String),
NodeConnection(NodeConnection),
}
impl GenericValue {
pub fn node_id(&self) -> Option<&str> {
match self {
GenericValue::NodeConnection(node_connection) => Some(&node_connection.node_id),
_ => None,
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(from = "(String, u32)")]
#[serde(into = "(String, u32)")]
pub struct NodeConnection {
pub node_id: String,
pub output_index: u32,
}
impl From<(String, u32)> for NodeConnection {
fn from((node_id, output_index): (String, u32)) -> Self {
Self {
node_id,
output_index,
}
}
}
impl From<NodeConnection> for (String, u32) {
fn from(
NodeConnection {
node_id,
output_index,
}: NodeConnection,
) -> Self {
(node_id, output_index)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
pub enum Input<T> {
NodeConnection(NodeConnection),
Value(T),
}
impl<T> Input<T> {
pub fn value(&self) -> Option<&T> {
match self {
Input::NodeConnection(_) => None,
Input::Value(value) => Some(value),
}
}
pub fn value_mut(&mut self) -> Option<&mut T> {
match self {
Input::NodeConnection(_) => None,
Input::Value(value) => Some(value),
}
}
pub fn node_connection(&self) -> Option<&NodeConnection> {
match self {
Input::NodeConnection(node_connection) => Some(node_connection),
Input::Value(_) => None,
}
}
pub fn node_id(&self) -> Option<&str> {
self.node_connection()
.map(|node_connection| node_connection.node_id.as_str())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KSampler {
pub cfg: Input<f32>,
pub denoise: Input<f32>,
pub sampler_name: Input<String>,
pub scheduler: Input<String>,
pub seed: Input<i64>,
pub steps: Input<u32>,
pub positive: NodeConnection,
pub negative: NodeConnection,
pub model: NodeConnection,
pub latent_image: NodeConnection,
}
#[typetag::serde]
impl Node for KSampler {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
let inputs = [
self.cfg.node_id(),
self.denoise.node_id(),
self.sampler_name.node_id(),
self.scheduler.node_id(),
self.seed.node_id(),
self.steps.node_id(),
]
.into_iter()
.flatten();
Box::new(inputs.chain([
self.positive.node_id.as_str(),
self.negative.node_id.as_str(),
self.model.node_id.as_str(),
self.latent_image.node_id.as_str(),
]))
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CLIPTextEncode {
pub text: Input<String>,
pub clip: NodeConnection,
}
#[typetag::serde]
impl Node for CLIPTextEncode {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new(
[self.text.node_id(), Some(self.clip.node_id.as_str())]
.into_iter()
.flatten(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct EmptyLatentImage {
pub batch_size: Input<u32>,
pub width: Input<u32>,
pub height: Input<u32>,
}
#[typetag::serde]
impl Node for EmptyLatentImage {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new(
[
self.batch_size.node_id(),
self.width.node_id(),
self.height.node_id(),
]
.into_iter()
.flatten(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CheckpointLoaderSimple {
pub ckpt_name: Input<String>,
}
#[typetag::serde]
impl Node for CheckpointLoaderSimple {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new([self.ckpt_name.node_id()].into_iter().flatten())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VAELoader {
pub vae_name: Input<String>,
}
#[typetag::serde]
impl Node for VAELoader {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new([self.vae_name.node_id()].into_iter().flatten())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VAEDecode {
pub samples: NodeConnection,
pub vae: NodeConnection,
}
#[typetag::serde]
impl Node for VAEDecode {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new([self.samples.node_id.as_str(), self.vae.node_id.as_str()].into_iter())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PreviewImage {
pub images: NodeConnection,
}
#[typetag::serde]
impl Node for PreviewImage {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new([self.images.node_id.as_str()].into_iter())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct KSamplerSelect {
pub sampler_name: Input<String>,
}
#[typetag::serde]
impl Node for KSamplerSelect {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new([self.sampler_name.node_id()].into_iter().flatten())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SamplerCustom {
pub add_noise: Input<bool>,
pub cfg: Input<f32>,
pub noise_seed: Input<i64>,
pub latent_image: NodeConnection,
pub model: NodeConnection,
pub positive: NodeConnection,
pub negative: NodeConnection,
pub sampler: NodeConnection,
pub sigmas: NodeConnection,
}
#[typetag::serde]
impl Node for SamplerCustom {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
let inputs = [
self.add_noise.node_id(),
self.cfg.node_id(),
self.noise_seed.node_id(),
]
.into_iter()
.flatten();
Box::new(inputs.chain([
self.latent_image.node_id.as_str(),
self.model.node_id.as_str(),
self.positive.node_id.as_str(),
self.negative.node_id.as_str(),
self.sampler.node_id.as_str(),
self.sigmas.node_id.as_str(),
]))
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SDTurboScheduler {
pub steps: Input<u32>,
pub model: NodeConnection,
}
#[typetag::serde]
impl Node for SDTurboScheduler {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new(
[self.steps.node_id(), Some(self.model.node_id.as_str())]
.into_iter()
.flatten(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImageOnlyCheckpointLoader {
pub ckpt_name: Input<String>,
}
#[typetag::serde]
impl Node for ImageOnlyCheckpointLoader {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new([self.ckpt_name.node_id()].into_iter().flatten())
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LoadImage {
pub upload: Input<String>,
pub image: Input<String>,
}
#[typetag::serde]
impl Node for LoadImage {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new(
[self.upload.node_id(), self.image.node_id()]
.into_iter()
.flatten(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename = "SVD_img2vid_Conditioning")]
pub struct SVDimg2vidConditioning {
pub augmentation_level: Input<f32>,
pub fps: Input<u32>,
pub width: Input<u32>,
pub height: Input<u32>,
pub motion_bucket_id: Input<u32>,
pub video_frames: Input<u32>,
pub clip_vision: NodeConnection,
pub init_image: NodeConnection,
pub vae: NodeConnection,
}
#[typetag::serde]
impl Node for SVDimg2vidConditioning {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
let inputs = [
self.augmentation_level.node_id(),
self.fps.node_id(),
self.width.node_id(),
self.height.node_id(),
self.motion_bucket_id.node_id(),
self.video_frames.node_id(),
]
.into_iter()
.flatten();
Box::new(inputs.chain([
self.clip_vision.node_id.as_str(),
self.init_image.node_id.as_str(),
self.vae.node_id.as_str(),
]))
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct VideoLinearCFGGuidance {
pub min_cfg: Input<f32>,
pub model: NodeConnection,
}
#[typetag::serde]
impl Node for VideoLinearCFGGuidance {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new(
[self.min_cfg.node_id(), Some(self.model.node_id.as_str())]
.into_iter()
.flatten(),
)
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SaveAnimatedWEBP {
pub filename_prefix: Input<String>,
pub fps: Input<u32>,
pub lossless: Input<bool>,
pub method: Input<String>,
pub quality: Input<u32>,
pub images: NodeConnection,
}
#[typetag::serde]
impl Node for SaveAnimatedWEBP {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
let inputs = [
self.filename_prefix.node_id(),
self.fps.node_id(),
self.lossless.node_id(),
self.method.node_id(),
self.quality.node_id(),
]
.into_iter()
.flatten();
Box::new(inputs.chain([self.images.node_id.as_str()]))
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LoraLoader {
pub lora_name: Input<String>,
pub strength_model: Input<f32>,
pub strength_clip: Input<f32>,
pub model: NodeConnection,
pub clip: NodeConnection,
}
#[typetag::serde]
impl Node for LoraLoader {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
let inputs = [
self.lora_name.node_id(),
self.strength_model.node_id(),
self.strength_clip.node_id(),
]
.into_iter()
.flatten();
Box::new(inputs.chain([self.model.node_id.as_str(), self.clip.node_id.as_str()]))
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ModelSamplingDiscrete {
pub sampling: Input<String>,
pub zsnr: Input<bool>,
pub model: NodeConnection,
}
#[typetag::serde]
impl Node for ModelSamplingDiscrete {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
let inputs = [self.sampling.node_id(), self.zsnr.node_id()]
.into_iter()
.flatten();
Box::new(inputs.chain([self.model.node_id.as_str()]))
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SaveImage {
pub filename_prefix: Input<String>,
pub images: NodeConnection,
}
#[typetag::serde]
impl Node for SaveImage {
fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
Box::new(
[
self.filename_prefix.node_id(),
Some(self.images.node_id.as_str()),
]
.into_iter()
.flatten(),
)
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Response {
pub prompt_id: uuid::Uuid,
pub number: u64,
pub node_errors: HashMap<String, serde_json::Value>,
}