comfyui_api/comfy/
visitor.rs

1use crate::models::*;
2
3use super::ImageInfo;
4
5/// Trait for visiting nodes in a ComfyUI graph.
6pub trait Visitor {
7    /// Visits a node in a ComfyUI graph.
8    ///
9    /// # Arguments
10    ///
11    /// * `prompt` - The prompt that contains the graph.
12    /// * `node` - The node to visit.
13    fn visit(&mut self, prompt: &Prompt, node: &dyn Node) {
14        for c in node.connections() {
15            if let Some(node) = prompt.get_node_by_id(c) {
16                self.visit(prompt, node);
17            }
18        }
19    }
20}
21
22impl Visitor for ImageInfo {
23    fn visit(&mut self, prompt: &Prompt, node: &dyn Node) {
24        if let Some(node) = as_node::<CheckpointLoaderSimple>(node) {
25            self.model = node.ckpt_name.value().cloned();
26        } else if let Some(node) = as_node::<ImageOnlyCheckpointLoader>(node) {
27            self.model = node.ckpt_name.value().cloned();
28        } else if let Some(node) = as_node::<EmptyLatentImage>(node) {
29            self.width = node.width.value().cloned();
30            self.height = node.height.value().cloned();
31        } else if let Some(node) = as_node::<KSampler>(node) {
32            self.seed = node.seed.value().cloned();
33        } else if let Some(node) = as_node::<SamplerCustom>(node) {
34            self.seed = node.noise_seed.value().cloned();
35        } else if let Some(node) = as_node::<CLIPTextEncode>(node) {
36            if self.prompt.is_none() {
37                self.prompt = node.text.value().cloned();
38            } else if self.negative_prompt.is_none() {
39                self.negative_prompt = node.text.value().cloned();
40            }
41        }
42        for c in node.connections() {
43            if let Some(node) = prompt.get_node_by_id(c) {
44                self.visit(prompt, node);
45            }
46        }
47    }
48}
49
50pub(crate) struct FindNode<T: Node + 'static> {
51    pub(crate) visiting: String,
52    pub(crate) found: Option<String>,
53    _phantom: std::marker::PhantomData<T>,
54}
55
56impl<T: Node + 'static> FindNode<T> {
57    pub(crate) fn new(start: String) -> Self {
58        Self {
59            visiting: start,
60            found: None,
61            _phantom: std::marker::PhantomData,
62        }
63    }
64}
65
66impl<T: Node + 'static> Visitor for FindNode<T> {
67    fn visit(&mut self, prompt: &Prompt, node: &dyn Node) {
68        if let Some(_node) = as_node::<T>(node) {
69            self.found = Some(self.visiting.clone());
70        }
71        for c in node.connections() {
72            if let Some(node) = prompt.get_node_by_id(c) {
73                self.visiting = c.to_string();
74                self.visit(prompt, node);
75            }
76        }
77    }
78}