stable_diffusion_api/
img2img.rs

1use std::collections::HashMap;
2
3use reqwest::Url;
4use serde::{Deserialize, Serialize};
5use serde_with::skip_serializing_none;
6
7use super::ImgResponse;
8
9/// Struct representing an image to image request.
10#[skip_serializing_none]
11#[derive(Default, PartialEq, Serialize, Deserialize, Debug, Clone)]
12pub struct Img2ImgRequest {
13    /// Initial images.
14    pub init_images: Option<Vec<String>>,
15    /// Resize mode.
16    pub resize_mode: Option<u32>,
17    /// Strength of denoising applied to the image.
18    pub denoising_strength: Option<f64>,
19    /// CFG scale.
20    pub image_cfg_scale: Option<u32>,
21    /// Mask.
22    pub mask: Option<String>,
23    /// Blur to apply to the mask.
24    pub mask_blur: Option<u32>,
25    /// Amount of inpainting to apply.
26    pub inpainting_fill: Option<u32>,
27    /// Whether to perform inpainting at full resolution.
28    pub inpaint_full_res: Option<bool>,
29    /// Padding to apply when performing inpainting at full resolution.
30    pub inpaint_full_res_padding: Option<u32>,
31    /// Whether to invert the inpainting mask.
32    pub inpainting_mask_invert: Option<u32>,
33    /// Initial noise multiplier.
34    pub initial_noise_multiplier: Option<u32>,
35    /// Text prompt for generating the image.
36    pub prompt: Option<String>,
37    /// List of style prompts for generating the image.
38    pub styles: Option<Vec<String>>,
39    /// Seed.
40    pub seed: Option<i64>,
41    /// Subseed.
42    pub subseed: Option<i64>,
43    /// Strength of the subseed.
44    pub subseed_strength: Option<u32>,
45    /// Height to resize the seed image from.
46    pub seed_resize_from_h: Option<i32>,
47    /// Width to resize the seed image from.
48    pub seed_resize_from_w: Option<i32>,
49    /// Name of the sampler.
50    pub sampler_name: Option<String>,
51    /// Batch size.
52    pub batch_size: Option<u32>,
53    /// Number of iterations.
54    pub n_iter: Option<u32>,
55    /// Number of steps.
56    pub steps: Option<u32>,
57    /// CFG scale.
58    pub cfg_scale: Option<f64>,
59    /// Width of the generated image.
60    pub width: Option<u32>,
61    /// Height of the generated image.
62    pub height: Option<u32>,
63    /// Whether to restore faces in the generated image.
64    pub restore_faces: Option<bool>,
65    /// Whether tiling for the generated image.
66    pub tiling: Option<bool>,
67    /// Whether to save samples when generating multiple images.
68    pub do_not_save_samples: Option<bool>,
69    /// Whether to save the grid when generating multiple images.
70    pub do_not_save_grid: Option<bool>,
71    /// Negative prompt.
72    pub negative_prompt: Option<String>,
73    /// Eta value.
74    pub eta: Option<u32>,
75    /// Churn value.
76    pub s_churn: Option<f64>,
77    /// Maximum temperature value.
78    pub s_tmax: Option<f64>,
79    /// Minimum temperature value.
80    pub s_tmin: Option<f64>,
81    /// Amount of noise.
82    pub s_noise: Option<f64>,
83    /// Settings to override when generating the image.
84    pub override_settings: Option<HashMap<String, serde_json::Value>>,
85    /// Whether to restore the settings after generating the image.
86    pub override_settings_restore_afterwards: Option<bool>,
87    /// Arguments to pass to the script.
88    pub script_args: Option<Vec<serde_json::Value>>,
89    /// Index of the sampler.
90    pub sampler_index: Option<String>,
91    /// Whether to include initial images in the output.
92    pub include_init_images: Option<bool>,
93    /// Name of the script.
94    pub script_name: Option<String>,
95    /// Whether to send the generated images.
96    pub send_images: Option<bool>,
97    /// Whether to save the generated images.
98    pub save_images: Option<bool>,
99    /// Scripts to always run.
100    pub alwayson_scripts: Option<HashMap<String, serde_json::Value>>,
101}
102
103impl Img2ImgRequest {
104    /// Adds a prompt to the request.
105    ///
106    /// # Arguments
107    ///
108    /// * `prompt` - A String representing the prompt to be used for image generation.
109    ///
110    /// # Example
111    ///
112    /// ```
113    /// let mut req = Img2ImgRequest::default();
114    /// req.with_prompt("A blue sky with green grass".to_string());
115    /// ```
116    pub fn with_prompt(&mut self, prompt: String) -> &mut Self {
117        self.prompt = Some(prompt);
118        self
119    }
120
121    /// Adds a single image to the request.
122    ///
123    /// # Arguments
124    ///
125    /// * `image` - array bytes of the image to be added.
126    ///
127    /// # Examples
128    ///
129    /// ```
130    /// use std::fs;
131    /// let mut req = Img2ImgRequest::default();
132    /// let image_data = fs::read("path/to/image.jpg").unwrap();
133    /// req.with_image(image_data);
134    /// ```
135    pub fn with_image<T>(&mut self, image: T) -> &mut Self
136    where
137        T: AsRef<[u8]>,
138    {
139        use base64::{engine::general_purpose, Engine as _};
140
141        if let Some(ref mut images) = self.init_images {
142            images.push(general_purpose::STANDARD.encode(image));
143            self
144        } else {
145            self.with_images(vec![image])
146        }
147    }
148
149    /// Adds multiple images to the request.
150    ///
151    /// # Arguments
152    ///
153    /// * `images` - A vector of byte arrays that represents the images to be added.
154    ///
155    /// # Examples
156    ///
157    /// ```
158    /// use std::fs;
159    /// let mut req = Img2ImgRequest::default();
160    /// let image_data1 = fs::read("path/to/image1.jpg").unwrap();
161    /// let image_data2 = fs::read("path/to/image2.jpg").unwrap();
162    /// req.with_images(vec![image_data1, image_data2]);
163    /// ```
164    pub fn with_images<T>(&mut self, images: Vec<T>) -> &mut Self
165    where
166        T: AsRef<[u8]>,
167    {
168        use base64::{engine::general_purpose, Engine as _};
169        let images = images.iter().map(|i| general_purpose::STANDARD.encode(i));
170        if let Some(ref mut i) = self.init_images {
171            i.extend(images);
172        } else {
173            self.init_images = Some(images.collect())
174        }
175        self
176    }
177
178    /// Adds styles to the request.
179    ///
180    /// # Arguments
181    ///
182    /// * `styles` - A vector of Strings representing the styles to be used for image generation.
183    ///
184    /// # Examples
185    ///
186    /// ```
187    /// let mut req = Img2ImgRequest::default();
188    /// req.with_styles(vec!["cubism".to_string(), "impressionism".to_string()]);
189    /// ```
190    pub fn with_styles(&mut self, styles: Vec<String>) -> &mut Self {
191        if let Some(ref mut s) = self.styles {
192            s.extend(styles);
193        } else {
194            self.styles = Some(styles);
195        }
196        self
197    }
198
199    /// Adds a style to the request.
200    ///
201    ///
202    /// # Arguments
203    ///
204    /// * `style` - A String representing the style to be used for image generation.
205    ///
206    /// # Examples
207    ///
208    /// ```
209    /// let mut req = Img2ImgRequest::default();
210    /// req.with_style("cubism".to_string());
211    /// ```
212    pub fn with_style(&mut self, style: String) -> &mut Self {
213        if let Some(ref mut styles) = self.styles {
214            styles.push(style);
215            self
216        } else {
217            self.with_styles(vec![style])
218        }
219    }
220
221    /// Sets the denoising strength for image generation.
222    ///
223    /// # Arguments
224    ///
225    /// * `denoising_strength` - A f64 value representing the denoising strength parameter.
226    ///
227    /// # Examples
228    ///
229    /// ```
230    /// let mut req = Img2ImgRequest::default();
231    /// req.with_denoising_strength(0.4);
232    /// ```
233    pub fn with_denoising_strength(&mut self, denoising_strength: f64) -> &mut Self {
234        self.denoising_strength = Some(denoising_strength);
235        self
236    }
237
238    /// Sets the seed for random number generation.
239    ///
240    /// # Arguments
241    ///
242    /// * `seed` - An i64 value representing the seed for random number generation.
243    ///            Set to `-1` to randomize.
244    ///
245    /// # Example
246    ///
247    /// ```
248    /// let mut req = Img2ImgRequest::default();
249    /// req.with_seed(12345);
250    /// ```
251    pub fn with_seed(&mut self, seed: i64) -> &mut Self {
252        self.seed = Some(seed);
253        self
254    }
255
256    /// Sets the subseed for random number generation.
257    ///
258    /// # Arguments
259    ///
260    /// * `subseed` - An i64 value representing the subseed for random number generation.
261    ///               Set to `-1` to randomize.
262    ///
263    /// # Example
264    ///
265    /// ```
266    /// let mut req = Img2ImgRequest::default();
267    /// req.with_subseed(12345);
268    /// ```
269    pub fn with_subseed(&mut self, subseed: i64) -> &mut Self {
270        self.subseed = Some(subseed);
271        self
272    }
273
274    /// Sets the strength of the subseed parameter for image generation.
275    ///
276    /// # Arguments
277    ///
278    /// * `subseed_strength` - A u32 value representing the strength of the subseed parameter.
279    ///
280    /// # Example
281    ///
282    /// ```
283    /// let mut req = Img2ImgRequest::default();
284    /// req.with_subseed_strength(5);
285    /// ```
286    pub fn with_subseed_strength(&mut self, subseed_strength: u32) -> &mut Self {
287        self.subseed_strength = Some(subseed_strength);
288        self
289    }
290
291    /// Sets the sampler name for image generation.
292    ///
293    /// # Arguments
294    ///
295    /// * `sampler_name` - A String representing the sampler name to be used.
296    ///
297    /// # Examples
298    ///
299    /// ```
300    /// let mut req = Img2ImgRequest::default();
301    /// req.with_sampler_name("Euler".to_string());
302    /// ```
303    pub fn with_sampler_name(&mut self, sampler_name: String) -> &mut Self {
304        self.sampler_name = Some(sampler_name);
305        self
306    }
307
308    /// Sets the batch size for image generation.
309    ///
310    /// # Arguments
311    ///
312    /// * `batch_size` - A u32 value representing the batch size to be used.
313    ///
314    /// # Examples
315    ///
316    /// ```
317    /// let mut req = Img2ImgRequest::default();
318    /// req.with_batch_size(16);
319    /// ```
320    pub fn with_batch_size(&mut self, batch_size: u32) -> &mut Self {
321        self.batch_size = Some(batch_size);
322        self
323    }
324
325    /// Sets the number of iterations for image generation.
326    ///
327    /// # Arguments
328    ///
329    /// * `n_iter` - A u32 value representing the number of iterations to run for image generation.
330    ///
331    /// # Examples
332    ///
333    /// ```
334    /// let mut req = Img2ImgRequest::default();
335    /// req.with_n_iter(1000);
336    /// ```
337    pub fn with_n_iter(&mut self, n_iter: u32) -> &mut Self {
338        self.n_iter = Some(n_iter);
339        self
340    }
341
342    /// Sets the number of steps for image generation.
343    ///
344    /// # Arguments
345    ///
346    /// * `steps` - A u32 value representing the number of steps for image generation.
347    ///
348    /// # Examples
349    ///
350    /// ```
351    /// let mut req = Img2ImgRequest::default();
352    /// req.with_steps(50);
353    /// ```
354    pub fn with_steps(&mut self, steps: u32) -> &mut Self {
355        self.steps = Some(steps);
356        self
357    }
358
359    /// Sets the cfg scale for image generation.
360    ///
361    /// # Arguments
362    ///
363    /// * `cfg_scale` - A f64 value representing the cfg scale parameter.
364    ///
365    /// # Examples
366    ///
367    /// ```
368    /// let mut req = Img2ImgRequest::default();
369    /// req.with_cfg_scale(0.7);
370    /// ```
371    pub fn with_cfg_scale(&mut self, cfg_scale: f64) -> &mut Self {
372        self.cfg_scale = Some(cfg_scale);
373        self
374    }
375
376    /// Sets the width for image generation.
377    ///
378    /// # Arguments
379    ///
380    /// * `width` - A u32 value representing the image width.
381    ///
382    /// # Examples
383    ///
384    /// ```
385    /// let mut req = Img2ImgRequest::default();
386    /// req.with_width(512);
387    /// ```
388    pub fn with_width(&mut self, width: u32) -> &mut Self {
389        self.width = Some(width);
390        self
391    }
392
393    /// Sets the height for image generation.
394    ///
395    /// # Arguments
396    ///
397    /// * `height` - A u32 value representing the image height.
398    ///
399    /// # Examples
400    ///
401    /// ```
402    /// let mut req = Img2ImgRequest::default();
403    /// req.with_height(512);
404    /// ```
405    pub fn with_height(&mut self, height: u32) -> &mut Self {
406        self.height = Some(height);
407        self
408    }
409
410    /// Enable or disable face restoration.
411    ///
412    /// # Arguments
413    ///
414    /// * `restore_faces` - A bool value to enable or disable face restoration.
415    ///
416    /// # Examples
417    ///
418    /// ```
419    /// let mut req = Img2ImgRequest::default();
420    /// req.with_restore_faces(true);
421    /// ```
422    pub fn with_restore_faces(&mut self, restore_faces: bool) -> &mut Self {
423        self.restore_faces = Some(restore_faces);
424        self
425    }
426
427    /// Enable or disable image tiling.
428    ///
429    /// # Arguments
430    ///
431    /// * `tiling` - A bool value to enable or disable tiling.
432    ///
433    /// # Examples
434    ///
435    /// ```
436    /// let mut req = Img2ImgRequest::default();
437    /// req.with_tiling(true);
438    /// ```
439    pub fn with_tiling(&mut self, tiling: bool) -> &mut Self {
440        self.tiling = Some(tiling);
441        self
442    }
443
444    /// Adds a negative prompt to the request.
445    ///
446    /// # Arguments
447    ///
448    /// * `negative_prompt` - A String representing the negative prompt to be used for image generation.
449    ///
450    /// # Example
451    ///
452    /// ```
453    /// let mut req = Img2ImgRequest::default();
454    /// req.with_prompt("bad, ugly, worst quality".to_string());
455    /// ```
456    pub fn with_negative_prompt(&mut self, negative_prompt: String) -> &mut Self {
457        self.negative_prompt = Some(negative_prompt);
458        self
459    }
460
461    /// Merges the given settings with the request's settings.
462    ///
463    /// # Arguments
464    ///
465    /// * `request` - A Img2ImgRequest containing the settings to merge.
466    pub fn merge(&self, request: Self) -> Self {
467        Self {
468            init_images: request.init_images.or(self.init_images.clone()),
469            resize_mode: request.resize_mode.or(self.resize_mode),
470            denoising_strength: request.denoising_strength.or(self.denoising_strength),
471            image_cfg_scale: request.image_cfg_scale.or(self.image_cfg_scale),
472            mask: request.mask.or(self.mask.clone()),
473            mask_blur: request.mask_blur.or(self.mask_blur),
474            inpainting_fill: request.inpainting_fill.or(self.inpainting_fill),
475            inpaint_full_res: request.inpaint_full_res.or(self.inpaint_full_res),
476            inpaint_full_res_padding: request
477                .inpaint_full_res_padding
478                .or(self.inpaint_full_res_padding),
479            inpainting_mask_invert: request
480                .inpainting_mask_invert
481                .or(self.inpainting_mask_invert),
482            initial_noise_multiplier: request
483                .initial_noise_multiplier
484                .or(self.initial_noise_multiplier),
485            prompt: request.prompt.or(self.prompt.clone()),
486            styles: request.styles.or(self.styles.clone()),
487            seed: request.seed.or(self.seed),
488            subseed: request.subseed.or(self.subseed),
489            subseed_strength: request.subseed_strength.or(self.subseed_strength),
490            seed_resize_from_h: request.seed_resize_from_h.or(self.seed_resize_from_h),
491            seed_resize_from_w: request.seed_resize_from_w.or(self.seed_resize_from_w),
492            sampler_name: request.sampler_name.or(self.sampler_name.clone()),
493            batch_size: request.batch_size.or(self.batch_size),
494            n_iter: request.n_iter.or(self.n_iter),
495            steps: request.steps.or(self.steps),
496            cfg_scale: request.cfg_scale.or(self.cfg_scale),
497            width: request.width.or(self.width),
498            height: request.height.or(self.height),
499            restore_faces: request.restore_faces.or(self.restore_faces),
500            tiling: request.tiling.or(self.tiling),
501            do_not_save_samples: request.do_not_save_samples.or(self.do_not_save_samples),
502            do_not_save_grid: request.do_not_save_grid.or(self.do_not_save_grid),
503            negative_prompt: request.negative_prompt.or(self.negative_prompt.clone()),
504            eta: request.eta.or(self.eta),
505            s_churn: request.s_churn.or(self.s_churn),
506            s_tmax: request.s_tmax.or(self.s_tmax),
507            s_tmin: request.s_tmin.or(self.s_tmin),
508            s_noise: request.s_noise.or(self.s_noise),
509            override_settings: request.override_settings.or(self.override_settings.clone()),
510            override_settings_restore_afterwards: request
511                .override_settings_restore_afterwards
512                .or(self.override_settings_restore_afterwards),
513            script_args: request.script_args.or(self.script_args.clone()),
514            sampler_index: request.sampler_index.or(self.sampler_index.clone()),
515            include_init_images: request.include_init_images.or(self.include_init_images),
516            script_name: request.script_name.or(self.script_name.clone()),
517            send_images: request.send_images.or(self.send_images),
518            save_images: request.save_images.or(self.save_images),
519            alwayson_scripts: request.alwayson_scripts.or(self.alwayson_scripts.clone()),
520        }
521    }
522}
523
524/// Errors that can occur when interacting with the `Img2Img` API.
525#[derive(thiserror::Error, Debug)]
526#[non_exhaustive]
527pub enum Img2ImgError {
528    /// Error parsing endpoint URL
529    #[error("Failed to parse endpoint URL")]
530    ParseError(#[from] url::ParseError),
531    /// Error sending request
532    #[error("Failed to send request")]
533    RequestFailed(#[from] reqwest::Error),
534    /// An error occurred while parsing the response from the API.
535    #[error("Parsing response failed")]
536    InvalidResponse(#[source] reqwest::Error),
537    /// An error occurred getting response data.
538    #[error("Failed to get response data")]
539    GetDataFailed(#[source] reqwest::Error),
540    /// Server returned an error for img2img
541    #[error("Img2Img request failed: {status}: {error}")]
542    Img2ImgFailed {
543        status: reqwest::StatusCode,
544        error: String,
545    },
546}
547
548type Result<T> = std::result::Result<T, Img2ImgError>;
549
550/// A client for sending image requests to a specified endpoint.
551pub struct Img2Img {
552    client: reqwest::Client,
553    endpoint: Url,
554}
555
556impl Img2Img {
557    /// Constructs a new Img2Img client with a given `reqwest::Client` and Stable Diffusion API
558    /// endpoint `String`.
559    ///
560    /// # Arguments
561    ///
562    /// * `client` - A `reqwest::Client` used to send requests.
563    /// * `endpoint` - A `String` representation of the endpoint url.
564    ///
565    /// # Returns
566    ///
567    /// A `Result` containing a new Img2Img instance on success, or an error if url parsing failed.
568    pub fn new(client: reqwest::Client, endpoint: String) -> Result<Self> {
569        Ok(Self::new_with_url(client, Url::parse(&endpoint)?))
570    }
571
572    /// Constructs a new Img2Img client with a given `reqwest::Client` and endpoint `Url`.
573    ///
574    /// # Arguments
575    ///
576    /// * `client` - A `reqwest::Client` used to send requests.
577    /// * `endpoint` - A `Url` representing the endpoint url.
578    ///
579    /// # Returns
580    ///
581    /// A new Img2Img instance.
582    pub fn new_with_url(client: reqwest::Client, endpoint: Url) -> Self {
583        Self { client, endpoint }
584    }
585
586    /// Sends an image request using the Img2Img client.
587    ///
588    /// # Arguments
589    ///
590    /// * `request` - An Img2ImgRequest containing the parameters for the image request.
591    ///
592    /// # Returns
593    ///
594    /// A `Result` containing an `ImgResponse<Img2ImgRequest>` on success, or an error if one occurred.
595    pub async fn send(&self, request: &Img2ImgRequest) -> Result<ImgResponse<Img2ImgRequest>> {
596        let response = self
597            .client
598            .post(self.endpoint.clone())
599            .json(&request)
600            .send()
601            .await?;
602        if response.status().is_success() {
603            return response.json().await.map_err(Img2ImgError::InvalidResponse);
604        }
605        let status = response.status();
606        let text = response.text().await.map_err(Img2ImgError::GetDataFailed)?;
607        Err(Img2ImgError::Img2ImgFailed {
608            status,
609            error: text,
610        })
611    }
612}