comfyui_api/comfy/
getter.rs

1use std::collections::HashSet;
2
3use anyhow::{anyhow, Context};
4
5use crate::models::*;
6use crate::{comfy::visitor::FindNode, comfy::Visitor};
7
8use super::accessors;
9
10/// A trait for getting values from nodes.
11///
12/// This trait is used to get values from nodes in a `Prompt`. Implementations of this trait
13/// should override `get_value` and `get_value_mut` to get the value of interest from a node.
14/// Implementations may also override `find_node` to find the node to get the value from.
15/// Other methods are used to get values from nodes using heuristics and may not need to be
16/// overridden.
17///
18/// # Type Parameters
19/// * `T` - The type of the value to get.
20/// * `N` - The type of the node to get the value from.
21///
22/// # Examples
23///
24/// See the `Getter` implementations for `Prompt` for an example of how to implement this trait.
25pub trait Getter<T, N>
26where
27    N: Node + 'static,
28    Self: Default,
29{
30    /// Uses a heuristic to find a `Node` and get the value from it.
31    ///
32    /// # Inputs
33    ///
34    /// * `prompt` - A reference to a `Prompt`.
35    ///
36    /// # Returns
37    ///
38    /// A reference to the value on success, or an error if the node could not be found.
39    fn get<'a>(&self, prompt: &'a Prompt) -> anyhow::Result<&'a T> {
40        let node = if let Some(node) = Self::guess_node(prompt, None) {
41            node
42        } else {
43            return Err(anyhow!("Failed to find node"));
44        };
45        self.get_value(node)
46    }
47
48    /// Uses a heuristic to find a `Node` and get the value from it.
49    ///
50    /// # Inputs
51    ///
52    /// * `prompt` - A mutable reference to a `Prompt`.
53    ///
54    /// # Returns
55    ///
56    /// A mutable reference to the value on success, or an error if the node could not be found.
57    fn get_mut<'a>(&self, prompt: &'a mut Prompt) -> anyhow::Result<&'a mut T> {
58        let node = if let Some(node) = Self::guess_node_mut(prompt, None) {
59            node
60        } else {
61            return Err(anyhow!("Failed to find node"));
62        };
63        self.get_value_mut(node)
64    }
65
66    /// Finds a `Node` leading into the given `output_node` and gets the value from it.
67    ///
68    /// # Inputs
69    ///
70    /// * `prompt` - A reference to a `Prompt`.
71    ///
72    /// # Returns
73    ///
74    /// A reference to the value on success, or an error if the node could not be found.
75    fn get_from<'a>(&self, prompt: &'a Prompt, output_node: &str) -> anyhow::Result<&'a T> {
76        let node = if let Some(node) = Self::find_node(prompt, Some(output_node)) {
77            prompt
78                .get_node_by_id(&node)
79                .context("Failed to find node")?
80        } else {
81            return Err(anyhow!("Failed to find node"));
82        };
83        self.get_value(node)
84    }
85
86    /// Finds a `Node` leading into the given `output_node` and gets the value from it.
87    ///
88    /// # Inputs
89    ///
90    /// * `prompt` - A mutable reference to a `Prompt`.
91    ///
92    /// # Returns
93    ///
94    /// A mutable reference to the value on success, or an error if the node could not be found.
95    fn get_from_mut<'a>(
96        &self,
97        prompt: &'a mut Prompt,
98        output_node: &str,
99    ) -> anyhow::Result<&'a mut T> {
100        let node = if let Some(node) = Self::find_node(prompt, Some(output_node)) {
101            prompt
102                .get_node_by_id_mut(&node)
103                .context("Failed to find node")?
104        } else {
105            return Err(anyhow!("Failed to find node"));
106        };
107        self.get_value_mut(node)
108    }
109
110    /// Gets a value from the node with id `node`.
111    ///
112    /// # Inputs
113    ///
114    /// * `prompt` - A reference to a `Prompt`.
115    /// * `node` - The id of the node to set the value on.
116    ///
117    /// # Returns
118    ///
119    /// A reference to the value on success, or an error if the node could not be found.
120    fn get_node<'a>(&self, prompt: &'a Prompt, node: &str) -> anyhow::Result<&'a T> {
121        let node = prompt.get_node_by_id(node).unwrap();
122        self.get_value(node)
123    }
124
125    /// Gets a value from the node with id `node`.
126    ///
127    /// # Inputs
128    ///
129    /// * `prompt` - A mutable reference to a `Prompt`.
130    /// * `node` - The id of the node to set the value on.
131    ///
132    /// # Returns
133    ///
134    /// A mutable reference to the value on success, or an error if the node could not be found.
135    fn get_node_mut<'a>(&self, prompt: &'a mut Prompt, node: &str) -> anyhow::Result<&'a mut T> {
136        let node = prompt.get_node_by_id_mut(node).unwrap();
137        self.get_value_mut(node)
138    }
139
140    /// Gets a value from the given `Node`.
141    ///
142    /// # Inputs
143    ///
144    /// * `node` - A reference to a `Node`.
145    ///
146    /// # Returns
147    ///
148    /// A reference to the value on success, or an error if the node could not be found.
149    fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a T>;
150
151    /// Gets a value from the given `Node`.
152    ///
153    /// # Inputs
154    ///
155    /// * `node` - A mutable reference to a `Node`.
156    ///
157    /// # Returns
158    ///
159    /// A mutable reference to the value on success, or an error if the node could not be found.
160    fn get_value_mut<'a>(&self, node: &'a mut dyn Node) -> anyhow::Result<&'a mut T>;
161
162    /// Finds a `Node` leading into the given `output_node`.
163    ///
164    /// # Inputs
165    ///
166    /// * `prompt` - A mutable reference to a `Prompt`.
167    ///
168    /// # Returns
169    ///
170    /// The id of the node on success, or `None` if the node could not be found.
171    fn find_node(prompt: &Prompt, output_node: Option<&str>) -> Option<String> {
172        find_node::<N>(prompt, output_node)
173    }
174
175    /// Uses a heuristic to find a `Node`.
176    ///
177    /// # Inputs
178    ///
179    /// * `prompt` - A reference to a `Prompt`.
180    /// * `output_node` - The id of the node to search from.
181    ///
182    /// # Returns
183    ///
184    /// A reference to the node on success, or `None` if the node could not be found.
185    fn guess_node<'a>(prompt: &'a Prompt, output_node: Option<&str>) -> Option<&'a dyn Node> {
186        if let Some(node) = Self::find_node(prompt, output_node) {
187            prompt.get_node_by_id(&node)
188        } else if let Some((_, node)) = prompt.get_nodes_by_type::<N>().next() {
189            Some(node)
190        } else {
191            None
192        }
193    }
194
195    /// Uses a heuristic to find a `Node`.
196    ///
197    /// # Inputs
198    ///
199    /// * `prompt` - A mutable reference to a `Prompt`.
200    /// * `output_node` - The id of the node to search from.
201    ///
202    /// # Returns
203    ///
204    /// A mutable reference to the node on success, or `None` if the node could not be found.
205    fn guess_node_mut<'a>(
206        prompt: &'a mut Prompt,
207        output_node: Option<&str>,
208    ) -> Option<&'a mut dyn Node> {
209        if let Some(node) = Self::find_node(prompt, output_node) {
210            prompt.get_node_by_id_mut(&node)
211        } else if let Some((_, node)) = prompt.get_nodes_by_type_mut::<N>().next() {
212            Some(node)
213        } else {
214            None
215        }
216    }
217}
218
219/// Extension methods for `Prompt` to get nodes.
220pub trait GetExt<N>
221where
222    N: Node + 'static,
223{
224    /// Gets a reference to the node with id `node`.
225    ///
226    /// # Inputs
227    ///
228    /// * `node` - The id of the node to get the value from.
229    ///
230    /// # Returns
231    ///
232    /// A `Result` containing the reference on success, or an error if the node could not be found.
233    fn get_typed_node(&self, node: &str) -> anyhow::Result<&N>;
234
235    /// Gets a mutable reference to the node with id `node`.
236    ///
237    /// # Inputs
238    ///
239    /// * `node` - The id of the node to get the value from.
240    ///
241    /// # Returns
242    ///
243    /// A `Result` containing the mutable reference on success, or an error if the node could not be found.
244    fn get_typed_node_mut(&mut self, node: &str) -> anyhow::Result<&mut N>;
245}
246
247impl<N: Node + 'static> GetExt<N> for Prompt {
248    fn get_typed_node(&self, node: &str) -> anyhow::Result<&N> {
249        let node = self.get_node_by_id(node).context("failed to get node")?;
250        as_node::<N>(node).context("Failed to cast node")
251    }
252
253    fn get_typed_node_mut(&mut self, node: &str) -> anyhow::Result<&mut N> {
254        let node = self
255            .get_node_by_id_mut(node)
256            .context("failed to get node")?;
257        as_node_mut::<N>(node).context("Failed to cast node")
258    }
259}
260
261trait GetterExt<T, N>
262where
263    N: Node + 'static,
264{
265    fn get<G>(&self) -> anyhow::Result<&T>
266    where
267        G: Getter<T, N>;
268
269    fn get_mut<G>(&mut self) -> anyhow::Result<&mut T>
270    where
271        G: Getter<T, N>;
272
273    fn get_from<G>(&self, output_node: &str) -> anyhow::Result<&T>
274    where
275        G: Getter<T, N>;
276
277    fn get_from_mut<G>(&mut self, output_node: &str) -> anyhow::Result<&mut T>
278    where
279        G: Getter<T, N>;
280
281    fn get_node<G>(&self, node: &str) -> anyhow::Result<&T>
282    where
283        G: Getter<T, N>;
284
285    fn get_node_mut<G>(&mut self, node: &str) -> anyhow::Result<&mut T>
286    where
287        G: Getter<T, N>;
288}
289
290impl<T, N: Node + 'static> GetterExt<T, N> for Prompt {
291    fn get<G>(&self) -> anyhow::Result<&T>
292    where
293        G: Getter<T, N>,
294    {
295        G::default().get(self).context("Failed to get value")
296    }
297
298    fn get_mut<G>(&mut self) -> anyhow::Result<&mut T>
299    where
300        G: Getter<T, N>,
301    {
302        G::default().get_mut(self).context("Failed to get value")
303    }
304
305    fn get_from<G>(&self, output_node: &str) -> anyhow::Result<&T>
306    where
307        G: Getter<T, N>,
308    {
309        G::default()
310            .get_from(self, output_node)
311            .context("Failed to get value")
312    }
313
314    fn get_from_mut<G>(&mut self, output_node: &str) -> anyhow::Result<&mut T>
315    where
316        G: Getter<T, N>,
317    {
318        G::default()
319            .get_from_mut(self, output_node)
320            .context("Failed to get value")
321    }
322
323    fn get_node<G>(&self, node: &str) -> anyhow::Result<&T>
324    where
325        G: Getter<T, N>,
326    {
327        G::default()
328            .get_node(self, node)
329            .context("Failed to get value")
330    }
331
332    fn get_node_mut<G>(&mut self, node: &str) -> anyhow::Result<&mut T>
333    where
334        G: Getter<T, N>,
335    {
336        G::default()
337            .get_node_mut(self, node)
338            .context("Failed to get value")
339    }
340}
341
342pub(crate) fn find_node<T: Node + 'static>(
343    prompt: &Prompt,
344    output_node: Option<&str>,
345) -> Option<String> {
346    let output_node = if let Some(node) = output_node {
347        node.to_string()
348    } else {
349        find_output_node(prompt)?
350    };
351    let mut find_node = FindNode::<T>::new(output_node.clone());
352    find_node.visit(prompt, prompt.get_node_by_id(&output_node).unwrap());
353    find_node.found
354}
355
356#[allow(dead_code)]
357pub(crate) fn guess_node<'a, T: Node + 'static>(
358    prompt: &'a Prompt,
359    output_node: Option<&str>,
360) -> Option<&'a dyn Node> {
361    if let Some(node) = find_node::<T>(prompt, output_node) {
362        prompt.get_node_by_id(&node)
363    } else if let Some((_, node)) = prompt.get_nodes_by_type::<T>().next() {
364        Some(node)
365    } else {
366        None
367    }
368}
369
370pub(crate) fn guess_node_mut<'a, T: Node + 'static>(
371    prompt: &'a mut Prompt,
372    output_node: Option<&str>,
373) -> Option<&'a mut dyn Node> {
374    if let Some(node) = find_node::<T>(prompt, output_node) {
375        prompt.get_node_by_id_mut(&node)
376    } else if let Some((_, node)) = prompt.get_nodes_by_type_mut::<T>().next() {
377        Some(node)
378    } else {
379        None
380    }
381}
382
383pub(crate) fn find_output_node(prompt: &Prompt) -> Option<String> {
384    let nodes: HashSet<String> = prompt.workflow.keys().cloned().collect();
385    prompt
386        .workflow
387        .iter()
388        .fold(nodes, |mut nodes, (key, value)| {
389            let mut has_input = false;
390            let connections = match value {
391                NodeOrUnknown::Node(node) => node.connections(),
392                NodeOrUnknown::GenericNode(node) => node.connections(),
393            };
394            for c in connections {
395                has_input = true;
396                nodes.remove(c);
397            }
398            if !has_input {
399                nodes.remove(key);
400            }
401            nodes
402        })
403        .into_iter()
404        .next()
405}
406
407macro_rules! create_getter {
408    ($ValueType:ty, $NodeType:ty, $AccessorType:ty, $field_name:ident) => {
409        impl Getter<$ValueType, $NodeType> for $AccessorType {
410            fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a $ValueType> {
411                as_node::<$NodeType>(node)
412                    .context(concat!("Failed to cast node to ", stringify!($NodeType)))?
413                    .$field_name
414                    .value()
415                    .context(concat!(
416                        "Failed to get ",
417                        stringify!($getter_name),
418                        " value"
419                    ))
420            }
421
422            fn get_value_mut<'a>(
423                &self,
424                node: &'a mut dyn Node,
425            ) -> anyhow::Result<&'a mut $ValueType> {
426                as_node_mut::<$NodeType>(node)
427                    .context(concat!("Failed to cast node to ", stringify!($NodeType)))?
428                    .$field_name
429                    .value_mut()
430                    .context(concat!(
431                        "Failed to get ",
432                        stringify!($getter_name),
433                        " value"
434                    ))
435            }
436        }
437    };
438}
439
440macro_rules! create_ext_trait {
441    ($ValueType:ty, $AccessorType:ty, $getter_name:ident, $getter_name_mut:ident, $TraitName:ident) => {
442        /// Trait to get references to values from a `Prompt`.
443        pub trait $TraitName {
444            /// Get a reference to the value.
445            ///
446            /// # Returns
447            ///
448            /// A `Result` containing the reference on success, or an error if the node could not be found.
449            fn $getter_name(&self) -> anyhow::Result<&$ValueType>;
450
451            /// Get a mutable reference to the value.
452            ///
453            /// # Returns
454            ///
455            /// A `Result` containing the mutable reference on success, or an error if the node could not be found.
456            fn $getter_name_mut(&mut self) -> anyhow::Result<&mut $ValueType>;
457        }
458
459        impl $TraitName for Prompt {
460            fn $getter_name(&self) -> anyhow::Result<&$ValueType> {
461                self.get::<$AccessorType>()
462            }
463
464            fn $getter_name_mut(&mut self) -> anyhow::Result<&mut $ValueType> {
465                self.get_mut::<$AccessorType>()
466            }
467        }
468    };
469}
470
471impl Getter<String, CLIPTextEncode> for accessors::Prompt {
472    fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a String> {
473        as_node::<CLIPTextEncode>(node)
474            .context("Failed to cast node")?
475            .text
476            .value()
477            .context("Failed to get text value")
478    }
479
480    fn get_value_mut<'a>(&self, node: &'a mut dyn Node) -> anyhow::Result<&'a mut String> {
481        as_node_mut::<CLIPTextEncode>(node)
482            .context("Failed to cast node")?
483            .text
484            .value_mut()
485            .context("Failed to get text value")
486    }
487
488    fn find_node(prompt: &Prompt, output_node: Option<&str>) -> Option<String> {
489        if let Some(node) = find_node::<KSampler>(prompt, output_node) {
490            if let Ok(node) = prompt.get_typed_node(&node) as anyhow::Result<&KSampler> {
491                return Some(node.positive.node_id.clone());
492            }
493        }
494        if let Some(node) = find_node::<SamplerCustom>(prompt, output_node) {
495            if let Ok(node) = prompt.get_typed_node(&node) as anyhow::Result<&SamplerCustom> {
496                return Some(node.positive.node_id.clone());
497            }
498        }
499        None
500    }
501}
502
503create_ext_trait!(String, accessors::Prompt, prompt, prompt_mut, PromptExt);
504
505impl Getter<String, CLIPTextEncode> for accessors::NegativePrompt {
506    fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a String> {
507        accessors::Prompt.get_value(node)
508    }
509
510    fn get_value_mut<'a>(&self, node: &'a mut dyn Node) -> anyhow::Result<&'a mut String> {
511        accessors::Prompt.get_value_mut(node)
512    }
513
514    fn find_node(prompt: &Prompt, output_node: Option<&str>) -> Option<String> {
515        if let Some(node) = find_node::<KSampler>(prompt, output_node) {
516            if let Ok(node) = prompt.get_typed_node(&node) as anyhow::Result<&KSampler> {
517                return Some(node.negative.node_id.clone());
518            }
519        }
520        if let Some(node) = find_node::<SamplerCustom>(prompt, output_node) {
521            if let Ok(node) = prompt.get_typed_node(&node) as anyhow::Result<&SamplerCustom> {
522                return Some(node.negative.node_id.clone());
523            }
524        }
525        None
526    }
527}
528
529create_ext_trait!(
530    String,
531    accessors::NegativePrompt,
532    negative_prompt,
533    negative_prompt_mut,
534    NegativePromptExt
535);
536
537create_getter!(String, CheckpointLoaderSimple, accessors::Model, ckpt_name);
538create_ext_trait!(String, accessors::Model, ckpt_name, ckpt_name_mut, ModelExt);
539
540create_getter!(u32, EmptyLatentImage, accessors::Width, width);
541create_ext_trait!(u32, accessors::Width, width, width_mut, WidthExt);
542
543create_getter!(u32, EmptyLatentImage, accessors::Height, height);
544create_ext_trait!(u32, accessors::Height, height, height_mut, HeightExt);
545
546create_getter!(i64, KSampler, accessors::SeedT<KSampler>, seed);
547create_getter!(
548    i64,
549    SamplerCustom,
550    accessors::SeedT<SamplerCustom>,
551    noise_seed
552);
553
554create_ext_trait!(i64, accessors::Seed, seed, seed_mut, SeedExt);
555
556create_getter!(u32, KSampler, accessors::StepsT<KSampler>, steps);
557create_getter!(
558    u32,
559    SDTurboScheduler,
560    accessors::StepsT<SDTurboScheduler>,
561    steps
562);
563
564create_ext_trait!(u32, accessors::Steps, steps, steps_mut, StepsExt);
565
566impl<S1, S2, T, N1, N2> Getter<T, N1> for accessors::Delegating<S1, S2, T, N1, N2>
567where
568    S1: Getter<T, N1>,
569    S2: Getter<T, N2>,
570    N1: Node + 'static,
571    N2: Node + 'static,
572    T: Clone + Default,
573{
574    fn get<'a>(&self, prompt: &'a Prompt) -> anyhow::Result<&'a T> {
575        S1::default()
576            .get(prompt)
577            .or_else(|_| S2::default().get(prompt).context("Failed to set value"))
578    }
579
580    fn get_mut<'a>(&self, prompt: &'a mut Prompt) -> anyhow::Result<&'a mut T> {
581        let s1 = S1::default();
582        if s1.get(prompt).is_ok() {
583            return s1.get_mut(prompt);
584        }
585        S2::default().get_mut(prompt).context("Failed to set value")
586    }
587
588    fn get_from<'a>(&self, prompt: &'a Prompt, output_node: &str) -> anyhow::Result<&'a T> {
589        S1::default().get_from(prompt, output_node).or_else(|_| {
590            S2::default()
591                .get_from(prompt, output_node)
592                .context("Failed to set value")
593        })
594    }
595
596    fn get_from_mut<'a>(
597        &self,
598        prompt: &'a mut Prompt,
599        output_node: &str,
600    ) -> anyhow::Result<&'a mut T> {
601        let s1 = S1::default();
602        if s1.get_from(prompt, output_node).is_ok() {
603            return s1.get_from_mut(prompt, output_node);
604        }
605        S2::default()
606            .get_from_mut(prompt, output_node)
607            .context("Failed to set value")
608    }
609
610    fn get_node<'a>(&self, prompt: &'a Prompt, node: &str) -> anyhow::Result<&'a T> {
611        S1::default().get_node(prompt, node).or_else(|_| {
612            S2::default()
613                .get_node(prompt, node)
614                .context("Failed to set value")
615        })
616    }
617
618    fn get_node_mut<'a>(&self, prompt: &'a mut Prompt, node: &str) -> anyhow::Result<&'a mut T> {
619        let s1 = S1::default();
620        if s1.get_node(prompt, node).is_ok() {
621            return s1.get_node_mut(prompt, node);
622        }
623        S2::default()
624            .get_node_mut(prompt, node)
625            .context("Failed to set value")
626    }
627
628    fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a T> {
629        S1::default()
630            .get_value(node)
631            .or_else(|_| S2::default().get_value(node).context("Failed to set value"))
632    }
633
634    fn get_value_mut<'a>(&self, node: &'a mut dyn Node) -> anyhow::Result<&'a mut T> {
635        let s1 = S1::default();
636        if s1.get_value(node).is_ok() {
637            return s1.get_value_mut(node);
638        }
639        S2::default()
640            .get_value_mut(node)
641            .context("Failed to set value")
642    }
643
644    fn find_node(prompt: &Prompt, output_node: Option<&str>) -> Option<String> {
645        find_node::<N1>(prompt, output_node).or_else(|| find_node::<N2>(prompt, output_node))
646    }
647}
648
649create_getter!(f32, KSampler, accessors::CfgT<KSampler>, cfg);
650create_getter!(f32, SamplerCustom, accessors::CfgT<SamplerCustom>, cfg);
651create_ext_trait!(f32, accessors::Cfg, cfg, cfg_mut, CfgExt);
652
653create_getter!(f32, KSampler, accessors::Denoise, denoise);
654create_ext_trait!(f32, accessors::Denoise, denoise, denoise_mut, DenoiseExt);
655
656create_getter!(
657    String,
658    KSampler,
659    accessors::SamplerT<KSampler>,
660    sampler_name
661);
662create_getter!(
663    String,
664    KSamplerSelect,
665    accessors::SamplerT<KSamplerSelect>,
666    sampler_name
667);
668create_ext_trait!(
669    String,
670    accessors::Sampler,
671    sampler_name,
672    sampler_name_mut,
673    SamplerExt
674);
675
676create_getter!(u32, EmptyLatentImage, accessors::BatchSize, batch_size);
677create_ext_trait!(
678    u32,
679    accessors::BatchSize,
680    batch_size,
681    batch_size_mut,
682    BatchSizeExt
683);
684
685create_getter!(String, LoadImage, accessors::LoadImage, image);
686create_ext_trait!(String, accessors::LoadImage, image, image_mut, LoadImageExt);