1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
use reqwest::Url;
use serde::Serialize;
use serde_with::skip_serializing_none;
use crate::models::{Prompt, Response};
/// Errors that can occur when interacting with `PromptApi`.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum PromptApiError {
/// Error parsing endpoint URL
#[error("Failed to parse endpoint URL")]
ParseError(#[from] url::ParseError),
/// Error sending request
#[error("Failed to send request")]
RequestFailed(#[from] reqwest::Error),
/// An error occurred while parsing the response from the API.
#[error("Parsing response failed")]
InvalidResponse(#[source] reqwest::Error),
/// An error occurred getting response data.
#[error("Failed to get response data")]
GetDataFailed(#[source] reqwest::Error),
/// Server returned an error when sending prompt
#[error("Failed to send prompt: {status}: {error}")]
SendPromptFailed {
status: reqwest::StatusCode,
error: String,
},
}
type Result<T> = std::result::Result<T, PromptApiError>;
#[derive(Serialize, Debug)]
#[skip_serializing_none]
struct PromptWrapper<'a> {
prompt: &'a Prompt,
client_id: Option<uuid::Uuid>,
}
/// Struct representing a connection to the ComfyUI API `prompt` endpoint.
#[derive(Clone, Debug)]
pub struct PromptApi {
client: reqwest::Client,
endpoint: Url,
client_id: uuid::Uuid,
}
impl PromptApi {
/// Constructs a new `PromptApi` client with a given `reqwest::Client` and ComfyUI API
/// endpoint.
///
/// # Arguments
///
/// * `client` - A `reqwest::Client` used to send requests.
/// * `endpoint` - A `str` representation of the endpoint url.
///
/// # Returns
///
/// A `Result` containing a new `PromptApi` instance on success, or an error if url parsing failed.
pub fn new<S>(client: reqwest::Client, endpoint: S, client_id: uuid::Uuid) -> Result<Self>
where
S: AsRef<str>,
{
Ok(Self::new_with_url(
client,
Url::parse(endpoint.as_ref())?,
client_id,
))
}
/// Constructs a new `PromptApi` client with a given `reqwest::Client` and endpoint `Url`.
///
/// # Arguments
///
/// * `client` - A `reqwest::Client` used to send requests.
/// * `endpoint` - A `Url` representing the endpoint url.
///
/// # Returns
///
/// A new `PromptApi` instance.
pub fn new_with_url(client: reqwest::Client, endpoint: Url, client_id: uuid::Uuid) -> Self {
Self {
client,
endpoint,
client_id,
}
}
/// Sends a prompt request using the `PromptApi` client.
///
/// # Arguments
///
/// * `prompt` - A `Prompt` to send to the ComfyUI API.
///
/// # Returns
///
/// A `Result` containing a `Response` on success, or an error if the request failed.
pub async fn send(&self, prompt: &Prompt) -> Result<Response> {
self.send_as_client(prompt, self.client_id).await
}
async fn send_as_client(&self, prompt: &Prompt, client_id: uuid::Uuid) -> Result<Response> {
let response = self
.client
.post(self.endpoint.clone())
.json(&PromptWrapper {
prompt,
client_id: Some(client_id),
})
.send()
.await?;
if response.status().is_success() {
return response
.json()
.await
.map_err(PromptApiError::InvalidResponse);
}
let status = response.status();
let text = response
.text()
.await
.map_err(PromptApiError::GetDataFailed)?;
Err(PromptApiError::SendPromptFailed {
status,
error: text,
})
}
/// Returns the client id used for requests.
pub fn client_id(&self) -> uuid::Uuid {
self.client_id
}
}