stable_diffusion_api/
txt2img.rs

1use std::collections::HashMap;
2
3use reqwest::Url;
4use serde::{Deserialize, Serialize};
5use serde_with::skip_serializing_none;
6
7use super::ImgResponse;
8
9/// Struct representing a text to image request.
10#[skip_serializing_none]
11#[derive(Default, PartialEq, Serialize, Deserialize, Debug, Clone)]
12pub struct Txt2ImgRequest {
13    /// Whether to enable high resolution mode.
14    pub enable_hr: Option<bool>,
15    /// Strength of denoising applied to the image.
16    pub denoising_strength: Option<f64>,
17    /// Width of the image in the first phase.
18    pub firstphase_width: Option<u32>,
19    /// Height of the image in the first phase.
20    pub firstphase_height: Option<u32>,
21    /// Scale factor for high resolution mode.
22    pub hr_scale: Option<f64>,
23    /// Upscaler used in high resolution mode.
24    pub hr_upscaler: Option<String>,
25    /// Number of steps in the second pass of high resolution mode.
26    pub hr_second_pass_steps: Option<u32>,
27    /// Width of the image after resizing in high resolution mode.
28    pub hr_resize_x: Option<u32>,
29    /// Height of the image after resizing in high resolution mode.
30    pub hr_resize_y: Option<u32>,
31    /// Text prompt for generating the image.
32    pub prompt: Option<String>,
33    /// List of style prompts for generating the image.
34    pub styles: Option<Vec<String>>,
35    /// Seed for generating the image.
36    pub seed: Option<i64>,
37    /// Subseed for generating the image.
38    pub subseed: Option<i64>,
39    /// Strength of subseed.
40    pub subseed_strength: Option<u32>,
41    /// Height of the seed image.
42    pub seed_resize_from_h: Option<i32>,
43    /// Width of the seed image.
44    pub seed_resize_from_w: Option<i32>,
45    /// Name of the sampler.
46    pub sampler_name: Option<String>,
47    /// Batch size used in generating images.
48    pub batch_size: Option<u32>,
49    /// Number of images to generate per batch.
50    pub n_iter: Option<u32>,
51    /// Number of steps.
52    pub steps: Option<u32>,
53    /// CFG scale factor.
54    pub cfg_scale: Option<f64>,
55    /// Width of the generated image.
56    pub width: Option<u32>,
57    /// Height of the generated image.
58    pub height: Option<u32>,
59    /// Whether to restore faces in the generated image.
60    pub restore_faces: Option<bool>,
61    /// Whether to use tiling mode in the generated image.
62    pub tiling: Option<bool>,
63    /// Whether to save samples when generating multiple images.
64    pub do_not_save_samples: Option<bool>,
65    /// Whether to save the grid when generating multiple images.
66    pub do_not_save_grid: Option<bool>,
67    /// Negative text prompt.
68    pub negative_prompt: Option<String>,
69    /// Eta value.
70    pub eta: Option<u32>,
71    /// Churn value.
72    pub s_churn: Option<f64>,
73    /// Maximum temperature value.
74    pub s_tmax: Option<f64>,
75    /// Minimum temperature value.
76    pub s_tmin: Option<f64>,
77    /// Noise value.
78    pub s_noise: Option<f64>,
79    /// Settings to override when generating the image.
80    pub override_settings: Option<HashMap<String, serde_json::Value>>,
81    /// Whether to restore the settings after generating the image.
82    pub override_settings_restore_afterwards: Option<bool>,
83    /// Arguments for the script.
84    pub script_args: Option<Vec<serde_json::Value>>,
85    /// Index of the sampler.
86    pub sampler_index: Option<String>,
87    /// Name of the script.
88    pub script_name: Option<String>,
89    /// Whether to send the generated images.
90    pub send_images: Option<bool>,
91    /// Whether to send the generated images.
92    pub save_images: Option<bool>,
93    /// Scripts to always run.
94    pub alwayson_scripts: Option<HashMap<String, serde_json::Value>>,
95}
96
97impl Txt2ImgRequest {
98    /// Adds a prompt to the request.
99    ///
100    /// # Arguments
101    ///
102    /// * `prompt` - A String representing the prompt to be used for image generation.
103    ///
104    /// # Example
105    ///
106    /// ```
107    /// let mut req = Txt2ImgRequest::default();
108    /// req.with_prompt("A blue sky with green grass".to_string());
109    /// ```
110    pub fn with_prompt(&mut self, prompt: String) -> &mut Self {
111        self.prompt = Some(prompt);
112        self
113    }
114
115    /// Adds styles to the request.
116    ///
117    /// # Arguments
118    ///
119    /// * `styles` - A vector of Strings representing the styles to be used for image generation.
120    ///
121    /// # Examples
122    ///
123    /// ```
124    /// let mut req = Txt2ImgRequest::default();
125    /// req.with_styles(vec!["cubism".to_string(), "impressionism".to_string()]);
126    /// ```
127    pub fn with_styles(&mut self, styles: Vec<String>) -> &mut Self {
128        self.styles = Some(styles);
129        self
130    }
131
132    /// Adds a style to the request.
133    ///
134    ///
135    /// # Arguments
136    ///
137    /// * `style` - A String representing the style to be used for image generation.
138    ///
139    /// # Examples
140    ///
141    /// ```
142    /// let mut req = Txt2ImgRequest::default();
143    /// req.with_style("cubism".to_string());
144    /// ```
145    pub fn with_style(&mut self, style: String) -> &mut Self {
146        if let Some(ref mut styles) = self.styles {
147            styles.push(style);
148            self
149        } else {
150            self.with_styles(vec![style])
151        }
152    }
153
154    /// Sets the seed for random number generation.
155    ///
156    /// # Arguments
157    ///
158    /// * `seed` - An i64 value representing the seed for random number generation.
159    ///            Set to `-1` to randomize.
160    ///
161    /// # Example
162    ///
163    /// ```
164    /// let mut req = Txt2ImgRequest::default();
165    /// req.with_seed(12345);
166    /// ```
167    pub fn with_seed(&mut self, seed: i64) -> &mut Self {
168        self.seed = Some(seed);
169        self
170    }
171
172    /// Sets the subseed for random number generation.
173    ///
174    /// # Arguments
175    ///
176    /// * `subseed` - An i64 value representing the subseed for random number generation.
177    ///               Set to `-1` to randomize.
178    ///
179    /// # Example
180    ///
181    /// ```
182    /// let mut req = Txt2ImgRequest::default();
183    /// req.with_subseed(12345);
184    /// ```
185    pub fn with_subseed(&mut self, subseed: i64) -> &mut Self {
186        self.subseed = Some(subseed);
187        self
188    }
189
190    /// Sets the strength of the subseed parameter for image generation.
191    ///
192    /// # Arguments
193    ///
194    /// * `subseed_strength` - A u32 value representing the strength of the subseed parameter.
195    ///
196    /// # Example
197    ///
198    /// ```
199    /// let mut req = Txt2ImgRequest::default();
200    /// req.with_subseed_strength(5);
201    /// ```
202    pub fn with_subseed_strength(&mut self, subseed_strength: u32) -> &mut Self {
203        self.subseed_strength = Some(subseed_strength);
204        self
205    }
206
207    /// Sets the sampler name for image generation.
208    ///
209    /// # Arguments
210    ///
211    /// * `sampler_name` - A String representing the sampler name to be used.
212    ///
213    /// # Examples
214    ///
215    /// ```
216    /// let mut req = Txt2ImgRequest::default();
217    /// req.with_sampler_name("Euler".to_string());
218    /// ```
219    pub fn with_sampler_name(&mut self, sampler_name: String) -> &mut Self {
220        self.sampler_name = Some(sampler_name);
221        self
222    }
223
224    /// Sets the batch size for image generation.
225    ///
226    /// # Arguments
227    ///
228    /// * `batch_size` - A u32 value representing the batch size to be used.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// let mut req = Txt2ImgRequest::default();
234    /// req.with_batch_size(16);
235    /// ```
236    pub fn with_batch_size(&mut self, batch_size: u32) -> &mut Self {
237        self.batch_size = Some(batch_size);
238        self
239    }
240
241    /// Sets the number of iterations for image generation.
242    ///
243    /// # Arguments
244    ///
245    /// * `n_iter` - A u32 value representing the number of iterations to run for image generation.
246    ///
247    /// # Examples
248    ///
249    /// ```
250    /// let mut req = Txt2ImgRequest::default();
251    /// req.with_n_iter(1000);
252    /// ```
253    pub fn with_n_iter(&mut self, n_iter: u32) -> &mut Self {
254        self.n_iter = Some(n_iter);
255        self
256    }
257
258    /// Sets the number of steps for image generation.
259    ///
260    /// # Arguments
261    ///
262    /// * `steps` - A u32 value representing the number of steps for image generation.
263    ///
264    /// # Examples
265    ///
266    /// ```
267    /// let mut req = Txt2ImgRequest::default();
268    /// req.with_steps(50);
269    /// ```
270    pub fn with_steps(&mut self, steps: u32) -> &mut Self {
271        self.steps = Some(steps);
272        self
273    }
274
275    /// Sets the cfg scale for image generation.
276    ///
277    /// # Arguments
278    ///
279    /// * `cfg_scale` - A f64 value representing the cfg scale parameter.
280    ///
281    /// # Examples
282    ///
283    /// ```
284    /// let mut req = Txt2ImgRequest::default();
285    /// req.with_cfg_scale(0.7);
286    /// ```
287    pub fn with_cfg_scale(&mut self, cfg_scale: f64) -> &mut Self {
288        self.cfg_scale = Some(cfg_scale);
289        self
290    }
291
292    /// Sets the width for image generation.
293    ///
294    /// # Arguments
295    ///
296    /// * `width` - A u32 value representing the image width.
297    ///
298    /// # Examples
299    ///
300    /// ```
301    /// let mut req = Txt2ImgRequest::default();
302    /// req.with_width(512);
303    /// ```
304    pub fn with_width(&mut self, width: u32) -> &mut Self {
305        self.width = Some(width);
306        self
307    }
308
309    /// Sets the height for image generation.
310    ///
311    /// # Arguments
312    ///
313    /// * `height` - A u32 value representing the image height.
314    ///
315    /// # Examples
316    ///
317    /// ```
318    /// let mut req = Txt2ImgRequest::default();
319    /// req.with_height(512);
320    /// ```
321    pub fn with_height(&mut self, height: u32) -> &mut Self {
322        self.height = Some(height);
323        self
324    }
325
326    /// Enable or disable face restoration.
327    ///
328    /// # Arguments
329    ///
330    /// * `restore_faces` - A bool value to enable or disable face restoration.
331    ///
332    /// # Examples
333    ///
334    /// ```
335    /// let mut req = Txt2ImgRequest::default();
336    /// req.with_restore_faces(true);
337    /// ```
338    pub fn with_restore_faces(&mut self, restore_faces: bool) -> &mut Self {
339        self.restore_faces = Some(restore_faces);
340        self
341    }
342
343    /// Enable or disable image tiling.
344    ///
345    /// # Arguments
346    ///
347    /// * `tiling` - A bool value to enable or disable tiling.
348    ///
349    /// # Examples
350    ///
351    /// ```
352    /// let mut req = Txt2ImgRequest::default();
353    /// req.with_tiling(true);
354    /// ```
355    pub fn with_tiling(&mut self, tiling: bool) -> &mut Self {
356        self.tiling = Some(tiling);
357        self
358    }
359
360    /// Adds a negative prompt to the request.
361    ///
362    /// # Arguments
363    ///
364    /// * `negative_prompt` - A String representing the negative prompt to be used for image generation.
365    ///
366    /// # Example
367    ///
368    /// ```
369    /// let mut req = Txt2ImgRequest::default();
370    /// req.with_prompt("bad, ugly, worst quality".to_string());
371    /// ```
372    pub fn with_negative_prompt(&mut self, negative_prompt: String) -> &mut Self {
373        self.negative_prompt = Some(negative_prompt);
374        self
375    }
376
377    /// Merges the given settings with the request's settings.
378    ///
379    /// # Arguments
380    ///
381    /// * `request` - A Txt2ImgRequest containing the settings to merge.
382    pub fn merge(&self, request: Self) -> Self {
383        Self {
384            enable_hr: request.enable_hr.or(self.enable_hr),
385            denoising_strength: request.denoising_strength.or(self.denoising_strength),
386            firstphase_width: request.firstphase_width.or(self.firstphase_width),
387            firstphase_height: request.firstphase_height.or(self.firstphase_height),
388            hr_scale: request.hr_scale.or(self.hr_scale),
389            hr_upscaler: request.hr_upscaler.or(self.hr_upscaler.clone()),
390            hr_second_pass_steps: request.hr_second_pass_steps.or(self.hr_second_pass_steps),
391            hr_resize_x: request.hr_resize_x.or(self.hr_resize_x),
392            hr_resize_y: request.hr_resize_y.or(self.hr_resize_y),
393            prompt: request.prompt.or(self.prompt.clone()),
394            styles: request.styles.or(self.styles.clone()),
395            seed: request.seed.or(self.seed),
396            subseed: request.subseed.or(self.subseed),
397            subseed_strength: request.subseed_strength.or(self.subseed_strength),
398            seed_resize_from_h: request.seed_resize_from_h.or(self.seed_resize_from_h),
399            seed_resize_from_w: request.seed_resize_from_w.or(self.seed_resize_from_w),
400            sampler_name: request.sampler_name.or(self.sampler_name.clone()),
401            batch_size: request.batch_size.or(self.batch_size),
402            n_iter: request.n_iter.or(self.n_iter),
403            steps: request.steps.or(self.steps),
404            cfg_scale: request.cfg_scale.or(self.cfg_scale),
405            width: request.width.or(self.width),
406            height: request.height.or(self.height),
407            restore_faces: request.restore_faces.or(self.restore_faces),
408            tiling: request.tiling.or(self.tiling),
409            do_not_save_samples: request.do_not_save_samples.or(self.do_not_save_samples),
410            do_not_save_grid: request.do_not_save_grid.or(self.do_not_save_grid),
411            negative_prompt: request.negative_prompt.or(self.negative_prompt.clone()),
412            eta: request.eta.or(self.eta),
413            s_churn: request.s_churn.or(self.s_churn),
414            s_tmax: request.s_tmax.or(self.s_tmax),
415            s_tmin: request.s_tmin.or(self.s_tmin),
416            s_noise: request.s_noise.or(self.s_noise),
417            override_settings: request.override_settings.or(self.override_settings.clone()),
418            override_settings_restore_afterwards: request
419                .override_settings_restore_afterwards
420                .or(self.override_settings_restore_afterwards),
421            script_args: request.script_args.or(self.script_args.clone()),
422            sampler_index: request.sampler_index.or(self.sampler_index.clone()),
423            script_name: request.script_name.or(self.script_name.clone()),
424            send_images: request.send_images.or(self.send_images),
425            save_images: request.save_images.or(self.save_images),
426            alwayson_scripts: request.alwayson_scripts.or(self.alwayson_scripts.clone()),
427        }
428    }
429}
430
431/// Errors that can occur when interacting with the `Txt2Img` API.
432#[derive(thiserror::Error, Debug)]
433#[non_exhaustive]
434pub enum Txt2ImgError {
435    /// Error parsing endpoint URL
436    #[error("Failed to parse endpoint URL")]
437    ParseError(#[from] url::ParseError),
438    /// Error sending request
439    #[error("Failed to send request")]
440    RequestFailed(#[from] reqwest::Error),
441    /// An error occurred while parsing the response from the API.
442    #[error("Parsing response failed")]
443    InvalidResponse(#[source] reqwest::Error),
444    /// An error occurred getting response data.
445    #[error("Failed to get response data")]
446    GetDataFailed(#[source] reqwest::Error),
447    /// Server returned an error for img2img
448    #[error("Img2Img request failed: {status}: {error}")]
449    Txt2ImgFailed {
450        status: reqwest::StatusCode,
451        error: String,
452    },
453}
454
455type Result<T> = std::result::Result<T, Txt2ImgError>;
456
457/// A client for sending image requests to a specified endpoint.
458pub struct Txt2Img {
459    client: reqwest::Client,
460    endpoint: Url,
461}
462
463impl Txt2Img {
464    /// Constructs a new Txt2Img client with a given `reqwest::Client` and Stable Diffusion API
465    /// endpoint `String`.
466    ///
467    /// # Arguments
468    ///
469    /// * `client` - A `reqwest::Client` used to send requests.
470    /// * `endpoint` - A `String` representation of the endpoint url.
471    ///
472    /// # Returns
473    ///
474    /// A `Result` containing a new Txt2Img instance on success, or an error if url parsing failed.
475    pub fn new(client: reqwest::Client, endpoint: String) -> Result<Self> {
476        Ok(Self::new_with_url(client, Url::parse(&endpoint)?))
477    }
478
479    /// Constructs a new Txt2Img client with a given `reqwest::Client` and endpoint `Url`.
480    ///
481    /// # Arguments
482    ///
483    /// * `client` - A `reqwest::Client` used to send requests.
484    /// * `endpoint` - A `Url` representing the endpoint url.
485    ///
486    /// # Returns
487    ///
488    /// A new Txt2Img instance.
489    pub fn new_with_url(client: reqwest::Client, endpoint: Url) -> Self {
490        Self { client, endpoint }
491    }
492
493    /// Sends an image request using the Txt2Img client.
494    ///
495    /// # Arguments
496    ///
497    /// * `request` - An Txt2ImgRequest containing the parameters for the image request.
498    ///
499    /// # Returns
500    ///
501    /// A `Result` containing an `ImgResponse<Txt2ImgRequest>` on success, or an error if one occurred.
502    pub async fn send(&self, request: &Txt2ImgRequest) -> Result<ImgResponse<Txt2ImgRequest>> {
503        let response = self
504            .client
505            .post(self.endpoint.clone())
506            .json(&request)
507            .send()
508            .await
509            .map_err(Txt2ImgError::RequestFailed)?;
510        if response.status().is_success() {
511            return response.json().await.map_err(Txt2ImgError::InvalidResponse);
512        }
513        let status = response.status();
514        let text = response.text().await.map_err(Txt2ImgError::GetDataFailed)?;
515        Err(Txt2ImgError::Txt2ImgFailed {
516            status,
517            error: text,
518        })
519    }
520}