comfyui_api/comfy/
visitor.rs1use crate::models::*;
2
3use super::ImageInfo;
4
5pub trait Visitor {
7 fn visit(&mut self, prompt: &Prompt, node: &dyn Node) {
14 for c in node.connections() {
15 if let Some(node) = prompt.get_node_by_id(c) {
16 self.visit(prompt, node);
17 }
18 }
19 }
20}
21
22impl Visitor for ImageInfo {
23 fn visit(&mut self, prompt: &Prompt, node: &dyn Node) {
24 if let Some(node) = as_node::<CheckpointLoaderSimple>(node) {
25 self.model = node.ckpt_name.value().cloned();
26 } else if let Some(node) = as_node::<ImageOnlyCheckpointLoader>(node) {
27 self.model = node.ckpt_name.value().cloned();
28 } else if let Some(node) = as_node::<EmptyLatentImage>(node) {
29 self.width = node.width.value().cloned();
30 self.height = node.height.value().cloned();
31 } else if let Some(node) = as_node::<KSampler>(node) {
32 self.seed = node.seed.value().cloned();
33 } else if let Some(node) = as_node::<SamplerCustom>(node) {
34 self.seed = node.noise_seed.value().cloned();
35 } else if let Some(node) = as_node::<CLIPTextEncode>(node) {
36 if self.prompt.is_none() {
37 self.prompt = node.text.value().cloned();
38 } else if self.negative_prompt.is_none() {
39 self.negative_prompt = node.text.value().cloned();
40 }
41 }
42 for c in node.connections() {
43 if let Some(node) = prompt.get_node_by_id(c) {
44 self.visit(prompt, node);
45 }
46 }
47 }
48}
49
50pub(crate) struct FindNode<T: Node + 'static> {
51 pub(crate) visiting: String,
52 pub(crate) found: Option<String>,
53 _phantom: std::marker::PhantomData<T>,
54}
55
56impl<T: Node + 'static> FindNode<T> {
57 pub(crate) fn new(start: String) -> Self {
58 Self {
59 visiting: start,
60 found: None,
61 _phantom: std::marker::PhantomData,
62 }
63 }
64}
65
66impl<T: Node + 'static> Visitor for FindNode<T> {
67 fn visit(&mut self, prompt: &Prompt, node: &dyn Node) {
68 if let Some(_node) = as_node::<T>(node) {
69 self.found = Some(self.visiting.clone());
70 }
71 for c in node.connections() {
72 if let Some(node) = prompt.get_node_by_id(c) {
73 self.visiting = c.to_string();
74 self.visit(prompt, node);
75 }
76 }
77 }
78}