comfyui_api/api/
prompt.rs

1use reqwest::Url;
2use serde::Serialize;
3use serde_with::skip_serializing_none;
4
5use crate::models::{Prompt, Response};
6
7/// Errors that can occur when interacting with `PromptApi`.
8#[derive(thiserror::Error, Debug)]
9#[non_exhaustive]
10pub enum PromptApiError {
11    /// Error parsing endpoint URL
12    #[error("Failed to parse endpoint URL")]
13    ParseError(#[from] url::ParseError),
14    /// Error sending request
15    #[error("Failed to send request")]
16    RequestFailed(#[from] reqwest::Error),
17    /// An error occurred while parsing the response from the API.
18    #[error("Parsing response failed")]
19    InvalidResponse(#[source] reqwest::Error),
20    /// An error occurred getting response data.
21    #[error("Failed to get response data")]
22    GetDataFailed(#[source] reqwest::Error),
23    /// Server returned an error when sending prompt
24    #[error("Failed to send prompt: {status}: {error}")]
25    SendPromptFailed {
26        status: reqwest::StatusCode,
27        error: String,
28    },
29}
30
31type Result<T> = std::result::Result<T, PromptApiError>;
32
33#[derive(Serialize, Debug)]
34#[skip_serializing_none]
35struct PromptWrapper<'a> {
36    prompt: &'a Prompt,
37    client_id: Option<uuid::Uuid>,
38}
39
40/// Struct representing a connection to the ComfyUI API `prompt` endpoint.
41#[derive(Clone, Debug)]
42pub struct PromptApi {
43    client: reqwest::Client,
44    endpoint: Url,
45    client_id: uuid::Uuid,
46}
47
48impl PromptApi {
49    /// Constructs a new `PromptApi` client with a given `reqwest::Client` and ComfyUI API
50    /// endpoint.
51    ///
52    /// # Arguments
53    ///
54    /// * `client` - A `reqwest::Client` used to send requests.
55    /// * `endpoint` - A `str` representation of the endpoint url.
56    ///
57    /// # Returns
58    ///
59    /// A `Result` containing a new `PromptApi` instance on success, or an error if url parsing failed.
60    pub fn new<S>(client: reqwest::Client, endpoint: S, client_id: uuid::Uuid) -> Result<Self>
61    where
62        S: AsRef<str>,
63    {
64        Ok(Self::new_with_url(
65            client,
66            Url::parse(endpoint.as_ref())?,
67            client_id,
68        ))
69    }
70
71    /// Constructs a new `PromptApi` client with a given `reqwest::Client` and endpoint `Url`.
72    ///
73    /// # Arguments
74    ///
75    /// * `client` - A `reqwest::Client` used to send requests.
76    /// * `endpoint` - A `Url` representing the endpoint url.
77    ///
78    /// # Returns
79    ///
80    /// A new `PromptApi` instance.
81    pub fn new_with_url(client: reqwest::Client, endpoint: Url, client_id: uuid::Uuid) -> Self {
82        Self {
83            client,
84            endpoint,
85            client_id,
86        }
87    }
88
89    /// Sends a prompt request using the `PromptApi` client.
90    ///
91    /// # Arguments
92    ///
93    /// * `prompt` - A `Prompt` to send to the ComfyUI API.
94    ///
95    /// # Returns
96    ///
97    /// A `Result` containing a `Response` on success, or an error if the request failed.
98    pub async fn send(&self, prompt: &Prompt) -> Result<Response> {
99        self.send_as_client(prompt, self.client_id).await
100    }
101
102    async fn send_as_client(&self, prompt: &Prompt, client_id: uuid::Uuid) -> Result<Response> {
103        let response = self
104            .client
105            .post(self.endpoint.clone())
106            .json(&PromptWrapper {
107                prompt,
108                client_id: Some(client_id),
109            })
110            .send()
111            .await?;
112        if response.status().is_success() {
113            return response
114                .json()
115                .await
116                .map_err(PromptApiError::InvalidResponse);
117        }
118        let status = response.status();
119        let text = response
120            .text()
121            .await
122            .map_err(PromptApiError::GetDataFailed)?;
123        Err(PromptApiError::SendPromptFailed {
124            status,
125            error: text,
126        })
127    }
128
129    /// Returns the client id used for requests.
130    pub fn client_id(&self) -> uuid::Uuid {
131        self.client_id
132    }
133}