1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
use reqwest::Url;
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
mod txt2img;
pub use txt2img::*;
mod img2img;
pub use img2img::*;
/// Errors that can occur when interacting with the Stable Diffusion API.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum ApiError {
    /// Error parsing endpoint URL
    #[error("Failed to parse endpoint URL")]
    ParseError(#[from] url::ParseError),
    /// Error parsing info from response
    #[error("Failed to info from response")]
    InvalidInfo(#[from] serde_json::Error),
    /// Error decoding image from response
    #[error("Failed to decode image from response")]
    DecodeError(#[from] base64::DecodeError),
}
type Result<T> = std::result::Result<T, ApiError>;
/// Struct representing a connection to a Stable Diffusion WebUI API.
#[derive(Clone, Debug)]
pub struct Api {
    client: reqwest::Client,
    url: Url,
}
impl Default for Api {
    fn default() -> Self {
        Self {
            client: reqwest::Client::new(),
            url: Url::parse("http://localhost:7860").expect("Failed to parse default URL"),
        }
    }
}
impl Api {
    /// Returns a new `Api` instance with default settings.
    pub fn new() -> Self {
        Self::default()
    }
    /// Returns a new `Api` instance with the given URL as a string value.
    ///
    /// # Arguments
    ///
    /// * `url` - A string that specifies the Stable Diffusion WebUI API URL endpoint.
    ///
    /// # Errors
    ///
    /// If the URL fails to parse, an error will be returned.
    pub fn new_with_url<S>(url: S) -> Result<Self>
    where
        S: AsRef<str>,
    {
        Ok(Self {
            url: Url::parse(url.as_ref())?,
            ..Default::default()
        })
    }
    /// Returns a new `Api` instance with the given `reqwest::Client` and URL as a string value.
    ///
    /// # Arguments
    ///
    /// * `client` - An instance of `reqwest::Client`.
    /// * `url` - A string that specifies the Stable Diffusion WebUI API URL endpoint.
    ///
    /// # Errors
    ///
    /// If the URL fails to parse, an error will be returned.
    pub fn new_with_client_and_url<S>(client: reqwest::Client, url: S) -> Result<Self>
    where
        S: AsRef<str>,
    {
        Ok(Self {
            client,
            url: Url::parse(url.as_ref())?,
        })
    }
    /// Returns a new instance of `Txt2Img` with the API's cloned `reqwest::Client` and the URL for `txt2img` endpoint.
    ///
    /// # Errors
    ///
    /// If the URL fails to parse, an error will be returned.
    pub fn txt2img(&self) -> Result<Txt2Img> {
        Ok(Txt2Img::new_with_url(
            self.client.clone(),
            self.url.join("sdapi/v1/txt2img")?,
        ))
    }
    /// Returns a new instance of `Img2Img` with the API's cloned `reqwest::Client` and the URL for `img2img` endpoint.
    ///
    /// # Errors
    ///
    /// If the URL fails to parse, an error will be returned.
    pub fn img2img(&self) -> Result<Img2Img> {
        Ok(Img2Img::new_with_url(
            self.client.clone(),
            self.url.join("sdapi/v1/img2img")?,
        ))
    }
}
/// A struct that represents the response from the Stable Diffusion WebUI API endpoint.
#[skip_serializing_none]
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
pub struct ImgResponse<T: Clone> {
    /// A vector of strings containing base64-encoded images.
    pub images: Vec<String>,
    /// The parameters that were provided for the generation request.
    pub parameters: T,
    /// A string containing JSON representing information about the request.
    pub info: String,
}
impl<T: Clone> ImgResponse<T> {
    /// Parses and returns a new `ImgInfo` instance from the `info` field of the `ImgResponse`.
    ///
    /// # Errors
    ///
    /// If the `info` field fails to parse, an error will be returned.
    pub fn info(&self) -> Result<ImgInfo> {
        Ok(serde_json::from_str(&self.info)?)
    }
    /// Decodes and returns a vector of images from the `images` field of the `ImgResponse`.
    ///
    /// # Errors
    ///
    /// If any of the images fail to decode, an error will be returned.
    pub fn images(&self) -> Result<Vec<Vec<u8>>> {
        use base64::{engine::general_purpose, Engine as _};
        self.images
            .iter()
            .map(|img| {
                general_purpose::STANDARD
                    .decode(img)
                    .map_err(ApiError::DecodeError)
            })
            .collect::<Result<Vec<_>>>()
    }
}
#[skip_serializing_none]
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
/// Information about the generated images.
pub struct ImgInfo {
    /// The prompt used when generating the image.
    pub prompt: Option<String>,
    /// A vector of all the prompts used for image generation.
    pub all_prompts: Option<Vec<String>>,
    /// The negative prompt used when generating the image.
    pub negative_prompt: Option<String>,
    /// A vector of all negative prompts used when generating the image.
    pub all_negative_prompts: Option<Vec<String>>,
    /// The random seed used for image generation.
    pub seed: Option<i64>,
    /// A vector of all the random seeds used for image generation.
    pub all_seeds: Option<Vec<i64>>,
    /// The subseed used when generating the image.
    pub subseed: Option<i64>,
    /// A vector of all the subseeds used for image generation.
    pub all_subseeds: Option<Vec<i64>>,
    /// The strength of the subseed used when generating the image.
    pub subseed_strength: Option<u32>,
    /// The width of the generated image.
    pub width: Option<i32>,
    /// The height of the generated image.
    pub height: Option<i32>,
    /// The name of the sampler used for image generation.
    pub sampler_name: Option<String>,
    /// The cfg scale factor used when generating the image.
    pub cfg_scale: Option<f64>,
    /// The number of steps taken when generating the image.
    pub steps: Option<u32>,
    /// The number of images generated in one batch.
    pub batch_size: Option<u32>,
    /// Whether or not face restoration was used.
    pub restore_faces: Option<bool>,
    /// The face restoration model used when generating the image.
    pub face_restoration_model: Option<serde_json::Value>,
    /// The name of the sd model used when generating the image.
    pub sd_model_name: Option<String>,
    /// The hash of the sd model used for image generation.
    pub sd_model_hash: Option<String>,
    /// The name of the VAE used when generating the image.
    pub sd_vae_name: Option<String>,
    /// The hash of the VAE used for image generation.
    pub sd_vae_hash: Option<String>,
    /// The width used when resizing the image seed.
    pub seed_resize_from_w: Option<i32>,
    /// The height used when resizing the image seed.
    pub seed_resize_from_h: Option<i32>,
    /// The strength of the denoising applied during image generation.
    pub denoising_strength: Option<f64>,
    /// Extra parameters passed for image generation.
    pub extra_generation_params: Option<ExtraGenParams>,
    /// The index of the first image.
    pub index_of_first_image: Option<u32>,
    /// A vector of information texts about the generated images.
    pub infotexts: Option<Vec<String>>,
    /// A vector of the styles used for image generation.
    pub styles: Option<Vec<String>>,
    /// The timestamp of when the job was started.
    pub job_timestamp: Option<String>,
    /// The number of clip layers skipped during image generation.
    pub clip_skip: Option<u32>,
    /// Whether or not inpainting conditioning was used for image generation.
    pub is_using_inpainting_conditioning: Option<bool>,
}
#[skip_serializing_none]
#[derive(Default, Serialize, Deserialize, Debug, Clone)]
/// Extra parameters describing image generation.
pub struct ExtraGenParams {
    /// Names and hashes of LORA models used for image generation.
    #[serde(rename = "Lora hashes")]
    pub lora_hashes: Option<String>,
    /// Names and hashes of Textual Inversion models used for image generation.
    #[serde(rename = "TI hashes")]
    pub ti_hashes: Option<String>,
}