1use std::{any::Any, collections::HashMap};
2
3use dyn_clone::DynClone;
4use serde::{Deserialize, Serialize};
5
6#[derive(Default, Serialize, Deserialize, Debug, Clone)]
8pub struct Prompt {
9 #[serde(flatten)]
11 pub workflow: HashMap<String, NodeOrUnknown>,
12}
13
14impl Prompt {
15 pub fn get_node_by_id(&self, id: &str) -> Option<&dyn Node> {
16 match self.workflow.get(id) {
17 Some(NodeOrUnknown::Node(node)) => Some(node.as_ref()),
18 Some(NodeOrUnknown::GenericNode(node)) => Some(node),
19 _ => None,
20 }
21 }
22
23 pub fn get_node_by_id_mut(&mut self, id: &str) -> Option<&mut dyn Node> {
24 match self.workflow.get_mut(id) {
25 Some(NodeOrUnknown::Node(node)) => Some(node.as_mut()),
26 Some(NodeOrUnknown::GenericNode(node)) => Some(node),
27 _ => None,
28 }
29 }
30
31 pub fn get_nodes_by_type<T: Node + 'static>(&self) -> impl Iterator<Item = (&str, &T)> {
32 self.workflow.iter().filter_map(|(key, node)| match node {
33 NodeOrUnknown::Node(node) => as_node::<T>(node.as_ref()).map(|n| (key.as_str(), n)),
34 NodeOrUnknown::GenericNode(node) => as_node::<T>(node).map(|n| (key.as_str(), n)),
35 })
36 }
37
38 pub fn get_nodes_by_type_mut<T: Node + 'static>(
39 &mut self,
40 ) -> impl Iterator<Item = (&str, &mut T)> {
41 self.workflow
42 .iter_mut()
43 .filter_map(|(key, node)| match node {
44 NodeOrUnknown::Node(node) => {
45 as_node_mut::<T>(node.as_mut()).map(|n| (key.as_str(), n))
46 }
47 NodeOrUnknown::GenericNode(node) => {
48 as_node_mut::<T>(node).map(|n| (key.as_str(), n))
49 }
50 })
51 }
52}
53
54#[derive(Serialize, Deserialize, Debug, Clone)]
56#[serde(untagged)]
57pub enum NodeOrUnknown {
58 Node(Box<dyn Node>),
60 GenericNode(GenericNode),
62}
63
64impl<T: Any> AsAny for T {
65 fn as_any(&self) -> &dyn Any {
66 self
67 }
68
69 fn as_any_mut(&mut self) -> &mut dyn Any {
70 self
71 }
72}
73
74pub trait AsAny {
76 fn as_any(&self) -> &dyn Any;
78
79 fn as_any_mut(&mut self) -> &mut dyn Any;
80}
81
82pub fn as_node<T: Node + 'static>(node: &dyn Node) -> Option<&T> {
92 node.as_any().downcast_ref::<T>()
93}
94
95pub fn as_node_mut<T: Node + 'static>(node: &mut dyn Node) -> Option<&mut T> {
105 node.as_any_mut().downcast_mut::<T>()
106}
107
108dyn_clone::clone_trait_object!(Node);
109
110#[typetag::serde(tag = "class_type", content = "inputs")]
111pub trait Node: std::fmt::Debug + Send + Sync + AsAny + DynClone {
112 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_>;
113 fn name(&self) -> &str {
114 self.typetag_name()
115 }
116}
117
118#[derive(Serialize, Deserialize, Debug, Clone)]
120pub struct Meta {
121 pub title: String,
123}
124
125#[derive(Serialize, Deserialize, Debug, Clone)]
127pub struct GenericNode {
128 pub class_type: String,
130 pub inputs: HashMap<String, GenericValue>,
132 #[serde(rename = "_meta")]
134 pub meta: Option<Meta>,
135}
136
137#[typetag::serde]
138impl Node for GenericNode {
139 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
140 Box::new(self.inputs.values().filter_map(|input| input.node_id()))
141 }
142 fn name(&self) -> &str {
143 &self.class_type
144 }
145}
146
147#[derive(Serialize, Deserialize, Debug, Clone)]
149#[serde(untagged)]
150pub enum GenericValue {
151 Bool(bool),
153 Int(i64),
155 Float(f32),
157 String(String),
159 NodeConnection(NodeConnection),
161}
162
163impl GenericValue {
164 pub fn node_id(&self) -> Option<&str> {
166 match self {
167 GenericValue::NodeConnection(node_connection) => Some(&node_connection.node_id),
168 _ => None,
169 }
170 }
171}
172
173#[derive(Serialize, Deserialize, Debug, Clone)]
175#[serde(from = "(String, u32)")]
176#[serde(into = "(String, u32)")]
177pub struct NodeConnection {
178 pub node_id: String,
180 pub output_index: u32,
182}
183
184impl From<(String, u32)> for NodeConnection {
185 fn from((node_id, output_index): (String, u32)) -> Self {
186 Self {
187 node_id,
188 output_index,
189 }
190 }
191}
192
193impl From<NodeConnection> for (String, u32) {
194 fn from(
195 NodeConnection {
196 node_id,
197 output_index,
198 }: NodeConnection,
199 ) -> Self {
200 (node_id, output_index)
201 }
202}
203
204#[derive(Serialize, Deserialize, Debug, Clone)]
206#[serde(untagged)]
207pub enum Input<T> {
208 NodeConnection(NodeConnection),
210 Value(T),
212}
213
214impl<T> Input<T> {
215 pub fn value(&self) -> Option<&T> {
217 match self {
218 Input::NodeConnection(_) => None,
219 Input::Value(value) => Some(value),
220 }
221 }
222
223 pub fn value_mut(&mut self) -> Option<&mut T> {
225 match self {
226 Input::NodeConnection(_) => None,
227 Input::Value(value) => Some(value),
228 }
229 }
230
231 pub fn node_connection(&self) -> Option<&NodeConnection> {
233 match self {
234 Input::NodeConnection(node_connection) => Some(node_connection),
235 Input::Value(_) => None,
236 }
237 }
238
239 pub fn node_id(&self) -> Option<&str> {
241 self.node_connection()
242 .map(|node_connection| node_connection.node_id.as_str())
243 }
244}
245
246#[derive(Serialize, Deserialize, Debug, Clone)]
248pub struct KSampler {
249 pub cfg: Input<f32>,
251 pub denoise: Input<f32>,
253 pub sampler_name: Input<String>,
255 pub scheduler: Input<String>,
257 pub seed: Input<i64>,
259 pub steps: Input<u32>,
261 pub positive: NodeConnection,
263 pub negative: NodeConnection,
265 pub model: NodeConnection,
267 pub latent_image: NodeConnection,
269}
270
271#[typetag::serde]
272impl Node for KSampler {
273 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
274 let inputs = [
275 self.cfg.node_id(),
276 self.denoise.node_id(),
277 self.sampler_name.node_id(),
278 self.scheduler.node_id(),
279 self.seed.node_id(),
280 self.steps.node_id(),
281 ]
282 .into_iter()
283 .flatten();
284 Box::new(inputs.chain([
285 self.positive.node_id.as_str(),
286 self.negative.node_id.as_str(),
287 self.model.node_id.as_str(),
288 self.latent_image.node_id.as_str(),
289 ]))
290 }
291}
292
293#[derive(Serialize, Deserialize, Debug, Clone)]
295pub struct CLIPTextEncode {
296 pub text: Input<String>,
298 pub clip: NodeConnection,
300}
301
302#[typetag::serde]
303impl Node for CLIPTextEncode {
304 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
305 Box::new(
306 [self.text.node_id(), Some(self.clip.node_id.as_str())]
307 .into_iter()
308 .flatten(),
309 )
310 }
311}
312
313#[derive(Serialize, Deserialize, Debug, Clone)]
315pub struct EmptyLatentImage {
316 pub batch_size: Input<u32>,
318 pub width: Input<u32>,
320 pub height: Input<u32>,
322}
323
324#[typetag::serde]
325impl Node for EmptyLatentImage {
326 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
327 Box::new(
328 [
329 self.batch_size.node_id(),
330 self.width.node_id(),
331 self.height.node_id(),
332 ]
333 .into_iter()
334 .flatten(),
335 )
336 }
337}
338
339#[derive(Serialize, Deserialize, Debug, Clone)]
341pub struct CheckpointLoaderSimple {
342 pub ckpt_name: Input<String>,
344}
345
346#[typetag::serde]
347impl Node for CheckpointLoaderSimple {
348 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
349 Box::new([self.ckpt_name.node_id()].into_iter().flatten())
350 }
351}
352
353#[derive(Serialize, Deserialize, Debug, Clone)]
355pub struct VAELoader {
356 pub vae_name: Input<String>,
358}
359
360#[typetag::serde]
361impl Node for VAELoader {
362 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
363 Box::new([self.vae_name.node_id()].into_iter().flatten())
364 }
365}
366
367#[derive(Serialize, Deserialize, Debug, Clone)]
369pub struct VAEDecode {
370 pub samples: NodeConnection,
372 pub vae: NodeConnection,
374}
375
376#[typetag::serde]
377impl Node for VAEDecode {
378 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
379 Box::new([self.samples.node_id.as_str(), self.vae.node_id.as_str()].into_iter())
380 }
381}
382
383#[derive(Serialize, Deserialize, Debug, Clone)]
385pub struct PreviewImage {
386 pub images: NodeConnection,
388}
389
390#[typetag::serde]
391impl Node for PreviewImage {
392 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
393 Box::new([self.images.node_id.as_str()].into_iter())
394 }
395}
396
397#[derive(Serialize, Deserialize, Debug, Clone)]
399pub struct KSamplerSelect {
400 pub sampler_name: Input<String>,
402}
403
404#[typetag::serde]
405impl Node for KSamplerSelect {
406 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
407 Box::new([self.sampler_name.node_id()].into_iter().flatten())
408 }
409}
410
411#[derive(Serialize, Deserialize, Debug, Clone)]
413pub struct SamplerCustom {
414 pub add_noise: Input<bool>,
416 pub cfg: Input<f32>,
418 pub noise_seed: Input<i64>,
420 pub latent_image: NodeConnection,
422 pub model: NodeConnection,
424 pub positive: NodeConnection,
426 pub negative: NodeConnection,
428 pub sampler: NodeConnection,
430 pub sigmas: NodeConnection,
432}
433
434#[typetag::serde]
435impl Node for SamplerCustom {
436 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
437 let inputs = [
438 self.add_noise.node_id(),
439 self.cfg.node_id(),
440 self.noise_seed.node_id(),
441 ]
442 .into_iter()
443 .flatten();
444 Box::new(inputs.chain([
445 self.latent_image.node_id.as_str(),
446 self.model.node_id.as_str(),
447 self.positive.node_id.as_str(),
448 self.negative.node_id.as_str(),
449 self.sampler.node_id.as_str(),
450 self.sigmas.node_id.as_str(),
451 ]))
452 }
453}
454
455#[derive(Serialize, Deserialize, Debug, Clone)]
457pub struct SDTurboScheduler {
458 pub steps: Input<u32>,
460 pub model: NodeConnection,
462}
463
464#[typetag::serde]
465impl Node for SDTurboScheduler {
466 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
467 Box::new(
468 [self.steps.node_id(), Some(self.model.node_id.as_str())]
469 .into_iter()
470 .flatten(),
471 )
472 }
473}
474
475#[derive(Serialize, Deserialize, Debug, Clone)]
477pub struct ImageOnlyCheckpointLoader {
478 pub ckpt_name: Input<String>,
480}
481
482#[typetag::serde]
483impl Node for ImageOnlyCheckpointLoader {
484 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
485 Box::new([self.ckpt_name.node_id()].into_iter().flatten())
486 }
487}
488
489#[derive(Serialize, Deserialize, Debug, Clone)]
491pub struct LoadImage {
492 pub upload: Input<String>,
494 pub image: Input<String>,
496}
497
498#[typetag::serde]
499impl Node for LoadImage {
500 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
501 Box::new(
502 [self.upload.node_id(), self.image.node_id()]
503 .into_iter()
504 .flatten(),
505 )
506 }
507}
508
509#[derive(Serialize, Deserialize, Debug, Clone)]
511#[serde(rename = "SVD_img2vid_Conditioning")]
512pub struct SVDimg2vidConditioning {
513 pub augmentation_level: Input<f32>,
515 pub fps: Input<u32>,
517 pub width: Input<u32>,
519 pub height: Input<u32>,
521 pub motion_bucket_id: Input<u32>,
523 pub video_frames: Input<u32>,
525 pub clip_vision: NodeConnection,
527 pub init_image: NodeConnection,
529 pub vae: NodeConnection,
531}
532
533#[typetag::serde]
534impl Node for SVDimg2vidConditioning {
535 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
536 let inputs = [
537 self.augmentation_level.node_id(),
538 self.fps.node_id(),
539 self.width.node_id(),
540 self.height.node_id(),
541 self.motion_bucket_id.node_id(),
542 self.video_frames.node_id(),
543 ]
544 .into_iter()
545 .flatten();
546 Box::new(inputs.chain([
547 self.clip_vision.node_id.as_str(),
548 self.init_image.node_id.as_str(),
549 self.vae.node_id.as_str(),
550 ]))
551 }
552}
553
554#[derive(Serialize, Deserialize, Debug, Clone)]
556pub struct VideoLinearCFGGuidance {
557 pub min_cfg: Input<f32>,
559 pub model: NodeConnection,
561}
562
563#[typetag::serde]
564impl Node for VideoLinearCFGGuidance {
565 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
566 Box::new(
567 [self.min_cfg.node_id(), Some(self.model.node_id.as_str())]
568 .into_iter()
569 .flatten(),
570 )
571 }
572}
573
574#[derive(Serialize, Deserialize, Debug, Clone)]
576pub struct SaveAnimatedWEBP {
577 pub filename_prefix: Input<String>,
579 pub fps: Input<u32>,
581 pub lossless: Input<bool>,
583 pub method: Input<String>,
585 pub quality: Input<u32>,
587 pub images: NodeConnection,
589}
590
591#[typetag::serde]
592impl Node for SaveAnimatedWEBP {
593 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
594 let inputs = [
595 self.filename_prefix.node_id(),
596 self.fps.node_id(),
597 self.lossless.node_id(),
598 self.method.node_id(),
599 self.quality.node_id(),
600 ]
601 .into_iter()
602 .flatten();
603 Box::new(inputs.chain([self.images.node_id.as_str()]))
604 }
605}
606
607#[derive(Serialize, Deserialize, Debug, Clone)]
609pub struct LoraLoader {
610 pub lora_name: Input<String>,
612 pub strength_model: Input<f32>,
614 pub strength_clip: Input<f32>,
616 pub model: NodeConnection,
618 pub clip: NodeConnection,
620}
621
622#[typetag::serde]
623impl Node for LoraLoader {
624 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
625 let inputs = [
626 self.lora_name.node_id(),
627 self.strength_model.node_id(),
628 self.strength_clip.node_id(),
629 ]
630 .into_iter()
631 .flatten();
632 Box::new(inputs.chain([self.model.node_id.as_str(), self.clip.node_id.as_str()]))
633 }
634}
635
636#[derive(Serialize, Deserialize, Debug, Clone)]
638pub struct ModelSamplingDiscrete {
639 pub sampling: Input<String>,
641 pub zsnr: Input<bool>,
643 pub model: NodeConnection,
645}
646
647#[typetag::serde]
648impl Node for ModelSamplingDiscrete {
649 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
650 let inputs = [self.sampling.node_id(), self.zsnr.node_id()]
651 .into_iter()
652 .flatten();
653 Box::new(inputs.chain([self.model.node_id.as_str()]))
654 }
655}
656
657#[derive(Serialize, Deserialize, Debug, Clone)]
659pub struct SaveImage {
660 pub filename_prefix: Input<String>,
662 pub images: NodeConnection,
664}
665
666#[typetag::serde]
667impl Node for SaveImage {
668 fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
669 Box::new(
670 [
671 self.filename_prefix.node_id(),
672 Some(self.images.node_id.as_str()),
673 ]
674 .into_iter()
675 .flatten(),
676 )
677 }
678}
679
680#[derive(Serialize, Deserialize, Debug)]
682pub struct Response {
683 pub prompt_id: uuid::Uuid,
685 pub number: u64,
687 pub node_errors: HashMap<String, serde_json::Value>,
689}