stable_diffusion_api/img2img.rs
1use std::collections::HashMap;
2
3use reqwest::Url;
4use serde::{Deserialize, Serialize};
5use serde_with::skip_serializing_none;
6
7use super::ImgResponse;
8
9/// Struct representing an image to image request.
10#[skip_serializing_none]
11#[derive(Default, PartialEq, Serialize, Deserialize, Debug, Clone)]
12pub struct Img2ImgRequest {
13 /// Initial images.
14 pub init_images: Option<Vec<String>>,
15 /// Resize mode.
16 pub resize_mode: Option<u32>,
17 /// Strength of denoising applied to the image.
18 pub denoising_strength: Option<f64>,
19 /// CFG scale.
20 pub image_cfg_scale: Option<u32>,
21 /// Mask.
22 pub mask: Option<String>,
23 /// Blur to apply to the mask.
24 pub mask_blur: Option<u32>,
25 /// Amount of inpainting to apply.
26 pub inpainting_fill: Option<u32>,
27 /// Whether to perform inpainting at full resolution.
28 pub inpaint_full_res: Option<bool>,
29 /// Padding to apply when performing inpainting at full resolution.
30 pub inpaint_full_res_padding: Option<u32>,
31 /// Whether to invert the inpainting mask.
32 pub inpainting_mask_invert: Option<u32>,
33 /// Initial noise multiplier.
34 pub initial_noise_multiplier: Option<u32>,
35 /// Text prompt for generating the image.
36 pub prompt: Option<String>,
37 /// List of style prompts for generating the image.
38 pub styles: Option<Vec<String>>,
39 /// Seed.
40 pub seed: Option<i64>,
41 /// Subseed.
42 pub subseed: Option<i64>,
43 /// Strength of the subseed.
44 pub subseed_strength: Option<u32>,
45 /// Height to resize the seed image from.
46 pub seed_resize_from_h: Option<i32>,
47 /// Width to resize the seed image from.
48 pub seed_resize_from_w: Option<i32>,
49 /// Name of the sampler.
50 pub sampler_name: Option<String>,
51 /// Batch size.
52 pub batch_size: Option<u32>,
53 /// Number of iterations.
54 pub n_iter: Option<u32>,
55 /// Number of steps.
56 pub steps: Option<u32>,
57 /// CFG scale.
58 pub cfg_scale: Option<f64>,
59 /// Width of the generated image.
60 pub width: Option<u32>,
61 /// Height of the generated image.
62 pub height: Option<u32>,
63 /// Whether to restore faces in the generated image.
64 pub restore_faces: Option<bool>,
65 /// Whether tiling for the generated image.
66 pub tiling: Option<bool>,
67 /// Whether to save samples when generating multiple images.
68 pub do_not_save_samples: Option<bool>,
69 /// Whether to save the grid when generating multiple images.
70 pub do_not_save_grid: Option<bool>,
71 /// Negative prompt.
72 pub negative_prompt: Option<String>,
73 /// Eta value.
74 pub eta: Option<u32>,
75 /// Churn value.
76 pub s_churn: Option<f64>,
77 /// Maximum temperature value.
78 pub s_tmax: Option<f64>,
79 /// Minimum temperature value.
80 pub s_tmin: Option<f64>,
81 /// Amount of noise.
82 pub s_noise: Option<f64>,
83 /// Settings to override when generating the image.
84 pub override_settings: Option<HashMap<String, serde_json::Value>>,
85 /// Whether to restore the settings after generating the image.
86 pub override_settings_restore_afterwards: Option<bool>,
87 /// Arguments to pass to the script.
88 pub script_args: Option<Vec<serde_json::Value>>,
89 /// Index of the sampler.
90 pub sampler_index: Option<String>,
91 /// Whether to include initial images in the output.
92 pub include_init_images: Option<bool>,
93 /// Name of the script.
94 pub script_name: Option<String>,
95 /// Whether to send the generated images.
96 pub send_images: Option<bool>,
97 /// Whether to save the generated images.
98 pub save_images: Option<bool>,
99 /// Scripts to always run.
100 pub alwayson_scripts: Option<HashMap<String, serde_json::Value>>,
101}
102
103impl Img2ImgRequest {
104 /// Adds a prompt to the request.
105 ///
106 /// # Arguments
107 ///
108 /// * `prompt` - A String representing the prompt to be used for image generation.
109 ///
110 /// # Example
111 ///
112 /// ```
113 /// let mut req = Img2ImgRequest::default();
114 /// req.with_prompt("A blue sky with green grass".to_string());
115 /// ```
116 pub fn with_prompt(&mut self, prompt: String) -> &mut Self {
117 self.prompt = Some(prompt);
118 self
119 }
120
121 /// Adds a single image to the request.
122 ///
123 /// # Arguments
124 ///
125 /// * `image` - array bytes of the image to be added.
126 ///
127 /// # Examples
128 ///
129 /// ```
130 /// use std::fs;
131 /// let mut req = Img2ImgRequest::default();
132 /// let image_data = fs::read("path/to/image.jpg").unwrap();
133 /// req.with_image(image_data);
134 /// ```
135 pub fn with_image<T>(&mut self, image: T) -> &mut Self
136 where
137 T: AsRef<[u8]>,
138 {
139 use base64::{engine::general_purpose, Engine as _};
140
141 if let Some(ref mut images) = self.init_images {
142 images.push(general_purpose::STANDARD.encode(image));
143 self
144 } else {
145 self.with_images(vec![image])
146 }
147 }
148
149 /// Adds multiple images to the request.
150 ///
151 /// # Arguments
152 ///
153 /// * `images` - A vector of byte arrays that represents the images to be added.
154 ///
155 /// # Examples
156 ///
157 /// ```
158 /// use std::fs;
159 /// let mut req = Img2ImgRequest::default();
160 /// let image_data1 = fs::read("path/to/image1.jpg").unwrap();
161 /// let image_data2 = fs::read("path/to/image2.jpg").unwrap();
162 /// req.with_images(vec![image_data1, image_data2]);
163 /// ```
164 pub fn with_images<T>(&mut self, images: Vec<T>) -> &mut Self
165 where
166 T: AsRef<[u8]>,
167 {
168 use base64::{engine::general_purpose, Engine as _};
169 let images = images.iter().map(|i| general_purpose::STANDARD.encode(i));
170 if let Some(ref mut i) = self.init_images {
171 i.extend(images);
172 } else {
173 self.init_images = Some(images.collect())
174 }
175 self
176 }
177
178 /// Adds styles to the request.
179 ///
180 /// # Arguments
181 ///
182 /// * `styles` - A vector of Strings representing the styles to be used for image generation.
183 ///
184 /// # Examples
185 ///
186 /// ```
187 /// let mut req = Img2ImgRequest::default();
188 /// req.with_styles(vec!["cubism".to_string(), "impressionism".to_string()]);
189 /// ```
190 pub fn with_styles(&mut self, styles: Vec<String>) -> &mut Self {
191 if let Some(ref mut s) = self.styles {
192 s.extend(styles);
193 } else {
194 self.styles = Some(styles);
195 }
196 self
197 }
198
199 /// Adds a style to the request.
200 ///
201 ///
202 /// # Arguments
203 ///
204 /// * `style` - A String representing the style to be used for image generation.
205 ///
206 /// # Examples
207 ///
208 /// ```
209 /// let mut req = Img2ImgRequest::default();
210 /// req.with_style("cubism".to_string());
211 /// ```
212 pub fn with_style(&mut self, style: String) -> &mut Self {
213 if let Some(ref mut styles) = self.styles {
214 styles.push(style);
215 self
216 } else {
217 self.with_styles(vec![style])
218 }
219 }
220
221 /// Sets the denoising strength for image generation.
222 ///
223 /// # Arguments
224 ///
225 /// * `denoising_strength` - A f64 value representing the denoising strength parameter.
226 ///
227 /// # Examples
228 ///
229 /// ```
230 /// let mut req = Img2ImgRequest::default();
231 /// req.with_denoising_strength(0.4);
232 /// ```
233 pub fn with_denoising_strength(&mut self, denoising_strength: f64) -> &mut Self {
234 self.denoising_strength = Some(denoising_strength);
235 self
236 }
237
238 /// Sets the seed for random number generation.
239 ///
240 /// # Arguments
241 ///
242 /// * `seed` - An i64 value representing the seed for random number generation.
243 /// Set to `-1` to randomize.
244 ///
245 /// # Example
246 ///
247 /// ```
248 /// let mut req = Img2ImgRequest::default();
249 /// req.with_seed(12345);
250 /// ```
251 pub fn with_seed(&mut self, seed: i64) -> &mut Self {
252 self.seed = Some(seed);
253 self
254 }
255
256 /// Sets the subseed for random number generation.
257 ///
258 /// # Arguments
259 ///
260 /// * `subseed` - An i64 value representing the subseed for random number generation.
261 /// Set to `-1` to randomize.
262 ///
263 /// # Example
264 ///
265 /// ```
266 /// let mut req = Img2ImgRequest::default();
267 /// req.with_subseed(12345);
268 /// ```
269 pub fn with_subseed(&mut self, subseed: i64) -> &mut Self {
270 self.subseed = Some(subseed);
271 self
272 }
273
274 /// Sets the strength of the subseed parameter for image generation.
275 ///
276 /// # Arguments
277 ///
278 /// * `subseed_strength` - A u32 value representing the strength of the subseed parameter.
279 ///
280 /// # Example
281 ///
282 /// ```
283 /// let mut req = Img2ImgRequest::default();
284 /// req.with_subseed_strength(5);
285 /// ```
286 pub fn with_subseed_strength(&mut self, subseed_strength: u32) -> &mut Self {
287 self.subseed_strength = Some(subseed_strength);
288 self
289 }
290
291 /// Sets the sampler name for image generation.
292 ///
293 /// # Arguments
294 ///
295 /// * `sampler_name` - A String representing the sampler name to be used.
296 ///
297 /// # Examples
298 ///
299 /// ```
300 /// let mut req = Img2ImgRequest::default();
301 /// req.with_sampler_name("Euler".to_string());
302 /// ```
303 pub fn with_sampler_name(&mut self, sampler_name: String) -> &mut Self {
304 self.sampler_name = Some(sampler_name);
305 self
306 }
307
308 /// Sets the batch size for image generation.
309 ///
310 /// # Arguments
311 ///
312 /// * `batch_size` - A u32 value representing the batch size to be used.
313 ///
314 /// # Examples
315 ///
316 /// ```
317 /// let mut req = Img2ImgRequest::default();
318 /// req.with_batch_size(16);
319 /// ```
320 pub fn with_batch_size(&mut self, batch_size: u32) -> &mut Self {
321 self.batch_size = Some(batch_size);
322 self
323 }
324
325 /// Sets the number of iterations for image generation.
326 ///
327 /// # Arguments
328 ///
329 /// * `n_iter` - A u32 value representing the number of iterations to run for image generation.
330 ///
331 /// # Examples
332 ///
333 /// ```
334 /// let mut req = Img2ImgRequest::default();
335 /// req.with_n_iter(1000);
336 /// ```
337 pub fn with_n_iter(&mut self, n_iter: u32) -> &mut Self {
338 self.n_iter = Some(n_iter);
339 self
340 }
341
342 /// Sets the number of steps for image generation.
343 ///
344 /// # Arguments
345 ///
346 /// * `steps` - A u32 value representing the number of steps for image generation.
347 ///
348 /// # Examples
349 ///
350 /// ```
351 /// let mut req = Img2ImgRequest::default();
352 /// req.with_steps(50);
353 /// ```
354 pub fn with_steps(&mut self, steps: u32) -> &mut Self {
355 self.steps = Some(steps);
356 self
357 }
358
359 /// Sets the cfg scale for image generation.
360 ///
361 /// # Arguments
362 ///
363 /// * `cfg_scale` - A f64 value representing the cfg scale parameter.
364 ///
365 /// # Examples
366 ///
367 /// ```
368 /// let mut req = Img2ImgRequest::default();
369 /// req.with_cfg_scale(0.7);
370 /// ```
371 pub fn with_cfg_scale(&mut self, cfg_scale: f64) -> &mut Self {
372 self.cfg_scale = Some(cfg_scale);
373 self
374 }
375
376 /// Sets the width for image generation.
377 ///
378 /// # Arguments
379 ///
380 /// * `width` - A u32 value representing the image width.
381 ///
382 /// # Examples
383 ///
384 /// ```
385 /// let mut req = Img2ImgRequest::default();
386 /// req.with_width(512);
387 /// ```
388 pub fn with_width(&mut self, width: u32) -> &mut Self {
389 self.width = Some(width);
390 self
391 }
392
393 /// Sets the height for image generation.
394 ///
395 /// # Arguments
396 ///
397 /// * `height` - A u32 value representing the image height.
398 ///
399 /// # Examples
400 ///
401 /// ```
402 /// let mut req = Img2ImgRequest::default();
403 /// req.with_height(512);
404 /// ```
405 pub fn with_height(&mut self, height: u32) -> &mut Self {
406 self.height = Some(height);
407 self
408 }
409
410 /// Enable or disable face restoration.
411 ///
412 /// # Arguments
413 ///
414 /// * `restore_faces` - A bool value to enable or disable face restoration.
415 ///
416 /// # Examples
417 ///
418 /// ```
419 /// let mut req = Img2ImgRequest::default();
420 /// req.with_restore_faces(true);
421 /// ```
422 pub fn with_restore_faces(&mut self, restore_faces: bool) -> &mut Self {
423 self.restore_faces = Some(restore_faces);
424 self
425 }
426
427 /// Enable or disable image tiling.
428 ///
429 /// # Arguments
430 ///
431 /// * `tiling` - A bool value to enable or disable tiling.
432 ///
433 /// # Examples
434 ///
435 /// ```
436 /// let mut req = Img2ImgRequest::default();
437 /// req.with_tiling(true);
438 /// ```
439 pub fn with_tiling(&mut self, tiling: bool) -> &mut Self {
440 self.tiling = Some(tiling);
441 self
442 }
443
444 /// Adds a negative prompt to the request.
445 ///
446 /// # Arguments
447 ///
448 /// * `negative_prompt` - A String representing the negative prompt to be used for image generation.
449 ///
450 /// # Example
451 ///
452 /// ```
453 /// let mut req = Img2ImgRequest::default();
454 /// req.with_prompt("bad, ugly, worst quality".to_string());
455 /// ```
456 pub fn with_negative_prompt(&mut self, negative_prompt: String) -> &mut Self {
457 self.negative_prompt = Some(negative_prompt);
458 self
459 }
460
461 /// Merges the given settings with the request's settings.
462 ///
463 /// # Arguments
464 ///
465 /// * `request` - A Img2ImgRequest containing the settings to merge.
466 pub fn merge(&self, request: Self) -> Self {
467 Self {
468 init_images: request.init_images.or(self.init_images.clone()),
469 resize_mode: request.resize_mode.or(self.resize_mode),
470 denoising_strength: request.denoising_strength.or(self.denoising_strength),
471 image_cfg_scale: request.image_cfg_scale.or(self.image_cfg_scale),
472 mask: request.mask.or(self.mask.clone()),
473 mask_blur: request.mask_blur.or(self.mask_blur),
474 inpainting_fill: request.inpainting_fill.or(self.inpainting_fill),
475 inpaint_full_res: request.inpaint_full_res.or(self.inpaint_full_res),
476 inpaint_full_res_padding: request
477 .inpaint_full_res_padding
478 .or(self.inpaint_full_res_padding),
479 inpainting_mask_invert: request
480 .inpainting_mask_invert
481 .or(self.inpainting_mask_invert),
482 initial_noise_multiplier: request
483 .initial_noise_multiplier
484 .or(self.initial_noise_multiplier),
485 prompt: request.prompt.or(self.prompt.clone()),
486 styles: request.styles.or(self.styles.clone()),
487 seed: request.seed.or(self.seed),
488 subseed: request.subseed.or(self.subseed),
489 subseed_strength: request.subseed_strength.or(self.subseed_strength),
490 seed_resize_from_h: request.seed_resize_from_h.or(self.seed_resize_from_h),
491 seed_resize_from_w: request.seed_resize_from_w.or(self.seed_resize_from_w),
492 sampler_name: request.sampler_name.or(self.sampler_name.clone()),
493 batch_size: request.batch_size.or(self.batch_size),
494 n_iter: request.n_iter.or(self.n_iter),
495 steps: request.steps.or(self.steps),
496 cfg_scale: request.cfg_scale.or(self.cfg_scale),
497 width: request.width.or(self.width),
498 height: request.height.or(self.height),
499 restore_faces: request.restore_faces.or(self.restore_faces),
500 tiling: request.tiling.or(self.tiling),
501 do_not_save_samples: request.do_not_save_samples.or(self.do_not_save_samples),
502 do_not_save_grid: request.do_not_save_grid.or(self.do_not_save_grid),
503 negative_prompt: request.negative_prompt.or(self.negative_prompt.clone()),
504 eta: request.eta.or(self.eta),
505 s_churn: request.s_churn.or(self.s_churn),
506 s_tmax: request.s_tmax.or(self.s_tmax),
507 s_tmin: request.s_tmin.or(self.s_tmin),
508 s_noise: request.s_noise.or(self.s_noise),
509 override_settings: request.override_settings.or(self.override_settings.clone()),
510 override_settings_restore_afterwards: request
511 .override_settings_restore_afterwards
512 .or(self.override_settings_restore_afterwards),
513 script_args: request.script_args.or(self.script_args.clone()),
514 sampler_index: request.sampler_index.or(self.sampler_index.clone()),
515 include_init_images: request.include_init_images.or(self.include_init_images),
516 script_name: request.script_name.or(self.script_name.clone()),
517 send_images: request.send_images.or(self.send_images),
518 save_images: request.save_images.or(self.save_images),
519 alwayson_scripts: request.alwayson_scripts.or(self.alwayson_scripts.clone()),
520 }
521 }
522}
523
524/// Errors that can occur when interacting with the `Img2Img` API.
525#[derive(thiserror::Error, Debug)]
526#[non_exhaustive]
527pub enum Img2ImgError {
528 /// Error parsing endpoint URL
529 #[error("Failed to parse endpoint URL")]
530 ParseError(#[from] url::ParseError),
531 /// Error sending request
532 #[error("Failed to send request")]
533 RequestFailed(#[from] reqwest::Error),
534 /// An error occurred while parsing the response from the API.
535 #[error("Parsing response failed")]
536 InvalidResponse(#[source] reqwest::Error),
537 /// An error occurred getting response data.
538 #[error("Failed to get response data")]
539 GetDataFailed(#[source] reqwest::Error),
540 /// Server returned an error for img2img
541 #[error("Img2Img request failed: {status}: {error}")]
542 Img2ImgFailed {
543 status: reqwest::StatusCode,
544 error: String,
545 },
546}
547
548type Result<T> = std::result::Result<T, Img2ImgError>;
549
550/// A client for sending image requests to a specified endpoint.
551pub struct Img2Img {
552 client: reqwest::Client,
553 endpoint: Url,
554}
555
556impl Img2Img {
557 /// Constructs a new Img2Img client with a given `reqwest::Client` and Stable Diffusion API
558 /// endpoint `String`.
559 ///
560 /// # Arguments
561 ///
562 /// * `client` - A `reqwest::Client` used to send requests.
563 /// * `endpoint` - A `String` representation of the endpoint url.
564 ///
565 /// # Returns
566 ///
567 /// A `Result` containing a new Img2Img instance on success, or an error if url parsing failed.
568 pub fn new(client: reqwest::Client, endpoint: String) -> Result<Self> {
569 Ok(Self::new_with_url(client, Url::parse(&endpoint)?))
570 }
571
572 /// Constructs a new Img2Img client with a given `reqwest::Client` and endpoint `Url`.
573 ///
574 /// # Arguments
575 ///
576 /// * `client` - A `reqwest::Client` used to send requests.
577 /// * `endpoint` - A `Url` representing the endpoint url.
578 ///
579 /// # Returns
580 ///
581 /// A new Img2Img instance.
582 pub fn new_with_url(client: reqwest::Client, endpoint: Url) -> Self {
583 Self { client, endpoint }
584 }
585
586 /// Sends an image request using the Img2Img client.
587 ///
588 /// # Arguments
589 ///
590 /// * `request` - An Img2ImgRequest containing the parameters for the image request.
591 ///
592 /// # Returns
593 ///
594 /// A `Result` containing an `ImgResponse<Img2ImgRequest>` on success, or an error if one occurred.
595 pub async fn send(&self, request: &Img2ImgRequest) -> Result<ImgResponse<Img2ImgRequest>> {
596 let response = self
597 .client
598 .post(self.endpoint.clone())
599 .json(&request)
600 .send()
601 .await?;
602 if response.status().is_success() {
603 return response.json().await.map_err(Img2ImgError::InvalidResponse);
604 }
605 let status = response.status();
606 let text = response.text().await.map_err(Img2ImgError::GetDataFailed)?;
607 Err(Img2ImgError::Img2ImgFailed {
608 status,
609 error: text,
610 })
611 }
612}