stable_diffusion_api/lib.rs
1use reqwest::Url;
2use serde::{Deserialize, Serialize};
3use serde_with::skip_serializing_none;
4
5mod txt2img;
6pub use txt2img::*;
7
8mod img2img;
9pub use img2img::*;
10
11/// Errors that can occur when interacting with the Stable Diffusion API.
12#[derive(thiserror::Error, Debug)]
13#[non_exhaustive]
14pub enum ApiError {
15 /// Error parsing endpoint URL
16 #[error("Failed to parse endpoint URL")]
17 ParseError(#[from] url::ParseError),
18 /// Error parsing info from response
19 #[error("Failed to info from response")]
20 InvalidInfo(#[from] serde_json::Error),
21 /// Error decoding image from response
22 #[error("Failed to decode image from response")]
23 DecodeError(#[from] base64::DecodeError),
24}
25
26type Result<T> = std::result::Result<T, ApiError>;
27
28/// Struct representing a connection to a Stable Diffusion WebUI API.
29#[derive(Clone, Debug)]
30pub struct Api {
31 client: reqwest::Client,
32 url: Url,
33}
34
35impl Default for Api {
36 fn default() -> Self {
37 Self {
38 client: reqwest::Client::new(),
39 url: Url::parse("http://localhost:7860").expect("Failed to parse default URL"),
40 }
41 }
42}
43
44impl Api {
45 /// Returns a new `Api` instance with default settings.
46 pub fn new() -> Self {
47 Self::default()
48 }
49
50 /// Returns a new `Api` instance with the given URL as a string value.
51 ///
52 /// # Arguments
53 ///
54 /// * `url` - A string that specifies the Stable Diffusion WebUI API URL endpoint.
55 ///
56 /// # Errors
57 ///
58 /// If the URL fails to parse, an error will be returned.
59 pub fn new_with_url<S>(url: S) -> Result<Self>
60 where
61 S: AsRef<str>,
62 {
63 Ok(Self {
64 url: Url::parse(url.as_ref())?,
65 ..Default::default()
66 })
67 }
68
69 /// Returns a new `Api` instance with the given `reqwest::Client` and URL as a string value.
70 ///
71 /// # Arguments
72 ///
73 /// * `client` - An instance of `reqwest::Client`.
74 /// * `url` - A string that specifies the Stable Diffusion WebUI API URL endpoint.
75 ///
76 /// # Errors
77 ///
78 /// If the URL fails to parse, an error will be returned.
79 pub fn new_with_client_and_url<S>(client: reqwest::Client, url: S) -> Result<Self>
80 where
81 S: AsRef<str>,
82 {
83 Ok(Self {
84 client,
85 url: Url::parse(url.as_ref())?,
86 })
87 }
88
89 /// Returns a new instance of `Txt2Img` with the API's cloned `reqwest::Client` and the URL for `txt2img` endpoint.
90 ///
91 /// # Errors
92 ///
93 /// If the URL fails to parse, an error will be returned.
94 pub fn txt2img(&self) -> Result<Txt2Img> {
95 Ok(Txt2Img::new_with_url(
96 self.client.clone(),
97 self.url.join("sdapi/v1/txt2img")?,
98 ))
99 }
100
101 /// Returns a new instance of `Img2Img` with the API's cloned `reqwest::Client` and the URL for `img2img` endpoint.
102 ///
103 /// # Errors
104 ///
105 /// If the URL fails to parse, an error will be returned.
106 pub fn img2img(&self) -> Result<Img2Img> {
107 Ok(Img2Img::new_with_url(
108 self.client.clone(),
109 self.url.join("sdapi/v1/img2img")?,
110 ))
111 }
112}
113
114/// A struct that represents the response from the Stable Diffusion WebUI API endpoint.
115#[skip_serializing_none]
116#[derive(Default, Serialize, Deserialize, Debug, Clone)]
117pub struct ImgResponse<T: Clone> {
118 /// A vector of strings containing base64-encoded images.
119 pub images: Vec<String>,
120 /// The parameters that were provided for the generation request.
121 pub parameters: T,
122 /// A string containing JSON representing information about the request.
123 pub info: String,
124}
125
126impl<T: Clone> ImgResponse<T> {
127 /// Parses and returns a new `ImgInfo` instance from the `info` field of the `ImgResponse`.
128 ///
129 /// # Errors
130 ///
131 /// If the `info` field fails to parse, an error will be returned.
132 pub fn info(&self) -> Result<ImgInfo> {
133 Ok(serde_json::from_str(&self.info)?)
134 }
135
136 /// Decodes and returns a vector of images from the `images` field of the `ImgResponse`.
137 ///
138 /// # Errors
139 ///
140 /// If any of the images fail to decode, an error will be returned.
141 pub fn images(&self) -> Result<Vec<Vec<u8>>> {
142 use base64::{engine::general_purpose, Engine as _};
143 self.images
144 .iter()
145 .map(|img| {
146 general_purpose::STANDARD
147 .decode(img)
148 .map_err(ApiError::DecodeError)
149 })
150 .collect::<Result<Vec<_>>>()
151 }
152}
153
154#[skip_serializing_none]
155#[derive(Default, Serialize, Deserialize, Debug, Clone)]
156/// Information about the generated images.
157pub struct ImgInfo {
158 /// The prompt used when generating the image.
159 pub prompt: Option<String>,
160 /// A vector of all the prompts used for image generation.
161 pub all_prompts: Option<Vec<String>>,
162 /// The negative prompt used when generating the image.
163 pub negative_prompt: Option<String>,
164 /// A vector of all negative prompts used when generating the image.
165 pub all_negative_prompts: Option<Vec<String>>,
166 /// The random seed used for image generation.
167 pub seed: Option<i64>,
168 /// A vector of all the random seeds used for image generation.
169 pub all_seeds: Option<Vec<i64>>,
170 /// The subseed used when generating the image.
171 pub subseed: Option<i64>,
172 /// A vector of all the subseeds used for image generation.
173 pub all_subseeds: Option<Vec<i64>>,
174 /// The strength of the subseed used when generating the image.
175 pub subseed_strength: Option<u32>,
176 /// The width of the generated image.
177 pub width: Option<i32>,
178 /// The height of the generated image.
179 pub height: Option<i32>,
180 /// The name of the sampler used for image generation.
181 pub sampler_name: Option<String>,
182 /// The cfg scale factor used when generating the image.
183 pub cfg_scale: Option<f64>,
184 /// The number of steps taken when generating the image.
185 pub steps: Option<u32>,
186 /// The number of images generated in one batch.
187 pub batch_size: Option<u32>,
188 /// Whether or not face restoration was used.
189 pub restore_faces: Option<bool>,
190 /// The face restoration model used when generating the image.
191 pub face_restoration_model: Option<serde_json::Value>,
192 /// The name of the sd model used when generating the image.
193 pub sd_model_name: Option<String>,
194 /// The hash of the sd model used for image generation.
195 pub sd_model_hash: Option<String>,
196 /// The name of the VAE used when generating the image.
197 pub sd_vae_name: Option<String>,
198 /// The hash of the VAE used for image generation.
199 pub sd_vae_hash: Option<String>,
200 /// The width used when resizing the image seed.
201 pub seed_resize_from_w: Option<i32>,
202 /// The height used when resizing the image seed.
203 pub seed_resize_from_h: Option<i32>,
204 /// The strength of the denoising applied during image generation.
205 pub denoising_strength: Option<f64>,
206 /// Extra parameters passed for image generation.
207 pub extra_generation_params: Option<ExtraGenParams>,
208 /// The index of the first image.
209 pub index_of_first_image: Option<u32>,
210 /// A vector of information texts about the generated images.
211 pub infotexts: Option<Vec<String>>,
212 /// A vector of the styles used for image generation.
213 pub styles: Option<Vec<String>>,
214 /// The timestamp of when the job was started.
215 pub job_timestamp: Option<String>,
216 /// The number of clip layers skipped during image generation.
217 pub clip_skip: Option<u32>,
218 /// Whether or not inpainting conditioning was used for image generation.
219 pub is_using_inpainting_conditioning: Option<bool>,
220}
221
222#[skip_serializing_none]
223#[derive(Default, Serialize, Deserialize, Debug, Clone)]
224/// Extra parameters describing image generation.
225pub struct ExtraGenParams {
226 /// Names and hashes of LORA models used for image generation.
227 #[serde(rename = "Lora hashes")]
228 pub lora_hashes: Option<String>,
229 /// Names and hashes of Textual Inversion models used for image generation.
230 #[serde(rename = "TI hashes")]
231 pub ti_hashes: Option<String>,
232}