1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
use reqwest::Url;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;

mod txt2img;
pub use txt2img::*;

mod img2img;
pub use img2img::*;

/// Errors that can occur when interacting with the Stable Diffusion API.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum ApiError {
    /// Error parsing endpoint URL
    #[error("Failed to parse endpoint URL")]
    ParseError(#[from] url::ParseError),
    /// Error parsing info from response
    #[error("Failed to info from response")]
    InvalidInfo(#[from] serde_json::Error),
    /// Error decoding image from response
    #[error("Failed to decode image from response")]
    DecodeError(#[from] base64::DecodeError),
}

type Result<T> = std::result::Result<T, ApiError>;

/// Struct representing a connection to a Stable Diffusion WebUI API.
#[derive(Clone, Debug)]
pub struct Api {
    client: reqwest::Client,
    url: Url,
}

impl Default for Api {
    fn default() -> Self {
        Self {
            client: reqwest::Client::new(),
            url: Url::parse("http://localhost:7860").expect("Failed to parse default URL"),
        }
    }
}

impl Api {
    /// Returns a new `Api` instance with default settings.
    pub fn new() -> Self {
        Self::default()
    }

    /// Returns a new `Api` instance with the given URL as a string value.
    ///
    /// # Arguments
    ///
    /// * `url` - A string that specifies the Stable Diffusion WebUI API URL endpoint.
    ///
    /// # Errors
    ///
    /// If the URL fails to parse, an error will be returned.
    pub fn new_with_url<S>(url: S) -> Result<Self>
    where
        S: AsRef<str>,
    {
        Ok(Self {
            url: Url::parse(url.as_ref())?,
            ..Default::default()
        })
    }

    /// Returns a new `Api` instance with the given `reqwest::Client` and URL as a string value.
    ///
    /// # Arguments
    ///
    /// * `client` - An instance of `reqwest::Client`.
    /// * `url` - A string that specifies the Stable Diffusion WebUI API URL endpoint.
    ///
    /// # Errors
    ///
    /// If the URL fails to parse, an error will be returned.
    pub fn new_with_client_and_url<S>(client: reqwest::Client, url: S) -> Result<Self>
    where
        S: AsRef<str>,
    {
        Ok(Self {
            client,
            url: Url::parse(url.as_ref())?,
        })
    }

    /// Returns a new instance of `Txt2Img` with the API's cloned `reqwest::Client` and the URL for `txt2img` endpoint.
    ///
    /// # Errors
    ///
    /// If the URL fails to parse, an error will be returned.
    pub fn txt2img(&self) -> Result<Txt2Img> {
        Ok(Txt2Img::new_with_url(
            self.client.clone(),
            self.url.join("sdapi/v1/txt2img")?,
        ))
    }

    /// Returns a new instance of `Img2Img` with the API's cloned `reqwest::Client` and the URL for `img2img` endpoint.
    ///
    /// # Errors
    ///
    /// If the URL fails to parse, an error will be returned.
    pub fn img2img(&self) -> Result<Img2Img> {
        Ok(Img2Img::new_with_url(
            self.client.clone(),
            self.url.join("sdapi/v1/img2img")?,
        ))
    }
}

/// A struct that represents the response from the Stable Diffusion WebUI API endpoint.
#[skip_serializing_none]
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
pub struct ImgResponse<T: Clone> {
    /// A vector of strings containing base64-encoded images.
    pub images: Vec<String>,
    /// The parameters that were provided for the generation request.
    pub parameters: T,
    /// A string containing JSON representing information about the request.
    pub info: String,
}

impl<T: Clone> ImgResponse<T> {
    /// Parses and returns a new `ImgInfo` instance from the `info` field of the `ImgResponse`.
    ///
    /// # Errors
    ///
    /// If the `info` field fails to parse, an error will be returned.
    pub fn info(&self) -> Result<ImgInfo> {
        Ok(serde_json::from_str(&self.info)?)
    }

    /// Decodes and returns a vector of images from the `images` field of the `ImgResponse`.
    ///
    /// # Errors
    ///
    /// If any of the images fail to decode, an error will be returned.
    pub fn images(&self) -> Result<Vec<Vec<u8>>> {
        use base64::{engine::general_purpose, Engine as _};
        self.images
            .iter()
            .map(|img| {
                general_purpose::STANDARD
                    .decode(img)
                    .map_err(ApiError::DecodeError)
            })
            .collect::<Result<Vec<_>>>()
    }
}

#[skip_serializing_none]
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
/// Information about the generated images.
pub struct ImgInfo {
    /// The prompt used when generating the image.
    pub prompt: Option<String>,
    /// A vector of all the prompts used for image generation.
    pub all_prompts: Option<Vec<String>>,
    /// The negative prompt used when generating the image.
    pub negative_prompt: Option<String>,
    /// A vector of all negative prompts used when generating the image.
    pub all_negative_prompts: Option<Vec<String>>,
    /// The random seed used for image generation.
    pub seed: Option<i64>,
    /// A vector of all the random seeds used for image generation.
    pub all_seeds: Option<Vec<i64>>,
    /// The subseed used when generating the image.
    pub subseed: Option<i64>,
    /// A vector of all the subseeds used for image generation.
    pub all_subseeds: Option<Vec<i64>>,
    /// The strength of the subseed used when generating the image.
    pub subseed_strength: Option<u32>,
    /// The width of the generated image.
    pub width: Option<i32>,
    /// The height of the generated image.
    pub height: Option<i32>,
    /// The name of the sampler used for image generation.
    pub sampler_name: Option<String>,
    /// The cfg scale factor used when generating the image.
    pub cfg_scale: Option<f64>,
    /// The number of steps taken when generating the image.
    pub steps: Option<u32>,
    /// The number of images generated in one batch.
    pub batch_size: Option<u32>,
    /// Whether or not face restoration was used.
    pub restore_faces: Option<bool>,
    /// The face restoration model used when generating the image.
    pub face_restoration_model: Option<serde_json::Value>,
    /// The name of the sd model used when generating the image.
    pub sd_model_name: Option<String>,
    /// The hash of the sd model used for image generation.
    pub sd_model_hash: Option<String>,
    /// The name of the VAE used when generating the image.
    pub sd_vae_name: Option<String>,
    /// The hash of the VAE used for image generation.
    pub sd_vae_hash: Option<String>,
    /// The width used when resizing the image seed.
    pub seed_resize_from_w: Option<i32>,
    /// The height used when resizing the image seed.
    pub seed_resize_from_h: Option<i32>,
    /// The strength of the denoising applied during image generation.
    pub denoising_strength: Option<f64>,
    /// Extra parameters passed for image generation.
    pub extra_generation_params: Option<ExtraGenParams>,
    /// The index of the first image.
    pub index_of_first_image: Option<u32>,
    /// A vector of information texts about the generated images.
    pub infotexts: Option<Vec<String>>,
    /// A vector of the styles used for image generation.
    pub styles: Option<Vec<String>>,
    /// The timestamp of when the job was started.
    pub job_timestamp: Option<String>,
    /// The number of clip layers skipped during image generation.
    pub clip_skip: Option<u32>,
    /// Whether or not inpainting conditioning was used for image generation.
    pub is_using_inpainting_conditioning: Option<bool>,
}

#[skip_serializing_none]
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
/// Extra parameters describing image generation.
pub struct ExtraGenParams {
    /// Names and hashes of LORA models used for image generation.
    #[serde(rename = "Lora hashes")]
    pub lora_hashes: Option<String>,
    /// Names and hashes of Textual Inversion models used for image generation.
    #[serde(rename = "TI hashes")]
    pub ti_hashes: Option<String>,
}