sal_e_api/
gen_params.rs

1use anyhow::Context as _;
2use comfyui_api::{
3    comfy::getter::*,
4    models::{AsAny, Prompt},
5};
6use dyn_clone::DynClone;
7use serde::{Deserialize, Serialize};
8use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest};
9
10dyn_clone::clone_trait_object!(GenParams);
11
12/// Trait representing an interface to image generation parameters.
13#[typetag::serde]
14pub trait GenParams: std::fmt::Debug + AsAny + Send + Sync + DynClone {
15    /// Gets the seed.
16    fn seed(&self) -> Option<i64>;
17    /// Sets the seed.
18    fn set_seed(&mut self, seed: i64);
19
20    /// Gets the number of steps.
21    fn steps(&self) -> Option<u32>;
22    /// Sets the number of steps.
23    fn set_steps(&mut self, steps: u32);
24
25    /// Gets the number of images to generate.
26    fn count(&self) -> Option<u32>;
27    /// Sets the number of images to generate.
28    fn set_count(&mut self, count: u32);
29
30    /// Gets the CFG scale.
31    fn cfg(&self) -> Option<f32>;
32    /// Sets the CFG scale.
33    fn set_cfg(&mut self, cfg: f32);
34
35    /// Gets the image width.
36    fn width(&self) -> Option<u32>;
37    /// Sets the image width.
38    fn set_width(&mut self, width: u32);
39
40    /// Gets the image height.
41    fn height(&self) -> Option<u32>;
42    /// Sets the image height.
43    fn set_height(&mut self, height: u32);
44
45    /// Gets the prompt.
46    fn prompt(&self) -> Option<String>;
47    /// Sets the prompt.
48    fn set_prompt(&mut self, prompt: String);
49
50    /// Gets the negative prompt.
51    fn negative_prompt(&self) -> Option<String>;
52    /// Sets the negative prompt.
53    fn set_negative_prompt(&mut self, negative_prompt: String);
54
55    /// Gets the denoising strength.
56    fn denoising(&self) -> Option<f32>;
57    /// Sets the denoising strength.
58    fn set_denoising(&mut self, denoising: f32);
59
60    /// Gets the sampler.
61    fn sampler(&self) -> Option<String>;
62    /// Sets the sampler.
63    fn set_sampler(&mut self, sampler: String);
64
65    /// Gets the batch size.
66    fn batch_size(&self) -> Option<u32>;
67    /// Sets the batch size.
68    fn set_batch_size(&mut self, batch_size: u32);
69
70    /// Gets the image.
71    fn image(&self) -> Option<Vec<u8>>;
72    /// Sets the image.
73    fn set_image(&mut self, image: Option<Vec<u8>>);
74}
75
76/// A struct representing the parameters for ComfyUI image generation.
77#[derive(Debug, Clone, Default, Serialize, Deserialize)]
78pub struct ComfyParams {
79    /// The ComfyUI prompt to use for generation.
80    #[serde(skip)]
81    pub prompt: Option<comfyui_api::models::Prompt>,
82    /// The random seed to use for generation.
83    pub seed: Option<i64>,
84    /// The number of steps to take for generation.
85    pub steps: Option<u32>,
86    /// The number of images to generate.
87    pub count: u32,
88    /// The CFG scale to use for generation.
89    pub cfg: Option<f32>,
90    /// The image width to use for generation.
91    pub width: Option<u32>,
92    /// The image height to use for generation.
93    pub height: Option<u32>,
94    /// The prompt text to use for generation.
95    pub prompt_text: Option<String>,
96    /// The negative prompt text to use for generation.
97    pub negative_prompt_text: Option<String>,
98    /// The denoising strength to use for generation.
99    pub denoising: Option<f32>,
100    /// The sampler to use for generation.
101    pub sampler: Option<String>,
102    /// The batch size to use for generation.
103    pub batch_size: Option<u32>,
104    /// The image to use for generation.
105    pub image: Option<Vec<u8>>,
106}
107
108impl ComfyParams {
109    /// Applies the parameters to the provided prompt.
110    ///
111    /// # Arguments
112    ///
113    /// * `prompt` - The prompt to apply the parameters to.
114    ///
115    /// # Returns
116    ///
117    /// The prompt with the parameters applied.
118    pub fn apply_to(&self, prompt: &Prompt) -> Prompt {
119        let mut prompt = prompt.clone();
120
121        if let Some(seed) = self.seed {
122            _ = prompt.seed_mut().map(|s| *s = seed);
123        }
124
125        if let Some(steps) = self.steps {
126            _ = prompt.steps_mut().map(|s| *s = steps);
127        }
128
129        if let Some(cfg) = self.cfg {
130            _ = prompt.cfg_mut().map(|c| *c = cfg);
131        }
132
133        if let Some(width) = self.width {
134            _ = prompt.width_mut().map(|w| *w = width);
135        }
136
137        if let Some(height) = self.height {
138            _ = prompt.height_mut().map(|h| *h = height);
139        }
140
141        if let Some(prompt_text) = &self.prompt_text {
142            _ = prompt.prompt_mut().map(|p| *p = prompt_text.clone());
143        }
144
145        if let Some(negative_prompt_text) = &self.negative_prompt_text {
146            _ = prompt
147                .negative_prompt_mut()
148                .map(|p| *p = negative_prompt_text.clone());
149        }
150
151        if let Some(denoising) = self.denoising {
152            _ = prompt.denoise_mut().map(|d| *d = denoising);
153        }
154
155        if let Some(sampler) = &self.sampler {
156            _ = prompt.sampler_name_mut().map(|s| *s = sampler.clone());
157        }
158
159        if let Some(batch_size) = self.batch_size {
160            _ = prompt.batch_size_mut().map(|b| *b = batch_size);
161        }
162
163        prompt
164    }
165
166    /// Applies the parameters to the current prompt.
167    ///
168    /// # Returns
169    ///
170    /// The prompt with the parameters applied.
171    pub fn apply(&self) -> Option<Prompt> {
172        self.prompt.as_ref().map(|prompt| self.apply_to(prompt))
173    }
174}
175
176impl From<&dyn GenParams> for ComfyParams {
177    fn from(params: &dyn GenParams) -> Self {
178        Self {
179            seed: params.seed(),
180            steps: params.steps(),
181            count: params.count().unwrap_or(1),
182            cfg: params.cfg(),
183            width: params.width(),
184            height: params.height(),
185            prompt_text: params.prompt(),
186            negative_prompt_text: params.negative_prompt(),
187            denoising: params.denoising(),
188            sampler: params.sampler(),
189            batch_size: params.batch_size(),
190            image: params.image(),
191            ..Default::default()
192        }
193    }
194}
195
196#[typetag::serde]
197impl GenParams for ComfyParams {
198    fn seed(&self) -> Option<i64> {
199        self.seed
200            .or_else(|| self.prompt.as_ref()?.seed().ok().copied())
201    }
202
203    fn set_seed(&mut self, seed: i64) {
204        self.seed = Some(seed);
205    }
206
207    fn steps(&self) -> Option<u32> {
208        self.steps
209            .or_else(|| self.prompt.as_ref()?.steps().ok().copied())
210    }
211
212    fn set_steps(&mut self, steps: u32) {
213        self.steps = Some(steps);
214    }
215
216    fn count(&self) -> Option<u32> {
217        Some(self.count)
218    }
219
220    fn set_count(&mut self, count: u32) {
221        self.count = count;
222    }
223
224    fn cfg(&self) -> Option<f32> {
225        self.cfg
226            .or_else(|| self.prompt.as_ref()?.cfg().ok().copied())
227    }
228
229    fn set_cfg(&mut self, cfg: f32) {
230        self.cfg = Some(cfg);
231    }
232
233    fn width(&self) -> Option<u32> {
234        self.width
235            .or_else(|| self.prompt.as_ref()?.width().ok().copied())
236    }
237
238    fn set_width(&mut self, width: u32) {
239        self.width = Some(width);
240    }
241
242    fn height(&self) -> Option<u32> {
243        self.height
244            .or_else(|| self.prompt.as_ref()?.height().ok().copied())
245    }
246
247    fn set_height(&mut self, height: u32) {
248        self.height = Some(height);
249    }
250
251    fn prompt(&self) -> Option<String> {
252        self.prompt_text
253            .clone()
254            .or_else(|| self.prompt.as_ref()?.prompt().ok().cloned())
255    }
256
257    fn set_prompt(&mut self, prompt: String) {
258        self.prompt_text = Some(prompt);
259    }
260
261    fn negative_prompt(&self) -> Option<String> {
262        self.negative_prompt_text
263            .clone()
264            .or_else(|| self.prompt.as_ref()?.negative_prompt().ok().cloned())
265    }
266
267    fn set_negative_prompt(&mut self, negative_prompt: String) {
268        self.negative_prompt_text = Some(negative_prompt);
269    }
270
271    fn denoising(&self) -> Option<f32> {
272        self.denoising
273            .or_else(|| self.prompt.as_ref()?.denoise().ok().copied())
274    }
275
276    fn set_denoising(&mut self, denoising: f32) {
277        self.denoising = Some(denoising);
278    }
279
280    fn sampler(&self) -> Option<String> {
281        self.sampler
282            .clone()
283            .or_else(|| self.prompt.as_ref()?.sampler_name().ok().cloned())
284    }
285
286    fn set_sampler(&mut self, sampler: String) {
287        self.sampler = Some(sampler);
288    }
289
290    fn batch_size(&self) -> Option<u32> {
291        self.batch_size
292            .or_else(|| self.prompt.as_ref()?.batch_size().ok().copied())
293    }
294
295    fn set_batch_size(&mut self, batch_size: u32) {
296        self.batch_size = Some(batch_size);
297    }
298
299    fn image(&self) -> Option<Vec<u8>> {
300        self.image.clone()
301    }
302
303    fn set_image(&mut self, image: Option<Vec<u8>>) {
304        self.image = image;
305    }
306}
307
308/// A struct representing the parameters for image generation in the Stable Diffusion WebUI API.
309#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
310pub struct Txt2ImgParams {
311    /// The parameters provided by the user.
312    pub user_params: Txt2ImgRequest,
313    /// The default parameters.
314    #[serde(skip)]
315    pub defaults: Option<Txt2ImgRequest>,
316}
317
318impl From<&dyn GenParams> for Txt2ImgParams {
319    fn from(params: &dyn GenParams) -> Self {
320        Self {
321            user_params: Txt2ImgRequest {
322                seed: params.seed(),
323                steps: params.steps(),
324                n_iter: params.count(),
325                cfg_scale: params.cfg().map(|c| c as f64),
326                width: params.width(),
327                height: params.height(),
328                prompt: params.prompt(),
329                negative_prompt: params.negative_prompt(),
330                denoising_strength: params.denoising().map(|d| d as f64),
331                sampler_index: params.sampler(),
332                batch_size: params.batch_size(),
333                ..Default::default()
334            },
335            defaults: None,
336        }
337    }
338}
339
340#[typetag::serde]
341impl GenParams for Txt2ImgParams {
342    fn seed(&self) -> Option<i64> {
343        self.user_params
344            .seed
345            .or_else(|| self.defaults.as_ref()?.seed)
346    }
347
348    fn set_seed(&mut self, seed: i64) {
349        self.user_params.seed = Some(seed);
350    }
351
352    fn steps(&self) -> Option<u32> {
353        self.user_params
354            .steps
355            .or_else(|| self.defaults.as_ref()?.steps)
356    }
357
358    fn set_steps(&mut self, steps: u32) {
359        self.user_params.steps = Some(steps);
360    }
361
362    fn count(&self) -> Option<u32> {
363        self.user_params
364            .n_iter
365            .or_else(|| self.defaults.as_ref()?.n_iter)
366    }
367
368    fn set_count(&mut self, count: u32) {
369        self.user_params.n_iter = Some(count);
370    }
371
372    fn cfg(&self) -> Option<f32> {
373        self.user_params
374            .cfg_scale
375            .map(|c| c as f32)
376            .or_else(|| self.defaults.as_ref()?.cfg_scale.map(|c| c as f32))
377    }
378
379    fn set_cfg(&mut self, cfg: f32) {
380        self.user_params.cfg_scale = Some(cfg as f64);
381    }
382
383    fn width(&self) -> Option<u32> {
384        self.user_params
385            .width
386            .or_else(|| self.defaults.as_ref()?.width)
387    }
388
389    fn set_width(&mut self, width: u32) {
390        self.user_params.width = Some(width);
391    }
392
393    fn height(&self) -> Option<u32> {
394        self.user_params
395            .height
396            .or_else(|| self.defaults.as_ref()?.height)
397    }
398
399    fn set_height(&mut self, height: u32) {
400        self.user_params.height = Some(height);
401    }
402
403    fn prompt(&self) -> Option<String> {
404        self.user_params
405            .prompt
406            .clone()
407            .or_else(|| self.defaults.as_ref()?.prompt.clone())
408    }
409
410    fn set_prompt(&mut self, prompt: String) {
411        self.user_params.prompt = Some(prompt);
412    }
413
414    fn negative_prompt(&self) -> Option<String> {
415        self.user_params
416            .negative_prompt
417            .clone()
418            .or_else(|| self.defaults.as_ref()?.negative_prompt.clone())
419    }
420
421    fn set_negative_prompt(&mut self, negative_prompt: String) {
422        self.user_params.negative_prompt = Some(negative_prompt);
423    }
424
425    fn denoising(&self) -> Option<f32> {
426        self.user_params
427            .denoising_strength
428            .map(|d| d as f32)
429            .or_else(|| self.defaults.as_ref()?.denoising_strength.map(|d| d as f32))
430    }
431
432    fn set_denoising(&mut self, denoising: f32) {
433        self.user_params.denoising_strength = Some(denoising as f64);
434    }
435
436    fn sampler(&self) -> Option<String> {
437        self.user_params
438            .sampler_index
439            .clone()
440            .or_else(|| self.defaults.as_ref()?.sampler_index.clone())
441    }
442
443    fn set_sampler(&mut self, sampler: String) {
444        self.user_params.sampler_index = Some(sampler);
445    }
446
447    fn batch_size(&self) -> Option<u32> {
448        self.user_params
449            .batch_size
450            .or_else(|| self.defaults.as_ref()?.batch_size)
451    }
452
453    fn set_batch_size(&mut self, batch_size: u32) {
454        self.user_params.batch_size = Some(batch_size);
455    }
456
457    fn image(&self) -> Option<Vec<u8>> {
458        None
459    }
460
461    fn set_image(&mut self, _image: Option<Vec<u8>>) {}
462}
463
464/// A struct representing the parameters for image generation in the Stable Diffusion WebUI API.
465#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
466pub struct Img2ImgParams {
467    /// The parameters provided by the user.
468    pub user_params: Img2ImgRequest,
469    /// The default parameters.
470    #[serde(skip)]
471    pub defaults: Option<Img2ImgRequest>,
472}
473
474impl From<&dyn GenParams> for Img2ImgParams {
475    fn from(params: &dyn GenParams) -> Self {
476        Self {
477            user_params: Img2ImgRequest {
478                seed: params.seed(),
479                steps: params.steps(),
480                n_iter: params.count(),
481                cfg_scale: params.cfg().map(|c| c as f64),
482                width: params.width(),
483                height: params.height(),
484                prompt: params.prompt(),
485                negative_prompt: params.negative_prompt(),
486                denoising_strength: params.denoising().map(|d| d as f64),
487                sampler_index: params.sampler(),
488                batch_size: params.batch_size(),
489                ..Default::default()
490            },
491            defaults: None,
492        }
493    }
494}
495
496#[typetag::serde]
497impl GenParams for Img2ImgParams {
498    fn seed(&self) -> Option<i64> {
499        self.user_params
500            .seed
501            .or_else(|| self.defaults.as_ref()?.seed)
502    }
503
504    fn set_seed(&mut self, seed: i64) {
505        self.user_params.seed = Some(seed);
506    }
507
508    fn steps(&self) -> Option<u32> {
509        self.user_params
510            .steps
511            .or_else(|| self.defaults.as_ref()?.steps)
512    }
513
514    fn set_steps(&mut self, steps: u32) {
515        self.user_params.steps = Some(steps);
516    }
517
518    fn count(&self) -> Option<u32> {
519        self.user_params
520            .n_iter
521            .or_else(|| self.defaults.as_ref()?.n_iter)
522    }
523
524    fn set_count(&mut self, count: u32) {
525        self.user_params.n_iter = Some(count);
526    }
527
528    fn cfg(&self) -> Option<f32> {
529        self.user_params
530            .cfg_scale
531            .map(|c| c as f32)
532            .or_else(|| self.defaults.as_ref()?.cfg_scale.map(|c| c as f32))
533    }
534
535    fn set_cfg(&mut self, cfg: f32) {
536        self.user_params.cfg_scale = Some(cfg as f64);
537    }
538
539    fn width(&self) -> Option<u32> {
540        self.user_params
541            .width
542            .or_else(|| self.defaults.as_ref()?.width)
543    }
544
545    fn set_width(&mut self, width: u32) {
546        self.user_params.width = Some(width);
547    }
548
549    fn height(&self) -> Option<u32> {
550        self.user_params
551            .height
552            .or_else(|| self.defaults.as_ref()?.height)
553    }
554
555    fn set_height(&mut self, height: u32) {
556        self.user_params.height = Some(height);
557    }
558
559    fn prompt(&self) -> Option<String> {
560        self.user_params
561            .prompt
562            .clone()
563            .or_else(|| self.defaults.as_ref()?.prompt.clone())
564    }
565
566    fn set_prompt(&mut self, prompt: String) {
567        self.user_params.prompt = Some(prompt);
568    }
569
570    fn negative_prompt(&self) -> Option<String> {
571        self.user_params
572            .negative_prompt
573            .clone()
574            .or_else(|| self.defaults.as_ref()?.negative_prompt.clone())
575    }
576
577    fn set_negative_prompt(&mut self, negative_prompt: String) {
578        self.user_params.negative_prompt = Some(negative_prompt);
579    }
580
581    fn denoising(&self) -> Option<f32> {
582        self.user_params
583            .denoising_strength
584            .map(|d| d as f32)
585            .or_else(|| self.defaults.as_ref()?.denoising_strength.map(|d| d as f32))
586    }
587
588    fn set_denoising(&mut self, denoising: f32) {
589        self.user_params.denoising_strength = Some(denoising as f64);
590    }
591
592    fn sampler(&self) -> Option<String> {
593        self.user_params
594            .sampler_index
595            .clone()
596            .or_else(|| self.defaults.as_ref()?.sampler_index.clone())
597    }
598
599    fn set_sampler(&mut self, sampler: String) {
600        self.user_params.sampler_index = Some(sampler);
601    }
602
603    fn batch_size(&self) -> Option<u32> {
604        self.user_params
605            .batch_size
606            .or_else(|| self.defaults.as_ref()?.batch_size)
607    }
608
609    fn set_batch_size(&mut self, batch_size: u32) {
610        self.user_params.batch_size = Some(batch_size);
611    }
612
613    fn image(&self) -> Option<Vec<u8>> {
614        if let Some(ref images) = self.user_params.init_images {
615            use base64::{engine::general_purpose, Engine as _};
616            images
617                .iter()
618                .map(|img| {
619                    general_purpose::STANDARD
620                        .decode(img)
621                        .context("failed to decode image")
622                })
623                .collect::<anyhow::Result<Vec<_>>>()
624                .ok()
625                .and_then(|mut images| images.pop())
626        } else {
627            None
628        }
629    }
630
631    fn set_image(&mut self, image: Option<Vec<u8>>) {
632        if let Some(image) = image {
633            self.user_params.with_image(image);
634        } else {
635            _ = self.user_params.init_images.take()
636        }
637    }
638}