stable_diffusion_api/
lib.rs

1use reqwest::Url;
2use serde::{Deserialize, Serialize};
3use serde_with::skip_serializing_none;
4
5mod txt2img;
6pub use txt2img::*;
7
8mod img2img;
9pub use img2img::*;
10
11/// Errors that can occur when interacting with the Stable Diffusion API.
12#[derive(thiserror::Error, Debug)]
13#[non_exhaustive]
14pub enum ApiError {
15    /// Error parsing endpoint URL
16    #[error("Failed to parse endpoint URL")]
17    ParseError(#[from] url::ParseError),
18    /// Error parsing info from response
19    #[error("Failed to info from response")]
20    InvalidInfo(#[from] serde_json::Error),
21    /// Error decoding image from response
22    #[error("Failed to decode image from response")]
23    DecodeError(#[from] base64::DecodeError),
24}
25
26type Result<T> = std::result::Result<T, ApiError>;
27
28/// Struct representing a connection to a Stable Diffusion WebUI API.
29#[derive(Clone, Debug)]
30pub struct Api {
31    client: reqwest::Client,
32    url: Url,
33}
34
35impl Default for Api {
36    fn default() -> Self {
37        Self {
38            client: reqwest::Client::new(),
39            url: Url::parse("http://localhost:7860").expect("Failed to parse default URL"),
40        }
41    }
42}
43
44impl Api {
45    /// Returns a new `Api` instance with default settings.
46    pub fn new() -> Self {
47        Self::default()
48    }
49
50    /// Returns a new `Api` instance with the given URL as a string value.
51    ///
52    /// # Arguments
53    ///
54    /// * `url` - A string that specifies the Stable Diffusion WebUI API URL endpoint.
55    ///
56    /// # Errors
57    ///
58    /// If the URL fails to parse, an error will be returned.
59    pub fn new_with_url<S>(url: S) -> Result<Self>
60    where
61        S: AsRef<str>,
62    {
63        Ok(Self {
64            url: Url::parse(url.as_ref())?,
65            ..Default::default()
66        })
67    }
68
69    /// Returns a new `Api` instance with the given `reqwest::Client` and URL as a string value.
70    ///
71    /// # Arguments
72    ///
73    /// * `client` - An instance of `reqwest::Client`.
74    /// * `url` - A string that specifies the Stable Diffusion WebUI API URL endpoint.
75    ///
76    /// # Errors
77    ///
78    /// If the URL fails to parse, an error will be returned.
79    pub fn new_with_client_and_url<S>(client: reqwest::Client, url: S) -> Result<Self>
80    where
81        S: AsRef<str>,
82    {
83        Ok(Self {
84            client,
85            url: Url::parse(url.as_ref())?,
86        })
87    }
88
89    /// Returns a new instance of `Txt2Img` with the API's cloned `reqwest::Client` and the URL for `txt2img` endpoint.
90    ///
91    /// # Errors
92    ///
93    /// If the URL fails to parse, an error will be returned.
94    pub fn txt2img(&self) -> Result<Txt2Img> {
95        Ok(Txt2Img::new_with_url(
96            self.client.clone(),
97            self.url.join("sdapi/v1/txt2img")?,
98        ))
99    }
100
101    /// Returns a new instance of `Img2Img` with the API's cloned `reqwest::Client` and the URL for `img2img` endpoint.
102    ///
103    /// # Errors
104    ///
105    /// If the URL fails to parse, an error will be returned.
106    pub fn img2img(&self) -> Result<Img2Img> {
107        Ok(Img2Img::new_with_url(
108            self.client.clone(),
109            self.url.join("sdapi/v1/img2img")?,
110        ))
111    }
112}
113
114/// A struct that represents the response from the Stable Diffusion WebUI API endpoint.
115#[skip_serializing_none]
116#[derive(Default, Serialize, Deserialize, Debug, Clone)]
117pub struct ImgResponse<T: Clone> {
118    /// A vector of strings containing base64-encoded images.
119    pub images: Vec<String>,
120    /// The parameters that were provided for the generation request.
121    pub parameters: T,
122    /// A string containing JSON representing information about the request.
123    pub info: String,
124}
125
126impl<T: Clone> ImgResponse<T> {
127    /// Parses and returns a new `ImgInfo` instance from the `info` field of the `ImgResponse`.
128    ///
129    /// # Errors
130    ///
131    /// If the `info` field fails to parse, an error will be returned.
132    pub fn info(&self) -> Result<ImgInfo> {
133        Ok(serde_json::from_str(&self.info)?)
134    }
135
136    /// Decodes and returns a vector of images from the `images` field of the `ImgResponse`.
137    ///
138    /// # Errors
139    ///
140    /// If any of the images fail to decode, an error will be returned.
141    pub fn images(&self) -> Result<Vec<Vec<u8>>> {
142        use base64::{engine::general_purpose, Engine as _};
143        self.images
144            .iter()
145            .map(|img| {
146                general_purpose::STANDARD
147                    .decode(img)
148                    .map_err(ApiError::DecodeError)
149            })
150            .collect::<Result<Vec<_>>>()
151    }
152}
153
154#[skip_serializing_none]
155#[derive(Default, Serialize, Deserialize, Debug, Clone)]
156/// Information about the generated images.
157pub struct ImgInfo {
158    /// The prompt used when generating the image.
159    pub prompt: Option<String>,
160    /// A vector of all the prompts used for image generation.
161    pub all_prompts: Option<Vec<String>>,
162    /// The negative prompt used when generating the image.
163    pub negative_prompt: Option<String>,
164    /// A vector of all negative prompts used when generating the image.
165    pub all_negative_prompts: Option<Vec<String>>,
166    /// The random seed used for image generation.
167    pub seed: Option<i64>,
168    /// A vector of all the random seeds used for image generation.
169    pub all_seeds: Option<Vec<i64>>,
170    /// The subseed used when generating the image.
171    pub subseed: Option<i64>,
172    /// A vector of all the subseeds used for image generation.
173    pub all_subseeds: Option<Vec<i64>>,
174    /// The strength of the subseed used when generating the image.
175    pub subseed_strength: Option<u32>,
176    /// The width of the generated image.
177    pub width: Option<i32>,
178    /// The height of the generated image.
179    pub height: Option<i32>,
180    /// The name of the sampler used for image generation.
181    pub sampler_name: Option<String>,
182    /// The cfg scale factor used when generating the image.
183    pub cfg_scale: Option<f64>,
184    /// The number of steps taken when generating the image.
185    pub steps: Option<u32>,
186    /// The number of images generated in one batch.
187    pub batch_size: Option<u32>,
188    /// Whether or not face restoration was used.
189    pub restore_faces: Option<bool>,
190    /// The face restoration model used when generating the image.
191    pub face_restoration_model: Option<serde_json::Value>,
192    /// The name of the sd model used when generating the image.
193    pub sd_model_name: Option<String>,
194    /// The hash of the sd model used for image generation.
195    pub sd_model_hash: Option<String>,
196    /// The name of the VAE used when generating the image.
197    pub sd_vae_name: Option<String>,
198    /// The hash of the VAE used for image generation.
199    pub sd_vae_hash: Option<String>,
200    /// The width used when resizing the image seed.
201    pub seed_resize_from_w: Option<i32>,
202    /// The height used when resizing the image seed.
203    pub seed_resize_from_h: Option<i32>,
204    /// The strength of the denoising applied during image generation.
205    pub denoising_strength: Option<f64>,
206    /// Extra parameters passed for image generation.
207    pub extra_generation_params: Option<ExtraGenParams>,
208    /// The index of the first image.
209    pub index_of_first_image: Option<u32>,
210    /// A vector of information texts about the generated images.
211    pub infotexts: Option<Vec<String>>,
212    /// A vector of the styles used for image generation.
213    pub styles: Option<Vec<String>>,
214    /// The timestamp of when the job was started.
215    pub job_timestamp: Option<String>,
216    /// The number of clip layers skipped during image generation.
217    pub clip_skip: Option<u32>,
218    /// Whether or not inpainting conditioning was used for image generation.
219    pub is_using_inpainting_conditioning: Option<bool>,
220}
221
222#[skip_serializing_none]
223#[derive(Default, Serialize, Deserialize, Debug, Clone)]
224/// Extra parameters describing image generation.
225pub struct ExtraGenParams {
226    /// Names and hashes of LORA models used for image generation.
227    #[serde(rename = "Lora hashes")]
228    pub lora_hashes: Option<String>,
229    /// Names and hashes of Textual Inversion models used for image generation.
230    #[serde(rename = "TI hashes")]
231    pub ti_hashes: Option<String>,
232}