comfyui_api/models/
prompt.rs

1use std::{any::Any, collections::HashMap};
2
3use dyn_clone::DynClone;
4use serde::{Deserialize, Serialize};
5
6/// Struct representing a prompt workflow.
7#[derive(Default, Serialize, Deserialize, Debug, Clone)]
8pub struct Prompt {
9    /// The prompt workflow, indexed by node id.
10    #[serde(flatten)]
11    pub workflow: HashMap<String, NodeOrUnknown>,
12}
13
14impl Prompt {
15    pub fn get_node_by_id(&self, id: &str) -> Option<&dyn Node> {
16        match self.workflow.get(id) {
17            Some(NodeOrUnknown::Node(node)) => Some(node.as_ref()),
18            Some(NodeOrUnknown::GenericNode(node)) => Some(node),
19            _ => None,
20        }
21    }
22
23    pub fn get_node_by_id_mut(&mut self, id: &str) -> Option<&mut dyn Node> {
24        match self.workflow.get_mut(id) {
25            Some(NodeOrUnknown::Node(node)) => Some(node.as_mut()),
26            Some(NodeOrUnknown::GenericNode(node)) => Some(node),
27            _ => None,
28        }
29    }
30
31    pub fn get_nodes_by_type<T: Node + 'static>(&self) -> impl Iterator<Item = (&str, &T)> {
32        self.workflow.iter().filter_map(|(key, node)| match node {
33            NodeOrUnknown::Node(node) => as_node::<T>(node.as_ref()).map(|n| (key.as_str(), n)),
34            NodeOrUnknown::GenericNode(node) => as_node::<T>(node).map(|n| (key.as_str(), n)),
35        })
36    }
37
38    pub fn get_nodes_by_type_mut<T: Node + 'static>(
39        &mut self,
40    ) -> impl Iterator<Item = (&str, &mut T)> {
41        self.workflow
42            .iter_mut()
43            .filter_map(|(key, node)| match node {
44                NodeOrUnknown::Node(node) => {
45                    as_node_mut::<T>(node.as_mut()).map(|n| (key.as_str(), n))
46                }
47                NodeOrUnknown::GenericNode(node) => {
48                    as_node_mut::<T>(node).map(|n| (key.as_str(), n))
49                }
50            })
51    }
52}
53
54/// Enum capturing all possible node types.
55#[derive(Serialize, Deserialize, Debug, Clone)]
56#[serde(untagged)]
57pub enum NodeOrUnknown {
58    /// Enum variant representing a known node.
59    Node(Box<dyn Node>),
60    /// Variant capturing unknown nodes.
61    GenericNode(GenericNode),
62}
63
64impl<T: Any> AsAny for T {
65    fn as_any(&self) -> &dyn Any {
66        self
67    }
68
69    fn as_any_mut(&mut self) -> &mut dyn Any {
70        self
71    }
72}
73
74/// Trait to allow downcasting to `dyn Any`.
75pub trait AsAny {
76    /// Get a reference to `dyn Any`.
77    fn as_any(&self) -> &dyn Any;
78
79    fn as_any_mut(&mut self) -> &mut dyn Any;
80}
81
82/// Get a reference to a node of a specific type.
83///
84/// # Arguments
85///
86/// * `node` - The node to get a reference to.
87///
88/// # Returns
89///
90/// A reference to the node of the specified type if the node is of the specified type, otherwise `None`.
91pub fn as_node<T: Node + 'static>(node: &dyn Node) -> Option<&T> {
92    node.as_any().downcast_ref::<T>()
93}
94
95/// Get a mutable reference to a node of a specific type.
96///
97/// # Arguments
98///
99/// * `node` - The node to get a reference to.
100///
101/// # Returns
102///
103/// A reference to the node of the specified type if the node is of the specified type, otherwise `None`.
104pub fn as_node_mut<T: Node + 'static>(node: &mut dyn Node) -> Option<&mut T> {
105    node.as_any_mut().downcast_mut::<T>()
106}
107
108dyn_clone::clone_trait_object!(Node);
109
110#[typetag::serde(tag = "class_type", content = "inputs")]
111pub trait Node: std::fmt::Debug + Send + Sync + AsAny + DynClone {
112    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_>;
113    fn name(&self) -> &str {
114        self.typetag_name()
115    }
116}
117
118/// Struct representing a node metadata.
119#[derive(Serialize, Deserialize, Debug, Clone)]
120pub struct Meta {
121    /// Node title.
122    pub title: String,
123}
124
125/// Struct representing a generic node.
126#[derive(Serialize, Deserialize, Debug, Clone)]
127pub struct GenericNode {
128    /// The node class type.
129    pub class_type: String,
130    /// The node inputs.
131    pub inputs: HashMap<String, GenericValue>,
132    /// Node metadata.
133    #[serde(rename = "_meta")]
134    pub meta: Option<Meta>,
135}
136
137#[typetag::serde]
138impl Node for GenericNode {
139    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
140        Box::new(self.inputs.values().filter_map(|input| input.node_id()))
141    }
142    fn name(&self) -> &str {
143        &self.class_type
144    }
145}
146
147/// Enum of possible generic node input types.
148#[derive(Serialize, Deserialize, Debug, Clone)]
149#[serde(untagged)]
150pub enum GenericValue {
151    /// Bool input variant.
152    Bool(bool),
153    /// Integer input variant.
154    Int(i64),
155    /// Float input variant.
156    Float(f32),
157    /// String input variant.
158    String(String),
159    /// Node connection input variant.
160    NodeConnection(NodeConnection),
161}
162
163impl GenericValue {
164    /// Get the node id of the input.
165    pub fn node_id(&self) -> Option<&str> {
166        match self {
167            GenericValue::NodeConnection(node_connection) => Some(&node_connection.node_id),
168            _ => None,
169        }
170    }
171}
172
173/// Struct representing a node input connection.
174#[derive(Serialize, Deserialize, Debug, Clone)]
175#[serde(from = "(String, u32)")]
176#[serde(into = "(String, u32)")]
177pub struct NodeConnection {
178    /// The node id of the node providing the input.
179    pub node_id: String,
180    /// The index of the output from the node providing the input.
181    pub output_index: u32,
182}
183
184impl From<(String, u32)> for NodeConnection {
185    fn from((node_id, output_index): (String, u32)) -> Self {
186        Self {
187            node_id,
188            output_index,
189        }
190    }
191}
192
193impl From<NodeConnection> for (String, u32) {
194    fn from(
195        NodeConnection {
196            node_id,
197            output_index,
198        }: NodeConnection,
199    ) -> Self {
200        (node_id, output_index)
201    }
202}
203
204/// Enum of inputs to a node.
205#[derive(Serialize, Deserialize, Debug, Clone)]
206#[serde(untagged)]
207pub enum Input<T> {
208    /// Node connection input variant.
209    NodeConnection(NodeConnection),
210    /// Widget input variant.
211    Value(T),
212}
213
214impl<T> Input<T> {
215    /// Get the value of the input.
216    pub fn value(&self) -> Option<&T> {
217        match self {
218            Input::NodeConnection(_) => None,
219            Input::Value(value) => Some(value),
220        }
221    }
222
223    /// Get a mutable value of the input.
224    pub fn value_mut(&mut self) -> Option<&mut T> {
225        match self {
226            Input::NodeConnection(_) => None,
227            Input::Value(value) => Some(value),
228        }
229    }
230
231    /// Get the node connection of the input.
232    pub fn node_connection(&self) -> Option<&NodeConnection> {
233        match self {
234            Input::NodeConnection(node_connection) => Some(node_connection),
235            Input::Value(_) => None,
236        }
237    }
238
239    /// Get the node id of the input.
240    pub fn node_id(&self) -> Option<&str> {
241        self.node_connection()
242            .map(|node_connection| node_connection.node_id.as_str())
243    }
244}
245
246/// Struct representing a KSampler node.
247#[derive(Serialize, Deserialize, Debug, Clone)]
248pub struct KSampler {
249    /// The cfg scale parameter.
250    pub cfg: Input<f32>,
251    /// The denoise parameter.
252    pub denoise: Input<f32>,
253    /// The sampler name.
254    pub sampler_name: Input<String>,
255    /// The scheduler used.
256    pub scheduler: Input<String>,
257    /// The seed.
258    pub seed: Input<i64>,
259    /// The number of steps.
260    pub steps: Input<u32>,
261    /// The positive conditioning input connection.
262    pub positive: NodeConnection,
263    /// The negative conditioning input connection.
264    pub negative: NodeConnection,
265    /// The model input connection.
266    pub model: NodeConnection,
267    /// The latent image input connection.
268    pub latent_image: NodeConnection,
269}
270
271#[typetag::serde]
272impl Node for KSampler {
273    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
274        let inputs = [
275            self.cfg.node_id(),
276            self.denoise.node_id(),
277            self.sampler_name.node_id(),
278            self.scheduler.node_id(),
279            self.seed.node_id(),
280            self.steps.node_id(),
281        ]
282        .into_iter()
283        .flatten();
284        Box::new(inputs.chain([
285            self.positive.node_id.as_str(),
286            self.negative.node_id.as_str(),
287            self.model.node_id.as_str(),
288            self.latent_image.node_id.as_str(),
289        ]))
290    }
291}
292
293/// Struct representing a CLIPTextEncode node.
294#[derive(Serialize, Deserialize, Debug, Clone)]
295pub struct CLIPTextEncode {
296    /// The text to encode.
297    pub text: Input<String>,
298    /// The CLIP model input connection.
299    pub clip: NodeConnection,
300}
301
302#[typetag::serde]
303impl Node for CLIPTextEncode {
304    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
305        Box::new(
306            [self.text.node_id(), Some(self.clip.node_id.as_str())]
307                .into_iter()
308                .flatten(),
309        )
310    }
311}
312
313/// Struct representing an EmptyLatentImage node.
314#[derive(Serialize, Deserialize, Debug, Clone)]
315pub struct EmptyLatentImage {
316    /// The batch size.
317    pub batch_size: Input<u32>,
318    /// The image width.
319    pub width: Input<u32>,
320    /// The image height.
321    pub height: Input<u32>,
322}
323
324#[typetag::serde]
325impl Node for EmptyLatentImage {
326    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
327        Box::new(
328            [
329                self.batch_size.node_id(),
330                self.width.node_id(),
331                self.height.node_id(),
332            ]
333            .into_iter()
334            .flatten(),
335        )
336    }
337}
338
339/// Struct representing a CheckpointLoaderSimple node.
340#[derive(Serialize, Deserialize, Debug, Clone)]
341pub struct CheckpointLoaderSimple {
342    /// The checkpoint name.
343    pub ckpt_name: Input<String>,
344}
345
346#[typetag::serde]
347impl Node for CheckpointLoaderSimple {
348    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
349        Box::new([self.ckpt_name.node_id()].into_iter().flatten())
350    }
351}
352
353/// Struct representing a VAELoader node.
354#[derive(Serialize, Deserialize, Debug, Clone)]
355pub struct VAELoader {
356    /// The VAE name.
357    pub vae_name: Input<String>,
358}
359
360#[typetag::serde]
361impl Node for VAELoader {
362    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
363        Box::new([self.vae_name.node_id()].into_iter().flatten())
364    }
365}
366
367/// Struct representing a VAEDecode node.
368#[derive(Serialize, Deserialize, Debug, Clone)]
369pub struct VAEDecode {
370    /// Latent output samples to decode.
371    pub samples: NodeConnection,
372    /// VAE model input connection.
373    pub vae: NodeConnection,
374}
375
376#[typetag::serde]
377impl Node for VAEDecode {
378    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
379        Box::new([self.samples.node_id.as_str(), self.vae.node_id.as_str()].into_iter())
380    }
381}
382
383/// Struct representing a PreviewImage node.
384#[derive(Serialize, Deserialize, Debug, Clone)]
385pub struct PreviewImage {
386    /// The images to preview.
387    pub images: NodeConnection,
388}
389
390#[typetag::serde]
391impl Node for PreviewImage {
392    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
393        Box::new([self.images.node_id.as_str()].into_iter())
394    }
395}
396
397/// Struct representing a KSamplerSelect node.
398#[derive(Serialize, Deserialize, Debug, Clone)]
399pub struct KSamplerSelect {
400    /// The sampler name.
401    pub sampler_name: Input<String>,
402}
403
404#[typetag::serde]
405impl Node for KSamplerSelect {
406    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
407        Box::new([self.sampler_name.node_id()].into_iter().flatten())
408    }
409}
410
411/// Struct representing a SamplerCustom node.
412#[derive(Serialize, Deserialize, Debug, Clone)]
413pub struct SamplerCustom {
414    /// Whether or not to add noise.
415    pub add_noise: Input<bool>,
416    /// The cfg scale.
417    pub cfg: Input<f32>,
418    /// The seed.
419    pub noise_seed: Input<i64>,
420    /// Latent image input connection.
421    pub latent_image: NodeConnection,
422    /// The model input connection.
423    pub model: NodeConnection,
424    /// The positive conditioning input connection.
425    pub positive: NodeConnection,
426    /// The negative conditioning input connection.
427    pub negative: NodeConnection,
428    /// The sampler input connection.
429    pub sampler: NodeConnection,
430    /// The sigmas from the scheduler.
431    pub sigmas: NodeConnection,
432}
433
434#[typetag::serde]
435impl Node for SamplerCustom {
436    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
437        let inputs = [
438            self.add_noise.node_id(),
439            self.cfg.node_id(),
440            self.noise_seed.node_id(),
441        ]
442        .into_iter()
443        .flatten();
444        Box::new(inputs.chain([
445            self.latent_image.node_id.as_str(),
446            self.model.node_id.as_str(),
447            self.positive.node_id.as_str(),
448            self.negative.node_id.as_str(),
449            self.sampler.node_id.as_str(),
450            self.sigmas.node_id.as_str(),
451        ]))
452    }
453}
454
455/// Struct representing a SDTurboScheduler node.
456#[derive(Serialize, Deserialize, Debug, Clone)]
457pub struct SDTurboScheduler {
458    /// The number of steps.
459    pub steps: Input<u32>,
460    /// The model input connection.
461    pub model: NodeConnection,
462}
463
464#[typetag::serde]
465impl Node for SDTurboScheduler {
466    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
467        Box::new(
468            [self.steps.node_id(), Some(self.model.node_id.as_str())]
469                .into_iter()
470                .flatten(),
471        )
472    }
473}
474
475/// Struct representing a ImageOnlyCheckpointLoader node.
476#[derive(Serialize, Deserialize, Debug, Clone)]
477pub struct ImageOnlyCheckpointLoader {
478    /// The checkpoint name.
479    pub ckpt_name: Input<String>,
480}
481
482#[typetag::serde]
483impl Node for ImageOnlyCheckpointLoader {
484    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
485        Box::new([self.ckpt_name.node_id()].into_iter().flatten())
486    }
487}
488
489/// Struct representing a LoadImage node.
490#[derive(Serialize, Deserialize, Debug, Clone)]
491pub struct LoadImage {
492    /// UI file selection button.
493    pub upload: Input<String>,
494    /// The name of the image to load.
495    pub image: Input<String>,
496}
497
498#[typetag::serde]
499impl Node for LoadImage {
500    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
501        Box::new(
502            [self.upload.node_id(), self.image.node_id()]
503                .into_iter()
504                .flatten(),
505        )
506    }
507}
508
509/// Struct representing a SVDimg2vidConditioning node.
510#[derive(Serialize, Deserialize, Debug, Clone)]
511#[serde(rename = "SVD_img2vid_Conditioning")]
512pub struct SVDimg2vidConditioning {
513    /// The augmentation level.
514    pub augmentation_level: Input<f32>,
515    /// The FPS.
516    pub fps: Input<u32>,
517    /// The video width.
518    pub width: Input<u32>,
519    /// The video height.
520    pub height: Input<u32>,
521    /// The motion bucket id.
522    pub motion_bucket_id: Input<u32>,
523    /// The number of frames.
524    pub video_frames: Input<u32>,
525    /// The CLIP vision model input connection.
526    pub clip_vision: NodeConnection,
527    /// The init image input connection.
528    pub init_image: NodeConnection,
529    /// The VAE model input connection.
530    pub vae: NodeConnection,
531}
532
533#[typetag::serde]
534impl Node for SVDimg2vidConditioning {
535    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
536        let inputs = [
537            self.augmentation_level.node_id(),
538            self.fps.node_id(),
539            self.width.node_id(),
540            self.height.node_id(),
541            self.motion_bucket_id.node_id(),
542            self.video_frames.node_id(),
543        ]
544        .into_iter()
545        .flatten();
546        Box::new(inputs.chain([
547            self.clip_vision.node_id.as_str(),
548            self.init_image.node_id.as_str(),
549            self.vae.node_id.as_str(),
550        ]))
551    }
552}
553
554/// Struct representing a VideoLinearCFGGuidance node.
555#[derive(Serialize, Deserialize, Debug, Clone)]
556pub struct VideoLinearCFGGuidance {
557    /// The minimum cfg scale.
558    pub min_cfg: Input<f32>,
559    /// The model input connection.
560    pub model: NodeConnection,
561}
562
563#[typetag::serde]
564impl Node for VideoLinearCFGGuidance {
565    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
566        Box::new(
567            [self.min_cfg.node_id(), Some(self.model.node_id.as_str())]
568                .into_iter()
569                .flatten(),
570        )
571    }
572}
573
574/// Struct representing a SaveAnimatedWEBP node.
575#[derive(Serialize, Deserialize, Debug, Clone)]
576pub struct SaveAnimatedWEBP {
577    /// The filename prefix.
578    pub filename_prefix: Input<String>,
579    /// The FPS.
580    pub fps: Input<u32>,
581    /// Whether or not to losslessly encode the video.
582    pub lossless: Input<bool>,
583    /// The encoding method.
584    pub method: Input<String>,
585    /// The quality.
586    pub quality: Input<u32>,
587    /// Input images connection.
588    pub images: NodeConnection,
589}
590
591#[typetag::serde]
592impl Node for SaveAnimatedWEBP {
593    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
594        let inputs = [
595            self.filename_prefix.node_id(),
596            self.fps.node_id(),
597            self.lossless.node_id(),
598            self.method.node_id(),
599            self.quality.node_id(),
600        ]
601        .into_iter()
602        .flatten();
603        Box::new(inputs.chain([self.images.node_id.as_str()]))
604    }
605}
606
607/// Struct representing a LoraLoader node.
608#[derive(Serialize, Deserialize, Debug, Clone)]
609pub struct LoraLoader {
610    /// The name of the LORA model.
611    pub lora_name: Input<String>,
612    /// The model strength.
613    pub strength_model: Input<f32>,
614    /// The CLIP strength.
615    pub strength_clip: Input<f32>,
616    /// The model input connection.
617    pub model: NodeConnection,
618    /// The CLIP input connection.
619    pub clip: NodeConnection,
620}
621
622#[typetag::serde]
623impl Node for LoraLoader {
624    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
625        let inputs = [
626            self.lora_name.node_id(),
627            self.strength_model.node_id(),
628            self.strength_clip.node_id(),
629        ]
630        .into_iter()
631        .flatten();
632        Box::new(inputs.chain([self.model.node_id.as_str(), self.clip.node_id.as_str()]))
633    }
634}
635
636/// Struct representing a ModelSamplingDiscrete node.
637#[derive(Serialize, Deserialize, Debug, Clone)]
638pub struct ModelSamplingDiscrete {
639    /// Sampling to use.
640    pub sampling: Input<String>,
641    /// Use ZSNR.
642    pub zsnr: Input<bool>,
643    /// The model input connection.
644    pub model: NodeConnection,
645}
646
647#[typetag::serde]
648impl Node for ModelSamplingDiscrete {
649    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
650        let inputs = [self.sampling.node_id(), self.zsnr.node_id()]
651            .into_iter()
652            .flatten();
653        Box::new(inputs.chain([self.model.node_id.as_str()]))
654    }
655}
656
657/// Struct representing a SaveImage node.
658#[derive(Serialize, Deserialize, Debug, Clone)]
659pub struct SaveImage {
660    /// The filename prefix.
661    pub filename_prefix: Input<String>,
662    /// The image input connection.
663    pub images: NodeConnection,
664}
665
666#[typetag::serde]
667impl Node for SaveImage {
668    fn connections(&'_ self) -> Box<dyn Iterator<Item = &str> + '_> {
669        Box::new(
670            [
671                self.filename_prefix.node_id(),
672                Some(self.images.node_id.as_str()),
673            ]
674            .into_iter()
675            .flatten(),
676        )
677    }
678}
679
680/// Struct representing a response to a prompt execution request.
681#[derive(Serialize, Deserialize, Debug)]
682pub struct Response {
683    /// The prompt id.
684    pub prompt_id: uuid::Uuid,
685    /// The prompt number.
686    pub number: u64,
687    /// Node errors that have occurred indexed by node id.
688    pub node_errors: HashMap<String, serde_json::Value>,
689}