sal_e_api/
api.rs

1use anyhow::Context;
2use async_trait::async_trait;
3use comfyui_api::{comfy::getter::*, models::AsAny};
4use dyn_clone::DynClone;
5use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest};
6
7use crate::{ComfyParams, Img2ImgParams, Txt2ImgParams};
8
9/// Struct representing a response from a Stable Diffusion API image generation endpoint.
10#[derive(Debug, Clone)]
11pub struct Response {
12    /// A vector of images.
13    pub images: Vec<Vec<u8>>,
14    /// The parameters describing the generated image.
15    pub params: Box<dyn crate::image_params::ImageParams>,
16    /// The parameters that were provided for the generation request.
17    pub gen_params: Box<dyn crate::gen_params::GenParams>,
18}
19
20#[derive(thiserror::Error, Debug)]
21#[non_exhaustive]
22pub enum ComfyPromptApiError {
23    /// Error creating a ComfyUI Client
24    #[error("Error creating a ComfyUI Client")]
25    CreateClient(#[from] comfyui_api::comfy::ComfyApiError),
26}
27
28/// Struct wrapping a connection to the ComfyUI API.
29#[derive(Debug, Clone, Default)]
30pub struct ComfyPromptApi {
31    /// The ComfyUI client.
32    pub client: comfyui_api::comfy::Comfy,
33    /// Default parameters for the ComfyUI API.
34    pub params: crate::gen_params::ComfyParams,
35    /// The output node.
36    pub output_node: Option<String>,
37    /// The prompt node.
38    pub prompt_node: Option<String>,
39}
40
41impl ComfyPromptApi {
42    /// Constructs a new `ComfyPromptApi` client with the provided prompt.
43    ///
44    /// # Arguments
45    ///
46    /// * `prompt` - The prompt to use for the API.
47    ///
48    /// # Returns
49    ///
50    /// A new `ComfyPromptApi` instance on success, or an error if there was a failure in the ComfyUI API client.
51    pub fn new(prompt: comfyui_api::models::Prompt) -> Result<Self, ComfyPromptApiError> {
52        Ok(Self {
53            client: comfyui_api::comfy::Comfy::new()?,
54            params: crate::gen_params::ComfyParams {
55                prompt: Some(prompt),
56                count: 1,
57                seed: Some(-1),
58                ..Default::default()
59            },
60            ..Default::default()
61        })
62    }
63
64    /// Constructs a new `ComfyPromptApi` client with the provided url and prompt.
65    ///
66    /// # Arguments
67    ///
68    /// * `url` - The URL to use for the API. Must be a valid URL, e.g. `http://localhost:8188
69    /// * `prompt` - The prompt to use for the API.
70    ///
71    /// # Returns
72    ///
73    /// A new `ComfyPromptApi` instance on success, or an error if there was a failure in the ComfyUI API client.
74    pub fn new_with_url<S>(
75        url: S,
76        prompt: comfyui_api::models::Prompt,
77    ) -> Result<Self, ComfyPromptApiError>
78    where
79        S: AsRef<str>,
80    {
81        Ok(Self {
82            client: comfyui_api::comfy::Comfy::new_with_url(url)?,
83            params: crate::gen_params::ComfyParams {
84                prompt: Some(prompt),
85                count: 1,
86                seed: Some(-1),
87                ..Default::default()
88            },
89            ..Default::default()
90        })
91    }
92
93    /// Constructs a new `ComfyPromptApi` client with the provided url and prompt.
94    ///
95    /// # Arguments
96    ///
97    /// * `client` - An instance of `reqwest::Client`.
98    /// * `url` - The URL to use for the API. Must be a valid URL, e.g. `http://localhost:8188
99    /// * `prompt` - The prompt to use for the API.
100    ///
101    /// # Returns
102    ///
103    /// A new `ComfyPromptApi` instance on success, or an error if there was a failure in the ComfyUI API client.
104    pub fn new_with_client_and_url<S>(
105        client: reqwest::Client,
106        url: S,
107        prompt: comfyui_api::models::Prompt,
108    ) -> anyhow::Result<Self>
109    where
110        S: AsRef<str>,
111    {
112        Ok(Self {
113            client: comfyui_api::comfy::Comfy::new_with_client_and_url(client, url)?,
114            params: crate::gen_params::ComfyParams {
115                prompt: Some(prompt),
116                count: 1,
117                seed: Some(-1),
118                ..Default::default()
119            },
120            ..Default::default()
121        })
122    }
123}
124
125#[derive(thiserror::Error, Debug)]
126#[non_exhaustive]
127pub enum Txt2ImgApiError {
128    /// Prompt was empty.
129    #[error("Prompt was empty.")]
130    EmptyPrompt,
131    /// Error running txt2img.
132    #[error("Error running txt2img.")]
133    Txt2Img(#[from] anyhow::Error),
134    /// Error parsing response.
135    #[error("Error parsing response.")]
136    ParseResponse(#[source] anyhow::Error),
137}
138
139dyn_clone::clone_trait_object!(Txt2ImgApi);
140
141/// Trait representing a Txt2Img endpoint.
142#[async_trait]
143pub trait Txt2ImgApi: std::fmt::Debug + DynClone + Send + Sync + AsAny {
144    /// Generates an image using text-to-image.
145    ///
146    /// # Arguments
147    ///
148    /// * `config` - The configuration to use for the generation.
149    ///
150    /// # Returns
151    ///
152    /// A `Result` containing a `Response` on success, or an error if the request failed.
153    async fn txt2img(
154        &self,
155        config: &dyn crate::gen_params::GenParams,
156    ) -> Result<Response, Txt2ImgApiError>;
157
158    /// Returns the default generation parameters for this endpoint.
159    ///
160    /// # Arguments
161    ///
162    /// * `user_settings` - The user settings to merge with the defaults.
163    ///
164    /// # Returns
165    ///
166    /// A `Box<dyn crate::gen_params::GenParams>` containing the generation parameters.
167    fn gen_params(
168        &self,
169        user_settings: Option<&dyn crate::gen_params::GenParams>,
170    ) -> Box<dyn crate::gen_params::GenParams>;
171}
172
173#[derive(thiserror::Error, Debug)]
174#[non_exhaustive]
175pub enum Img2ImgApiError {
176    /// Prompt was empty.
177    #[error("Prompt was empty.")]
178    EmptyPrompt,
179    /// Error running txt2img.
180    #[error("Error running img2img.")]
181    Img2Img(#[from] anyhow::Error),
182    /// Error parsing response.
183    #[error("Error parsing response.")]
184    ParseResponse(#[source] anyhow::Error),
185    /// No image provided.
186    #[error("No image provided.")]
187    NoImage,
188    /// Error uploading image.
189    #[error("Error uploading image.")]
190    UploadImage(#[source] anyhow::Error),
191}
192
193dyn_clone::clone_trait_object!(Img2ImgApi);
194
195/// Trait representing an Img2Img endpoint.
196#[async_trait]
197pub trait Img2ImgApi: std::fmt::Debug + DynClone + Send + Sync + AsAny {
198    /// Generates an image using image-to-image.
199    ///
200    /// # Arguments
201    ///
202    /// * `config` - The configuration to use for the generation.
203    ///
204    /// # Returns
205    ///
206    /// A `Result` containing a `Response` on success, or an error if the request failed.
207    async fn img2img(
208        &self,
209        config: &dyn crate::gen_params::GenParams,
210    ) -> Result<Response, Img2ImgApiError>;
211
212    /// Returns the default generation parameters for this endpoint.
213    ///
214    /// # Arguments
215    ///
216    /// * `user_settings` - The user settings to merge with the defaults.
217    ///
218    /// # Returns
219    ///
220    /// A `Box<dyn crate::gen_params::GenParams>` containing the generation parameters.
221    fn gen_params(
222        &self,
223        user_settings: Option<&dyn crate::gen_params::GenParams>,
224    ) -> Box<dyn crate::gen_params::GenParams>;
225}
226
227#[async_trait]
228impl Txt2ImgApi for ComfyPromptApi {
229    async fn txt2img(
230        &self,
231        config: &dyn crate::gen_params::GenParams,
232    ) -> Result<Response, Txt2ImgApiError> {
233        let base_prompt = config.as_any().downcast_ref().unwrap_or(&self.params);
234
235        let mut new_prompt = base_prompt.clone();
236        if let Some(-1) = new_prompt.seed {
237            new_prompt.seed = Some(rand::random::<i64>().abs());
238        }
239
240        let prompt = new_prompt.apply().context(Txt2ImgApiError::EmptyPrompt)?;
241
242        let images = self
243            .client
244            .execute_prompt(&prompt)
245            .await
246            .context("Failed to execute prompt")?;
247        Ok(Response {
248            images: images.into_iter().map(|image| image.image).collect(),
249            params: Box::new(prompt),
250            gen_params: Box::new(base_prompt.clone()),
251        })
252    }
253
254    fn gen_params(
255        &self,
256        user_settings: Option<&dyn crate::gen_params::GenParams>,
257    ) -> Box<dyn crate::gen_params::GenParams> {
258        if let Some(user_settings) = user_settings {
259            let mut params = ComfyParams::from(user_settings);
260            params.prompt = self.params.prompt.clone();
261            Box::new(params)
262        } else {
263            Box::new(self.params.clone())
264        }
265    }
266}
267
268#[async_trait]
269impl Img2ImgApi for ComfyPromptApi {
270    async fn img2img(
271        &self,
272        config: &dyn crate::gen_params::GenParams,
273    ) -> Result<Response, Img2ImgApiError> {
274        let base_prompt = config.as_any().downcast_ref().unwrap_or(&self.params);
275
276        let resp = if let Some(image) = &base_prompt.image {
277            self.client
278                .upload_file(image.clone())
279                .await
280                .context("Failed to upload image")
281                .map_err(Img2ImgApiError::UploadImage)?
282        } else {
283            return Err(Img2ImgApiError::NoImage);
284        };
285
286        let mut new_prompt = base_prompt.clone();
287        if let Some(-1) = new_prompt.seed {
288            new_prompt.seed = Some(rand::random::<i64>().abs());
289        }
290
291        let mut prompt = new_prompt.apply().context(Img2ImgApiError::EmptyPrompt)?;
292
293        *prompt.image_mut()? = resp.name;
294
295        let images = self
296            .client
297            .execute_prompt(&prompt)
298            .await
299            .context("Failed to execute prompt")?;
300        Ok(Response {
301            images: images.into_iter().map(|image| image.image).collect(),
302            params: Box::new(prompt.clone()),
303            gen_params: Box::new(base_prompt.clone()),
304        })
305    }
306
307    fn gen_params(
308        &self,
309        user_settings: Option<&dyn crate::gen_params::GenParams>,
310    ) -> Box<dyn crate::gen_params::GenParams> {
311        if let Some(user_settings) = user_settings {
312            let mut params = ComfyParams::from(user_settings);
313            params.prompt = self.params.prompt.clone();
314            Box::new(params)
315        } else {
316            Box::new(self.params.clone())
317        }
318    }
319}
320
321/// Struct wrapping a connection to the Stable Diffusion WebUI API.
322#[derive(Debug, Clone, Default)]
323pub struct StableDiffusionWebUiApi {
324    /// The Stable Diffusion WebUI client.
325    pub client: stable_diffusion_api::Api,
326    /// Default parameters for the Txt2Img endpoint.
327    pub txt2img_defaults: Txt2ImgRequest,
328    /// Default parameters for the Img2Img endpoint.
329    pub img2img_defaults: Img2ImgRequest,
330}
331
332impl StableDiffusionWebUiApi {
333    /// Constructs a new `StableDiffusionWebUiApi` client with the default parameters.
334    pub fn new() -> Self {
335        Self::default()
336    }
337}
338
339#[async_trait]
340impl Txt2ImgApi for StableDiffusionWebUiApi {
341    async fn txt2img(
342        &self,
343        config: &dyn crate::gen_params::GenParams,
344    ) -> Result<Response, Txt2ImgApiError> {
345        let config = Txt2ImgParams::from(config);
346        let txt2img = self
347            .client
348            .txt2img()
349            .context("Failed to open txt2img API")?;
350        let resp = txt2img
351            .send(&config.user_params)
352            .await
353            .context("Failed to send request")?;
354        let params = Box::new(
355            resp.info()
356                .context("Failed to parse info from response")
357                .map_err(Txt2ImgApiError::ParseResponse)?,
358        );
359        Ok(Response {
360            images: resp
361                .images()
362                .context("Failed to parse image from response")
363                .map_err(Txt2ImgApiError::ParseResponse)?,
364            params: params.clone(),
365            gen_params: Box::new(Txt2ImgParams {
366                user_params: resp.parameters.clone(),
367                defaults: Some(self.txt2img_defaults.clone()),
368            }),
369        })
370    }
371
372    fn gen_params(
373        &self,
374        user_settings: Option<&dyn crate::gen_params::GenParams>,
375    ) -> Box<dyn crate::gen_params::GenParams> {
376        if let Some(user_settings) = user_settings {
377            Box::new(Txt2ImgParams {
378                user_params: Txt2ImgParams::from(user_settings).user_params,
379                defaults: Some(self.txt2img_defaults.clone()),
380            })
381        } else {
382            Box::new(Txt2ImgParams {
383                user_params: Txt2ImgRequest::default(),
384                defaults: Some(self.txt2img_defaults.clone()),
385            })
386        }
387    }
388}
389
390#[async_trait]
391impl Img2ImgApi for StableDiffusionWebUiApi {
392    async fn img2img(
393        &self,
394        config: &dyn crate::gen_params::GenParams,
395    ) -> Result<Response, Img2ImgApiError> {
396        let config = Img2ImgParams::from(config);
397        let img2img = self
398            .client
399            .img2img()
400            .context("Failed to open img2img API")?;
401        let resp = img2img
402            .send(&config.user_params)
403            .await
404            .context("Failed to send request")?;
405        let params = Box::new(
406            resp.info()
407                .context("Failed to parse info from response")
408                .map_err(Img2ImgApiError::ParseResponse)?,
409        );
410        Ok(Response {
411            images: resp
412                .images()
413                .context("Failed to parse image from response")
414                .map_err(Img2ImgApiError::ParseResponse)?,
415            params: params.clone(),
416            gen_params: Box::new(Img2ImgParams {
417                user_params: resp.parameters.clone(),
418                defaults: Some(self.img2img_defaults.clone()),
419            }),
420        })
421    }
422
423    fn gen_params(
424        &self,
425        user_settings: Option<&dyn crate::gen_params::GenParams>,
426    ) -> Box<dyn crate::gen_params::GenParams> {
427        if let Some(user_settings) = user_settings {
428            Box::new(Txt2ImgParams {
429                user_params: Txt2ImgParams::from(user_settings).user_params,
430                defaults: Some(self.txt2img_defaults.clone()),
431            })
432        } else {
433            Box::new(Txt2ImgParams {
434                user_params: Txt2ImgRequest::default(),
435                defaults: Some(self.txt2img_defaults.clone()),
436            })
437        }
438    }
439}