1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
use reqwest::{multipart, Url};
use serde::{Deserialize, Serialize};
/// Errors that can occur when interacting with `UploadApi`.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum UploadApiError {
/// Error parsing endpoint URL
#[error("Failed to parse endpoint URL")]
ParseError(#[from] url::ParseError),
/// Error setting MIME type
#[error("Failed to set MIME type")]
SetMimeStrFailed(#[source] reqwest::Error),
/// Error sending request
#[error("Failed to send request")]
RequestFailed(#[from] reqwest::Error),
/// An error occurred while parsing the response from the API.
#[error("Parsing response failed")]
InvalidResponse(#[source] reqwest::Error),
/// An error occurred getting response data.
#[error("Failed to get response data")]
GetDataFailed(#[source] reqwest::Error),
/// Server returned an error when uploading file
#[error("Failed to upload image: {status}: {error}")]
UploadImageFailed {
status: reqwest::StatusCode,
error: String,
},
}
type Result<T> = std::result::Result<T, UploadApiError>;
/// Struct representing an image.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImageUpload {
/// The filename of the image.
pub name: String,
/// The subfolder.
pub subfolder: String,
/// The folder type.
#[serde(rename = "type")]
pub folder_type: String,
}
/// Struct representing a connection to the ComfyUI API `upload` endpoint.
#[derive(Clone, Debug)]
pub struct UploadApi {
client: reqwest::Client,
endpoint: Url,
}
impl UploadApi {
/// Constructs a new `UploadApi` client with a given `reqwest::Client` and ComfyUI API
/// endpoint.
///
/// # Arguments
///
/// * `client` - A `reqwest::Client` used to send requests.
/// * `endpoint` - A `str` representation of the endpoint url.
///
/// # Returns
///
/// A `Result` containing a new `UploadApi` instance on success, or an error if url parsing failed.
pub fn new<S>(client: reqwest::Client, endpoint: S) -> Result<Self>
where
S: AsRef<str>,
{
Ok(Self::new_with_url(client, Url::parse(endpoint.as_ref())?))
}
/// Constructs a new `UploadApi` client with a given `reqwest::Client` and endpoint `Url`.
///
/// # Arguments
///
/// * `client` - A `reqwest::Client` used to send requests.
/// * `endpoint` - A `Url` representing the endpoint url.
///
/// # Returns
///
/// A new `UploadApi` instance.
pub fn new_with_url(client: reqwest::Client, endpoint: Url) -> Self {
Self { client, endpoint }
}
/// Uploads an image using the `UploadApi` client.
///
/// # Arguments
///
/// * `image` - A `Vec<u8>` containing the image to upload.
///
/// # Returns
///
/// A `Result` containing an `Image` struct containing information about the image.
/// success, or an error if the request failed.
pub async fn image(&self, image: Vec<u8>) -> Result<ImageUpload> {
let file = multipart::Part::bytes(image)
.file_name("image.png")
.mime_str("image/png")
.map_err(UploadApiError::SetMimeStrFailed)?;
let form = multipart::Form::new().part("image", file);
let response = self
.client
.post(self.endpoint.clone().join("image")?)
.multipart(form)
.send()
.await
.map_err(UploadApiError::RequestFailed)?;
if response.status().is_success() {
return response
.json()
.await
.map_err(UploadApiError::InvalidResponse);
}
let status = response.status();
let text = response
.text()
.await
.map_err(UploadApiError::GetDataFailed)?;
Err(UploadApiError::UploadImageFailed {
status,
error: text,
})
}
}