1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use reqwest::{multipart, Url};
use serde::{Deserialize, Serialize};

/// Errors that can occur when interacting with `UploadApi`.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum UploadApiError {
    /// Error parsing endpoint URL
    #[error("Failed to parse endpoint URL")]
    ParseError(#[from] url::ParseError),
    /// Error setting MIME type
    #[error("Failed to set MIME type")]
    SetMimeStrFailed(#[source] reqwest::Error),
    /// Error sending request
    #[error("Failed to send request")]
    RequestFailed(#[from] reqwest::Error),
    /// An error occurred while parsing the response from the API.
    #[error("Parsing response failed")]
    InvalidResponse(#[source] reqwest::Error),
    /// An error occurred getting response data.
    #[error("Failed to get response data")]
    GetDataFailed(#[source] reqwest::Error),
    /// Server returned an error when uploading file
    #[error("Failed to upload image: {status}: {error}")]
    UploadImageFailed {
        status: reqwest::StatusCode,
        error: String,
    },
}

type Result<T> = std::result::Result<T, UploadApiError>;

/// Struct representing an image.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ImageUpload {
    /// The filename of the image.
    pub name: String,
    /// The subfolder.
    pub subfolder: String,
    /// The folder type.
    #[serde(rename = "type")]
    pub folder_type: String,
}

/// Struct representing a connection to the ComfyUI API `upload` endpoint.
#[derive(Clone, Debug)]
pub struct UploadApi {
    client: reqwest::Client,
    endpoint: Url,
}

impl UploadApi {
    /// Constructs a new `UploadApi` client with a given `reqwest::Client` and ComfyUI API
    /// endpoint.
    ///
    /// # Arguments
    ///
    /// * `client` - A `reqwest::Client` used to send requests.
    /// * `endpoint` - A `str` representation of the endpoint url.
    ///
    /// # Returns
    ///
    /// A `Result` containing a new `UploadApi` instance on success, or an error if url parsing failed.
    pub fn new<S>(client: reqwest::Client, endpoint: S) -> Result<Self>
    where
        S: AsRef<str>,
    {
        Ok(Self::new_with_url(client, Url::parse(endpoint.as_ref())?))
    }

    /// Constructs a new `UploadApi` client with a given `reqwest::Client` and endpoint `Url`.
    ///
    /// # Arguments
    ///
    /// * `client` - A `reqwest::Client` used to send requests.
    /// * `endpoint` - A `Url` representing the endpoint url.
    ///
    /// # Returns
    ///
    /// A new `UploadApi` instance.
    pub fn new_with_url(client: reqwest::Client, endpoint: Url) -> Self {
        Self { client, endpoint }
    }

    /// Uploads an image using the `UploadApi` client.
    ///
    /// # Arguments
    ///
    /// * `image` - A `Vec<u8>` containing the image to upload.
    ///
    /// # Returns
    ///
    /// A `Result` containing an `Image` struct containing information about the image.
    /// success, or an error if the request failed.
    pub async fn image(&self, image: Vec<u8>) -> Result<ImageUpload> {
        let file = multipart::Part::bytes(image)
            .file_name("image.png")
            .mime_str("image/png")
            .map_err(UploadApiError::SetMimeStrFailed)?;
        let form = multipart::Form::new().part("image", file);
        let response = self
            .client
            .post(self.endpoint.clone().join("image")?)
            .multipart(form)
            .send()
            .await
            .map_err(UploadApiError::RequestFailed)?;
        if response.status().is_success() {
            return response
                .json()
                .await
                .map_err(UploadApiError::InvalidResponse);
        }
        let status = response.status();
        let text = response
            .text()
            .await
            .map_err(UploadApiError::GetDataFailed)?;
        Err(UploadApiError::UploadImageFailed {
            status,
            error: text,
        })
    }
}