stable_diffusion_api/txt2img.rs
1use std::collections::HashMap;
2
3use reqwest::Url;
4use serde::{Deserialize, Serialize};
5use serde_with::skip_serializing_none;
6
7use super::ImgResponse;
8
9/// Struct representing a text to image request.
10#[skip_serializing_none]
11#[derive(Default, PartialEq, Serialize, Deserialize, Debug, Clone)]
12pub struct Txt2ImgRequest {
13 /// Whether to enable high resolution mode.
14 pub enable_hr: Option<bool>,
15 /// Strength of denoising applied to the image.
16 pub denoising_strength: Option<f64>,
17 /// Width of the image in the first phase.
18 pub firstphase_width: Option<u32>,
19 /// Height of the image in the first phase.
20 pub firstphase_height: Option<u32>,
21 /// Scale factor for high resolution mode.
22 pub hr_scale: Option<f64>,
23 /// Upscaler used in high resolution mode.
24 pub hr_upscaler: Option<String>,
25 /// Number of steps in the second pass of high resolution mode.
26 pub hr_second_pass_steps: Option<u32>,
27 /// Width of the image after resizing in high resolution mode.
28 pub hr_resize_x: Option<u32>,
29 /// Height of the image after resizing in high resolution mode.
30 pub hr_resize_y: Option<u32>,
31 /// Text prompt for generating the image.
32 pub prompt: Option<String>,
33 /// List of style prompts for generating the image.
34 pub styles: Option<Vec<String>>,
35 /// Seed for generating the image.
36 pub seed: Option<i64>,
37 /// Subseed for generating the image.
38 pub subseed: Option<i64>,
39 /// Strength of subseed.
40 pub subseed_strength: Option<u32>,
41 /// Height of the seed image.
42 pub seed_resize_from_h: Option<i32>,
43 /// Width of the seed image.
44 pub seed_resize_from_w: Option<i32>,
45 /// Name of the sampler.
46 pub sampler_name: Option<String>,
47 /// Batch size used in generating images.
48 pub batch_size: Option<u32>,
49 /// Number of images to generate per batch.
50 pub n_iter: Option<u32>,
51 /// Number of steps.
52 pub steps: Option<u32>,
53 /// CFG scale factor.
54 pub cfg_scale: Option<f64>,
55 /// Width of the generated image.
56 pub width: Option<u32>,
57 /// Height of the generated image.
58 pub height: Option<u32>,
59 /// Whether to restore faces in the generated image.
60 pub restore_faces: Option<bool>,
61 /// Whether to use tiling mode in the generated image.
62 pub tiling: Option<bool>,
63 /// Whether to save samples when generating multiple images.
64 pub do_not_save_samples: Option<bool>,
65 /// Whether to save the grid when generating multiple images.
66 pub do_not_save_grid: Option<bool>,
67 /// Negative text prompt.
68 pub negative_prompt: Option<String>,
69 /// Eta value.
70 pub eta: Option<u32>,
71 /// Churn value.
72 pub s_churn: Option<f64>,
73 /// Maximum temperature value.
74 pub s_tmax: Option<f64>,
75 /// Minimum temperature value.
76 pub s_tmin: Option<f64>,
77 /// Noise value.
78 pub s_noise: Option<f64>,
79 /// Settings to override when generating the image.
80 pub override_settings: Option<HashMap<String, serde_json::Value>>,
81 /// Whether to restore the settings after generating the image.
82 pub override_settings_restore_afterwards: Option<bool>,
83 /// Arguments for the script.
84 pub script_args: Option<Vec<serde_json::Value>>,
85 /// Index of the sampler.
86 pub sampler_index: Option<String>,
87 /// Name of the script.
88 pub script_name: Option<String>,
89 /// Whether to send the generated images.
90 pub send_images: Option<bool>,
91 /// Whether to send the generated images.
92 pub save_images: Option<bool>,
93 /// Scripts to always run.
94 pub alwayson_scripts: Option<HashMap<String, serde_json::Value>>,
95}
96
97impl Txt2ImgRequest {
98 /// Adds a prompt to the request.
99 ///
100 /// # Arguments
101 ///
102 /// * `prompt` - A String representing the prompt to be used for image generation.
103 ///
104 /// # Example
105 ///
106 /// ```
107 /// let mut req = Txt2ImgRequest::default();
108 /// req.with_prompt("A blue sky with green grass".to_string());
109 /// ```
110 pub fn with_prompt(&mut self, prompt: String) -> &mut Self {
111 self.prompt = Some(prompt);
112 self
113 }
114
115 /// Adds styles to the request.
116 ///
117 /// # Arguments
118 ///
119 /// * `styles` - A vector of Strings representing the styles to be used for image generation.
120 ///
121 /// # Examples
122 ///
123 /// ```
124 /// let mut req = Txt2ImgRequest::default();
125 /// req.with_styles(vec!["cubism".to_string(), "impressionism".to_string()]);
126 /// ```
127 pub fn with_styles(&mut self, styles: Vec<String>) -> &mut Self {
128 self.styles = Some(styles);
129 self
130 }
131
132 /// Adds a style to the request.
133 ///
134 ///
135 /// # Arguments
136 ///
137 /// * `style` - A String representing the style to be used for image generation.
138 ///
139 /// # Examples
140 ///
141 /// ```
142 /// let mut req = Txt2ImgRequest::default();
143 /// req.with_style("cubism".to_string());
144 /// ```
145 pub fn with_style(&mut self, style: String) -> &mut Self {
146 if let Some(ref mut styles) = self.styles {
147 styles.push(style);
148 self
149 } else {
150 self.with_styles(vec![style])
151 }
152 }
153
154 /// Sets the seed for random number generation.
155 ///
156 /// # Arguments
157 ///
158 /// * `seed` - An i64 value representing the seed for random number generation.
159 /// Set to `-1` to randomize.
160 ///
161 /// # Example
162 ///
163 /// ```
164 /// let mut req = Txt2ImgRequest::default();
165 /// req.with_seed(12345);
166 /// ```
167 pub fn with_seed(&mut self, seed: i64) -> &mut Self {
168 self.seed = Some(seed);
169 self
170 }
171
172 /// Sets the subseed for random number generation.
173 ///
174 /// # Arguments
175 ///
176 /// * `subseed` - An i64 value representing the subseed for random number generation.
177 /// Set to `-1` to randomize.
178 ///
179 /// # Example
180 ///
181 /// ```
182 /// let mut req = Txt2ImgRequest::default();
183 /// req.with_subseed(12345);
184 /// ```
185 pub fn with_subseed(&mut self, subseed: i64) -> &mut Self {
186 self.subseed = Some(subseed);
187 self
188 }
189
190 /// Sets the strength of the subseed parameter for image generation.
191 ///
192 /// # Arguments
193 ///
194 /// * `subseed_strength` - A u32 value representing the strength of the subseed parameter.
195 ///
196 /// # Example
197 ///
198 /// ```
199 /// let mut req = Txt2ImgRequest::default();
200 /// req.with_subseed_strength(5);
201 /// ```
202 pub fn with_subseed_strength(&mut self, subseed_strength: u32) -> &mut Self {
203 self.subseed_strength = Some(subseed_strength);
204 self
205 }
206
207 /// Sets the sampler name for image generation.
208 ///
209 /// # Arguments
210 ///
211 /// * `sampler_name` - A String representing the sampler name to be used.
212 ///
213 /// # Examples
214 ///
215 /// ```
216 /// let mut req = Txt2ImgRequest::default();
217 /// req.with_sampler_name("Euler".to_string());
218 /// ```
219 pub fn with_sampler_name(&mut self, sampler_name: String) -> &mut Self {
220 self.sampler_name = Some(sampler_name);
221 self
222 }
223
224 /// Sets the batch size for image generation.
225 ///
226 /// # Arguments
227 ///
228 /// * `batch_size` - A u32 value representing the batch size to be used.
229 ///
230 /// # Examples
231 ///
232 /// ```
233 /// let mut req = Txt2ImgRequest::default();
234 /// req.with_batch_size(16);
235 /// ```
236 pub fn with_batch_size(&mut self, batch_size: u32) -> &mut Self {
237 self.batch_size = Some(batch_size);
238 self
239 }
240
241 /// Sets the number of iterations for image generation.
242 ///
243 /// # Arguments
244 ///
245 /// * `n_iter` - A u32 value representing the number of iterations to run for image generation.
246 ///
247 /// # Examples
248 ///
249 /// ```
250 /// let mut req = Txt2ImgRequest::default();
251 /// req.with_n_iter(1000);
252 /// ```
253 pub fn with_n_iter(&mut self, n_iter: u32) -> &mut Self {
254 self.n_iter = Some(n_iter);
255 self
256 }
257
258 /// Sets the number of steps for image generation.
259 ///
260 /// # Arguments
261 ///
262 /// * `steps` - A u32 value representing the number of steps for image generation.
263 ///
264 /// # Examples
265 ///
266 /// ```
267 /// let mut req = Txt2ImgRequest::default();
268 /// req.with_steps(50);
269 /// ```
270 pub fn with_steps(&mut self, steps: u32) -> &mut Self {
271 self.steps = Some(steps);
272 self
273 }
274
275 /// Sets the cfg scale for image generation.
276 ///
277 /// # Arguments
278 ///
279 /// * `cfg_scale` - A f64 value representing the cfg scale parameter.
280 ///
281 /// # Examples
282 ///
283 /// ```
284 /// let mut req = Txt2ImgRequest::default();
285 /// req.with_cfg_scale(0.7);
286 /// ```
287 pub fn with_cfg_scale(&mut self, cfg_scale: f64) -> &mut Self {
288 self.cfg_scale = Some(cfg_scale);
289 self
290 }
291
292 /// Sets the width for image generation.
293 ///
294 /// # Arguments
295 ///
296 /// * `width` - A u32 value representing the image width.
297 ///
298 /// # Examples
299 ///
300 /// ```
301 /// let mut req = Txt2ImgRequest::default();
302 /// req.with_width(512);
303 /// ```
304 pub fn with_width(&mut self, width: u32) -> &mut Self {
305 self.width = Some(width);
306 self
307 }
308
309 /// Sets the height for image generation.
310 ///
311 /// # Arguments
312 ///
313 /// * `height` - A u32 value representing the image height.
314 ///
315 /// # Examples
316 ///
317 /// ```
318 /// let mut req = Txt2ImgRequest::default();
319 /// req.with_height(512);
320 /// ```
321 pub fn with_height(&mut self, height: u32) -> &mut Self {
322 self.height = Some(height);
323 self
324 }
325
326 /// Enable or disable face restoration.
327 ///
328 /// # Arguments
329 ///
330 /// * `restore_faces` - A bool value to enable or disable face restoration.
331 ///
332 /// # Examples
333 ///
334 /// ```
335 /// let mut req = Txt2ImgRequest::default();
336 /// req.with_restore_faces(true);
337 /// ```
338 pub fn with_restore_faces(&mut self, restore_faces: bool) -> &mut Self {
339 self.restore_faces = Some(restore_faces);
340 self
341 }
342
343 /// Enable or disable image tiling.
344 ///
345 /// # Arguments
346 ///
347 /// * `tiling` - A bool value to enable or disable tiling.
348 ///
349 /// # Examples
350 ///
351 /// ```
352 /// let mut req = Txt2ImgRequest::default();
353 /// req.with_tiling(true);
354 /// ```
355 pub fn with_tiling(&mut self, tiling: bool) -> &mut Self {
356 self.tiling = Some(tiling);
357 self
358 }
359
360 /// Adds a negative prompt to the request.
361 ///
362 /// # Arguments
363 ///
364 /// * `negative_prompt` - A String representing the negative prompt to be used for image generation.
365 ///
366 /// # Example
367 ///
368 /// ```
369 /// let mut req = Txt2ImgRequest::default();
370 /// req.with_prompt("bad, ugly, worst quality".to_string());
371 /// ```
372 pub fn with_negative_prompt(&mut self, negative_prompt: String) -> &mut Self {
373 self.negative_prompt = Some(negative_prompt);
374 self
375 }
376
377 /// Merges the given settings with the request's settings.
378 ///
379 /// # Arguments
380 ///
381 /// * `request` - A Txt2ImgRequest containing the settings to merge.
382 pub fn merge(&self, request: Self) -> Self {
383 Self {
384 enable_hr: request.enable_hr.or(self.enable_hr),
385 denoising_strength: request.denoising_strength.or(self.denoising_strength),
386 firstphase_width: request.firstphase_width.or(self.firstphase_width),
387 firstphase_height: request.firstphase_height.or(self.firstphase_height),
388 hr_scale: request.hr_scale.or(self.hr_scale),
389 hr_upscaler: request.hr_upscaler.or(self.hr_upscaler.clone()),
390 hr_second_pass_steps: request.hr_second_pass_steps.or(self.hr_second_pass_steps),
391 hr_resize_x: request.hr_resize_x.or(self.hr_resize_x),
392 hr_resize_y: request.hr_resize_y.or(self.hr_resize_y),
393 prompt: request.prompt.or(self.prompt.clone()),
394 styles: request.styles.or(self.styles.clone()),
395 seed: request.seed.or(self.seed),
396 subseed: request.subseed.or(self.subseed),
397 subseed_strength: request.subseed_strength.or(self.subseed_strength),
398 seed_resize_from_h: request.seed_resize_from_h.or(self.seed_resize_from_h),
399 seed_resize_from_w: request.seed_resize_from_w.or(self.seed_resize_from_w),
400 sampler_name: request.sampler_name.or(self.sampler_name.clone()),
401 batch_size: request.batch_size.or(self.batch_size),
402 n_iter: request.n_iter.or(self.n_iter),
403 steps: request.steps.or(self.steps),
404 cfg_scale: request.cfg_scale.or(self.cfg_scale),
405 width: request.width.or(self.width),
406 height: request.height.or(self.height),
407 restore_faces: request.restore_faces.or(self.restore_faces),
408 tiling: request.tiling.or(self.tiling),
409 do_not_save_samples: request.do_not_save_samples.or(self.do_not_save_samples),
410 do_not_save_grid: request.do_not_save_grid.or(self.do_not_save_grid),
411 negative_prompt: request.negative_prompt.or(self.negative_prompt.clone()),
412 eta: request.eta.or(self.eta),
413 s_churn: request.s_churn.or(self.s_churn),
414 s_tmax: request.s_tmax.or(self.s_tmax),
415 s_tmin: request.s_tmin.or(self.s_tmin),
416 s_noise: request.s_noise.or(self.s_noise),
417 override_settings: request.override_settings.or(self.override_settings.clone()),
418 override_settings_restore_afterwards: request
419 .override_settings_restore_afterwards
420 .or(self.override_settings_restore_afterwards),
421 script_args: request.script_args.or(self.script_args.clone()),
422 sampler_index: request.sampler_index.or(self.sampler_index.clone()),
423 script_name: request.script_name.or(self.script_name.clone()),
424 send_images: request.send_images.or(self.send_images),
425 save_images: request.save_images.or(self.save_images),
426 alwayson_scripts: request.alwayson_scripts.or(self.alwayson_scripts.clone()),
427 }
428 }
429}
430
431/// Errors that can occur when interacting with the `Txt2Img` API.
432#[derive(thiserror::Error, Debug)]
433#[non_exhaustive]
434pub enum Txt2ImgError {
435 /// Error parsing endpoint URL
436 #[error("Failed to parse endpoint URL")]
437 ParseError(#[from] url::ParseError),
438 /// Error sending request
439 #[error("Failed to send request")]
440 RequestFailed(#[from] reqwest::Error),
441 /// An error occurred while parsing the response from the API.
442 #[error("Parsing response failed")]
443 InvalidResponse(#[source] reqwest::Error),
444 /// An error occurred getting response data.
445 #[error("Failed to get response data")]
446 GetDataFailed(#[source] reqwest::Error),
447 /// Server returned an error for img2img
448 #[error("Img2Img request failed: {status}: {error}")]
449 Txt2ImgFailed {
450 status: reqwest::StatusCode,
451 error: String,
452 },
453}
454
455type Result<T> = std::result::Result<T, Txt2ImgError>;
456
457/// A client for sending image requests to a specified endpoint.
458pub struct Txt2Img {
459 client: reqwest::Client,
460 endpoint: Url,
461}
462
463impl Txt2Img {
464 /// Constructs a new Txt2Img client with a given `reqwest::Client` and Stable Diffusion API
465 /// endpoint `String`.
466 ///
467 /// # Arguments
468 ///
469 /// * `client` - A `reqwest::Client` used to send requests.
470 /// * `endpoint` - A `String` representation of the endpoint url.
471 ///
472 /// # Returns
473 ///
474 /// A `Result` containing a new Txt2Img instance on success, or an error if url parsing failed.
475 pub fn new(client: reqwest::Client, endpoint: String) -> Result<Self> {
476 Ok(Self::new_with_url(client, Url::parse(&endpoint)?))
477 }
478
479 /// Constructs a new Txt2Img client with a given `reqwest::Client` and endpoint `Url`.
480 ///
481 /// # Arguments
482 ///
483 /// * `client` - A `reqwest::Client` used to send requests.
484 /// * `endpoint` - A `Url` representing the endpoint url.
485 ///
486 /// # Returns
487 ///
488 /// A new Txt2Img instance.
489 pub fn new_with_url(client: reqwest::Client, endpoint: Url) -> Self {
490 Self { client, endpoint }
491 }
492
493 /// Sends an image request using the Txt2Img client.
494 ///
495 /// # Arguments
496 ///
497 /// * `request` - An Txt2ImgRequest containing the parameters for the image request.
498 ///
499 /// # Returns
500 ///
501 /// A `Result` containing an `ImgResponse<Txt2ImgRequest>` on success, or an error if one occurred.
502 pub async fn send(&self, request: &Txt2ImgRequest) -> Result<ImgResponse<Txt2ImgRequest>> {
503 let response = self
504 .client
505 .post(self.endpoint.clone())
506 .json(&request)
507 .send()
508 .await
509 .map_err(Txt2ImgError::RequestFailed)?;
510 if response.status().is_success() {
511 return response.json().await.map_err(Txt2ImgError::InvalidResponse);
512 }
513 let status = response.status();
514 let text = response.text().await.map_err(Txt2ImgError::GetDataFailed)?;
515 Err(Txt2ImgError::Txt2ImgFailed {
516 status,
517 error: text,
518 })
519 }
520}