1use std::collections::HashSet;
2
3use anyhow::{anyhow, Context};
4
5use crate::models::*;
6use crate::{comfy::visitor::FindNode, comfy::Visitor};
7
8use super::accessors;
9
10pub trait Getter<T, N>
26where
27 N: Node + 'static,
28 Self: Default,
29{
30 fn get<'a>(&self, prompt: &'a Prompt) -> anyhow::Result<&'a T> {
40 let node = if let Some(node) = Self::guess_node(prompt, None) {
41 node
42 } else {
43 return Err(anyhow!("Failed to find node"));
44 };
45 self.get_value(node)
46 }
47
48 fn get_mut<'a>(&self, prompt: &'a mut Prompt) -> anyhow::Result<&'a mut T> {
58 let node = if let Some(node) = Self::guess_node_mut(prompt, None) {
59 node
60 } else {
61 return Err(anyhow!("Failed to find node"));
62 };
63 self.get_value_mut(node)
64 }
65
66 fn get_from<'a>(&self, prompt: &'a Prompt, output_node: &str) -> anyhow::Result<&'a T> {
76 let node = if let Some(node) = Self::find_node(prompt, Some(output_node)) {
77 prompt
78 .get_node_by_id(&node)
79 .context("Failed to find node")?
80 } else {
81 return Err(anyhow!("Failed to find node"));
82 };
83 self.get_value(node)
84 }
85
86 fn get_from_mut<'a>(
96 &self,
97 prompt: &'a mut Prompt,
98 output_node: &str,
99 ) -> anyhow::Result<&'a mut T> {
100 let node = if let Some(node) = Self::find_node(prompt, Some(output_node)) {
101 prompt
102 .get_node_by_id_mut(&node)
103 .context("Failed to find node")?
104 } else {
105 return Err(anyhow!("Failed to find node"));
106 };
107 self.get_value_mut(node)
108 }
109
110 fn get_node<'a>(&self, prompt: &'a Prompt, node: &str) -> anyhow::Result<&'a T> {
121 let node = prompt.get_node_by_id(node).unwrap();
122 self.get_value(node)
123 }
124
125 fn get_node_mut<'a>(&self, prompt: &'a mut Prompt, node: &str) -> anyhow::Result<&'a mut T> {
136 let node = prompt.get_node_by_id_mut(node).unwrap();
137 self.get_value_mut(node)
138 }
139
140 fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a T>;
150
151 fn get_value_mut<'a>(&self, node: &'a mut dyn Node) -> anyhow::Result<&'a mut T>;
161
162 fn find_node(prompt: &Prompt, output_node: Option<&str>) -> Option<String> {
172 find_node::<N>(prompt, output_node)
173 }
174
175 fn guess_node<'a>(prompt: &'a Prompt, output_node: Option<&str>) -> Option<&'a dyn Node> {
186 if let Some(node) = Self::find_node(prompt, output_node) {
187 prompt.get_node_by_id(&node)
188 } else if let Some((_, node)) = prompt.get_nodes_by_type::<N>().next() {
189 Some(node)
190 } else {
191 None
192 }
193 }
194
195 fn guess_node_mut<'a>(
206 prompt: &'a mut Prompt,
207 output_node: Option<&str>,
208 ) -> Option<&'a mut dyn Node> {
209 if let Some(node) = Self::find_node(prompt, output_node) {
210 prompt.get_node_by_id_mut(&node)
211 } else if let Some((_, node)) = prompt.get_nodes_by_type_mut::<N>().next() {
212 Some(node)
213 } else {
214 None
215 }
216 }
217}
218
219pub trait GetExt<N>
221where
222 N: Node + 'static,
223{
224 fn get_typed_node(&self, node: &str) -> anyhow::Result<&N>;
234
235 fn get_typed_node_mut(&mut self, node: &str) -> anyhow::Result<&mut N>;
245}
246
247impl<N: Node + 'static> GetExt<N> for Prompt {
248 fn get_typed_node(&self, node: &str) -> anyhow::Result<&N> {
249 let node = self.get_node_by_id(node).context("failed to get node")?;
250 as_node::<N>(node).context("Failed to cast node")
251 }
252
253 fn get_typed_node_mut(&mut self, node: &str) -> anyhow::Result<&mut N> {
254 let node = self
255 .get_node_by_id_mut(node)
256 .context("failed to get node")?;
257 as_node_mut::<N>(node).context("Failed to cast node")
258 }
259}
260
261trait GetterExt<T, N>
262where
263 N: Node + 'static,
264{
265 fn get<G>(&self) -> anyhow::Result<&T>
266 where
267 G: Getter<T, N>;
268
269 fn get_mut<G>(&mut self) -> anyhow::Result<&mut T>
270 where
271 G: Getter<T, N>;
272
273 fn get_from<G>(&self, output_node: &str) -> anyhow::Result<&T>
274 where
275 G: Getter<T, N>;
276
277 fn get_from_mut<G>(&mut self, output_node: &str) -> anyhow::Result<&mut T>
278 where
279 G: Getter<T, N>;
280
281 fn get_node<G>(&self, node: &str) -> anyhow::Result<&T>
282 where
283 G: Getter<T, N>;
284
285 fn get_node_mut<G>(&mut self, node: &str) -> anyhow::Result<&mut T>
286 where
287 G: Getter<T, N>;
288}
289
290impl<T, N: Node + 'static> GetterExt<T, N> for Prompt {
291 fn get<G>(&self) -> anyhow::Result<&T>
292 where
293 G: Getter<T, N>,
294 {
295 G::default().get(self).context("Failed to get value")
296 }
297
298 fn get_mut<G>(&mut self) -> anyhow::Result<&mut T>
299 where
300 G: Getter<T, N>,
301 {
302 G::default().get_mut(self).context("Failed to get value")
303 }
304
305 fn get_from<G>(&self, output_node: &str) -> anyhow::Result<&T>
306 where
307 G: Getter<T, N>,
308 {
309 G::default()
310 .get_from(self, output_node)
311 .context("Failed to get value")
312 }
313
314 fn get_from_mut<G>(&mut self, output_node: &str) -> anyhow::Result<&mut T>
315 where
316 G: Getter<T, N>,
317 {
318 G::default()
319 .get_from_mut(self, output_node)
320 .context("Failed to get value")
321 }
322
323 fn get_node<G>(&self, node: &str) -> anyhow::Result<&T>
324 where
325 G: Getter<T, N>,
326 {
327 G::default()
328 .get_node(self, node)
329 .context("Failed to get value")
330 }
331
332 fn get_node_mut<G>(&mut self, node: &str) -> anyhow::Result<&mut T>
333 where
334 G: Getter<T, N>,
335 {
336 G::default()
337 .get_node_mut(self, node)
338 .context("Failed to get value")
339 }
340}
341
342pub(crate) fn find_node<T: Node + 'static>(
343 prompt: &Prompt,
344 output_node: Option<&str>,
345) -> Option<String> {
346 let output_node = if let Some(node) = output_node {
347 node.to_string()
348 } else {
349 find_output_node(prompt)?
350 };
351 let mut find_node = FindNode::<T>::new(output_node.clone());
352 find_node.visit(prompt, prompt.get_node_by_id(&output_node).unwrap());
353 find_node.found
354}
355
356#[allow(dead_code)]
357pub(crate) fn guess_node<'a, T: Node + 'static>(
358 prompt: &'a Prompt,
359 output_node: Option<&str>,
360) -> Option<&'a dyn Node> {
361 if let Some(node) = find_node::<T>(prompt, output_node) {
362 prompt.get_node_by_id(&node)
363 } else if let Some((_, node)) = prompt.get_nodes_by_type::<T>().next() {
364 Some(node)
365 } else {
366 None
367 }
368}
369
370pub(crate) fn guess_node_mut<'a, T: Node + 'static>(
371 prompt: &'a mut Prompt,
372 output_node: Option<&str>,
373) -> Option<&'a mut dyn Node> {
374 if let Some(node) = find_node::<T>(prompt, output_node) {
375 prompt.get_node_by_id_mut(&node)
376 } else if let Some((_, node)) = prompt.get_nodes_by_type_mut::<T>().next() {
377 Some(node)
378 } else {
379 None
380 }
381}
382
383pub(crate) fn find_output_node(prompt: &Prompt) -> Option<String> {
384 let nodes: HashSet<String> = prompt.workflow.keys().cloned().collect();
385 prompt
386 .workflow
387 .iter()
388 .fold(nodes, |mut nodes, (key, value)| {
389 let mut has_input = false;
390 let connections = match value {
391 NodeOrUnknown::Node(node) => node.connections(),
392 NodeOrUnknown::GenericNode(node) => node.connections(),
393 };
394 for c in connections {
395 has_input = true;
396 nodes.remove(c);
397 }
398 if !has_input {
399 nodes.remove(key);
400 }
401 nodes
402 })
403 .into_iter()
404 .next()
405}
406
407macro_rules! create_getter {
408 ($ValueType:ty, $NodeType:ty, $AccessorType:ty, $field_name:ident) => {
409 impl Getter<$ValueType, $NodeType> for $AccessorType {
410 fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a $ValueType> {
411 as_node::<$NodeType>(node)
412 .context(concat!("Failed to cast node to ", stringify!($NodeType)))?
413 .$field_name
414 .value()
415 .context(concat!(
416 "Failed to get ",
417 stringify!($getter_name),
418 " value"
419 ))
420 }
421
422 fn get_value_mut<'a>(
423 &self,
424 node: &'a mut dyn Node,
425 ) -> anyhow::Result<&'a mut $ValueType> {
426 as_node_mut::<$NodeType>(node)
427 .context(concat!("Failed to cast node to ", stringify!($NodeType)))?
428 .$field_name
429 .value_mut()
430 .context(concat!(
431 "Failed to get ",
432 stringify!($getter_name),
433 " value"
434 ))
435 }
436 }
437 };
438}
439
440macro_rules! create_ext_trait {
441 ($ValueType:ty, $AccessorType:ty, $getter_name:ident, $getter_name_mut:ident, $TraitName:ident) => {
442 pub trait $TraitName {
444 fn $getter_name(&self) -> anyhow::Result<&$ValueType>;
450
451 fn $getter_name_mut(&mut self) -> anyhow::Result<&mut $ValueType>;
457 }
458
459 impl $TraitName for Prompt {
460 fn $getter_name(&self) -> anyhow::Result<&$ValueType> {
461 self.get::<$AccessorType>()
462 }
463
464 fn $getter_name_mut(&mut self) -> anyhow::Result<&mut $ValueType> {
465 self.get_mut::<$AccessorType>()
466 }
467 }
468 };
469}
470
471impl Getter<String, CLIPTextEncode> for accessors::Prompt {
472 fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a String> {
473 as_node::<CLIPTextEncode>(node)
474 .context("Failed to cast node")?
475 .text
476 .value()
477 .context("Failed to get text value")
478 }
479
480 fn get_value_mut<'a>(&self, node: &'a mut dyn Node) -> anyhow::Result<&'a mut String> {
481 as_node_mut::<CLIPTextEncode>(node)
482 .context("Failed to cast node")?
483 .text
484 .value_mut()
485 .context("Failed to get text value")
486 }
487
488 fn find_node(prompt: &Prompt, output_node: Option<&str>) -> Option<String> {
489 if let Some(node) = find_node::<KSampler>(prompt, output_node) {
490 if let Ok(node) = prompt.get_typed_node(&node) as anyhow::Result<&KSampler> {
491 return Some(node.positive.node_id.clone());
492 }
493 }
494 if let Some(node) = find_node::<SamplerCustom>(prompt, output_node) {
495 if let Ok(node) = prompt.get_typed_node(&node) as anyhow::Result<&SamplerCustom> {
496 return Some(node.positive.node_id.clone());
497 }
498 }
499 None
500 }
501}
502
503create_ext_trait!(String, accessors::Prompt, prompt, prompt_mut, PromptExt);
504
505impl Getter<String, CLIPTextEncode> for accessors::NegativePrompt {
506 fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a String> {
507 accessors::Prompt.get_value(node)
508 }
509
510 fn get_value_mut<'a>(&self, node: &'a mut dyn Node) -> anyhow::Result<&'a mut String> {
511 accessors::Prompt.get_value_mut(node)
512 }
513
514 fn find_node(prompt: &Prompt, output_node: Option<&str>) -> Option<String> {
515 if let Some(node) = find_node::<KSampler>(prompt, output_node) {
516 if let Ok(node) = prompt.get_typed_node(&node) as anyhow::Result<&KSampler> {
517 return Some(node.negative.node_id.clone());
518 }
519 }
520 if let Some(node) = find_node::<SamplerCustom>(prompt, output_node) {
521 if let Ok(node) = prompt.get_typed_node(&node) as anyhow::Result<&SamplerCustom> {
522 return Some(node.negative.node_id.clone());
523 }
524 }
525 None
526 }
527}
528
529create_ext_trait!(
530 String,
531 accessors::NegativePrompt,
532 negative_prompt,
533 negative_prompt_mut,
534 NegativePromptExt
535);
536
537create_getter!(String, CheckpointLoaderSimple, accessors::Model, ckpt_name);
538create_ext_trait!(String, accessors::Model, ckpt_name, ckpt_name_mut, ModelExt);
539
540create_getter!(u32, EmptyLatentImage, accessors::Width, width);
541create_ext_trait!(u32, accessors::Width, width, width_mut, WidthExt);
542
543create_getter!(u32, EmptyLatentImage, accessors::Height, height);
544create_ext_trait!(u32, accessors::Height, height, height_mut, HeightExt);
545
546create_getter!(i64, KSampler, accessors::SeedT<KSampler>, seed);
547create_getter!(
548 i64,
549 SamplerCustom,
550 accessors::SeedT<SamplerCustom>,
551 noise_seed
552);
553
554create_ext_trait!(i64, accessors::Seed, seed, seed_mut, SeedExt);
555
556create_getter!(u32, KSampler, accessors::StepsT<KSampler>, steps);
557create_getter!(
558 u32,
559 SDTurboScheduler,
560 accessors::StepsT<SDTurboScheduler>,
561 steps
562);
563
564create_ext_trait!(u32, accessors::Steps, steps, steps_mut, StepsExt);
565
566impl<S1, S2, T, N1, N2> Getter<T, N1> for accessors::Delegating<S1, S2, T, N1, N2>
567where
568 S1: Getter<T, N1>,
569 S2: Getter<T, N2>,
570 N1: Node + 'static,
571 N2: Node + 'static,
572 T: Clone + Default,
573{
574 fn get<'a>(&self, prompt: &'a Prompt) -> anyhow::Result<&'a T> {
575 S1::default()
576 .get(prompt)
577 .or_else(|_| S2::default().get(prompt).context("Failed to set value"))
578 }
579
580 fn get_mut<'a>(&self, prompt: &'a mut Prompt) -> anyhow::Result<&'a mut T> {
581 let s1 = S1::default();
582 if s1.get(prompt).is_ok() {
583 return s1.get_mut(prompt);
584 }
585 S2::default().get_mut(prompt).context("Failed to set value")
586 }
587
588 fn get_from<'a>(&self, prompt: &'a Prompt, output_node: &str) -> anyhow::Result<&'a T> {
589 S1::default().get_from(prompt, output_node).or_else(|_| {
590 S2::default()
591 .get_from(prompt, output_node)
592 .context("Failed to set value")
593 })
594 }
595
596 fn get_from_mut<'a>(
597 &self,
598 prompt: &'a mut Prompt,
599 output_node: &str,
600 ) -> anyhow::Result<&'a mut T> {
601 let s1 = S1::default();
602 if s1.get_from(prompt, output_node).is_ok() {
603 return s1.get_from_mut(prompt, output_node);
604 }
605 S2::default()
606 .get_from_mut(prompt, output_node)
607 .context("Failed to set value")
608 }
609
610 fn get_node<'a>(&self, prompt: &'a Prompt, node: &str) -> anyhow::Result<&'a T> {
611 S1::default().get_node(prompt, node).or_else(|_| {
612 S2::default()
613 .get_node(prompt, node)
614 .context("Failed to set value")
615 })
616 }
617
618 fn get_node_mut<'a>(&self, prompt: &'a mut Prompt, node: &str) -> anyhow::Result<&'a mut T> {
619 let s1 = S1::default();
620 if s1.get_node(prompt, node).is_ok() {
621 return s1.get_node_mut(prompt, node);
622 }
623 S2::default()
624 .get_node_mut(prompt, node)
625 .context("Failed to set value")
626 }
627
628 fn get_value<'a>(&self, node: &'a dyn Node) -> anyhow::Result<&'a T> {
629 S1::default()
630 .get_value(node)
631 .or_else(|_| S2::default().get_value(node).context("Failed to set value"))
632 }
633
634 fn get_value_mut<'a>(&self, node: &'a mut dyn Node) -> anyhow::Result<&'a mut T> {
635 let s1 = S1::default();
636 if s1.get_value(node).is_ok() {
637 return s1.get_value_mut(node);
638 }
639 S2::default()
640 .get_value_mut(node)
641 .context("Failed to set value")
642 }
643
644 fn find_node(prompt: &Prompt, output_node: Option<&str>) -> Option<String> {
645 find_node::<N1>(prompt, output_node).or_else(|| find_node::<N2>(prompt, output_node))
646 }
647}
648
649create_getter!(f32, KSampler, accessors::CfgT<KSampler>, cfg);
650create_getter!(f32, SamplerCustom, accessors::CfgT<SamplerCustom>, cfg);
651create_ext_trait!(f32, accessors::Cfg, cfg, cfg_mut, CfgExt);
652
653create_getter!(f32, KSampler, accessors::Denoise, denoise);
654create_ext_trait!(f32, accessors::Denoise, denoise, denoise_mut, DenoiseExt);
655
656create_getter!(
657 String,
658 KSampler,
659 accessors::SamplerT<KSampler>,
660 sampler_name
661);
662create_getter!(
663 String,
664 KSamplerSelect,
665 accessors::SamplerT<KSamplerSelect>,
666 sampler_name
667);
668create_ext_trait!(
669 String,
670 accessors::Sampler,
671 sampler_name,
672 sampler_name_mut,
673 SamplerExt
674);
675
676create_getter!(u32, EmptyLatentImage, accessors::BatchSize, batch_size);
677create_ext_trait!(
678 u32,
679 accessors::BatchSize,
680 batch_size,
681 batch_size_mut,
682 BatchSizeExt
683);
684
685create_getter!(String, LoadImage, accessors::LoadImage, image);
686create_ext_trait!(String, accessors::LoadImage, image, image_mut, LoadImageExt);