comfyui_api/comfy/
mod.rs

1use std::collections::HashSet;
2use std::pin::pin;
3
4use anyhow::{anyhow, Context};
5use async_stream::stream;
6use futures_util::{
7    stream::{FusedStream, FuturesOrdered},
8    Stream, StreamExt,
9};
10use uuid::Uuid;
11
12use crate::{
13    api::{self, *},
14    models::*,
15};
16
17pub mod visitor;
18pub use visitor::Visitor;
19
20pub mod setter;
21
22pub mod getter;
23use getter::*;
24
25mod accessors;
26
27use self::setter::SetterExt as _;
28
29enum State {
30    Executing(String, Vec<Image>),
31    Finished(Vec<(String, Vec<Image>)>),
32}
33
34/// Output from a node.
35#[derive(Debug, Clone)]
36pub struct NodeOutput {
37    /// The identifier of the node.
38    pub node: String,
39    /// The image generated by the node.
40    pub image: Vec<u8>,
41}
42
43/// Errors that can occur opening API endpoints.
44#[derive(thiserror::Error, Debug)]
45#[non_exhaustive]
46pub enum ComfyApiError {
47    /// Error parsing endpoint URL
48    #[error("Failed to create API")]
49    CreateApiFailed(#[from] api::ApiError),
50    /// Execution was interrupted
51    #[error("Execution was interrupted: node {} ({})", response.node_id, response.node_type)]
52    ExecutionInterrupted { response: ExecutionInterrupted },
53    /// Error occurred during execution
54    #[error("Error occurred during execution: {exception_type}: {exception_message}")]
55    ExecutionError {
56        exception_type: String,
57        exception_message: String,
58    },
59    /// Connection error occurred during prompt execution
60    #[error("Failed to get prompt execution update")]
61    ReceiveUpdateFailure(#[from] api::WebSocketApiError),
62    /// Prompt task not found
63    #[error("Failed to get task for prompt")]
64    PromptTaskNotFound(#[source] api::HistoryApiError),
65    /// Error sending prompt to API
66    #[error("Failed to send prompt to API")]
67    SendPromptFailed(#[from] PromptApiError),
68    /// Error getting image from API
69    #[error("Failed to get image from API")]
70    GetImageFailed(#[from] ViewApiError),
71    /// Error uploading image to API
72    #[error("Failed to upload image to API")]
73    UploadImageFailed(#[from] UploadApiError),
74}
75
76type Result<T> = std::result::Result<T, ComfyApiError>;
77
78/// Higher-level API for interacting with the ComfyUI API.
79#[derive(Clone, Debug)]
80pub struct Comfy {
81    api: Api,
82    history: HistoryApi,
83    upload: UploadApi,
84    view: ViewApi,
85}
86
87impl Default for Comfy {
88    fn default() -> Self {
89        let api = Api::default();
90        Self {
91            history: api.history().expect("failed to create history api"),
92            upload: api.upload().expect("failed to create upload api"),
93            view: api.view().expect("failed to create view api"),
94            api,
95        }
96    }
97}
98
99impl Comfy {
100    /// Returns a new `Comfy` instance with default settings.
101    pub fn new() -> Result<Self> {
102        let api = Api::default();
103        Ok(Self {
104            history: api.history()?,
105            upload: api.upload()?,
106            view: api.view()?,
107            api,
108        })
109    }
110
111    /// Returns a new `Comfy` instance with the given URL as a string value.
112    ///
113    /// # Arguments
114    ///
115    /// * `url` - A string that specifies the ComfyUI API URL endpoint.
116    ///
117    /// # Errors
118    ///
119    /// If the URL fails to parse, an error will be returned.
120    pub fn new_with_url<S>(url: S) -> Result<Self>
121    where
122        S: AsRef<str>,
123    {
124        let api = Api::new_with_url(url.as_ref())?;
125        Ok(Self {
126            history: api.history()?,
127            upload: api.upload()?,
128            view: api.view()?,
129            api,
130        })
131    }
132
133    /// Returns a new `Comfy` instance with the given `reqwest::Client` and URL as a string value.
134    ///
135    /// # Arguments
136    ///
137    /// * `client` - An instance of `reqwest::Client`.
138    /// * `url` - A string that specifies the ComfyUI API URL endpoint.
139    ///
140    /// # Errors
141    ///
142    /// If the URL fails to parse, an error will be returned.
143    pub fn new_with_client_and_url<S>(client: reqwest::Client, url: S) -> Result<Self>
144    where
145        S: AsRef<str>,
146    {
147        let api = Api::new_with_client_and_url(client, url.as_ref())?;
148        Ok(Self {
149            history: api.history()?,
150            upload: api.upload()?,
151            view: api.view()?,
152            api,
153        })
154    }
155
156    async fn filter_update(&self, update: Update, target_prompt_id: Uuid) -> Result<Option<State>> {
157        match update {
158            Update::Executing(data) => {
159                if data.node.is_none() {
160                    if let Some(prompt_id) = data.prompt_id {
161                        if prompt_id != target_prompt_id {
162                            return Ok(None);
163                        }
164                        let task = self
165                            .history
166                            .get_prompt(&prompt_id)
167                            .await
168                            .map_err(ComfyApiError::PromptTaskNotFound)?;
169                        let images = task
170                            .outputs
171                            .nodes
172                            .into_iter()
173                            .filter_map(|(key, value)| {
174                                if let NodeOutputOrUnknown::NodeOutput(output) = value {
175                                    Some((key, output.images))
176                                } else {
177                                    None
178                                }
179                            })
180                            .collect::<Vec<(String, Vec<Image>)>>();
181                        return Ok(Some(State::Finished(images)));
182                    }
183                }
184                Ok(None)
185            }
186            Update::Executed(data) => {
187                if data.prompt_id != target_prompt_id {
188                    return Ok(None);
189                }
190                Ok(Some(State::Executing(data.node, data.output.images)))
191            }
192            Update::ExecutionInterrupted(data) => {
193                if data.prompt_id != target_prompt_id {
194                    return Ok(None);
195                }
196                Err(ComfyApiError::ExecutionInterrupted { response: data })
197            }
198            Update::ExecutionError(data) => {
199                if data.execution_status.prompt_id != target_prompt_id {
200                    return Ok(None);
201                }
202                Err(ComfyApiError::ExecutionError {
203                    exception_type: data.exception_type,
204                    exception_message: data.exception_message,
205                })
206            }
207            _ => Ok(None),
208        }
209    }
210
211    async fn prompt_impl<'a>(
212        &'a self,
213        prompt: &Prompt,
214    ) -> Result<impl Stream<Item = Result<State>> + 'a> {
215        let client_id = Uuid::new_v4();
216        let prompt_api = self.api.prompt_with_client(client_id)?;
217        let websocket_api = self.api.websocket_with_client(client_id)?;
218        let stream = websocket_api
219            .updates()
220            .await
221            .map_err(ComfyApiError::ReceiveUpdateFailure)?;
222        let response = prompt_api.send(prompt).await?;
223        let prompt_id = response.prompt_id;
224        Ok(stream.filter_map(move |msg| async move {
225            match msg {
226                Ok(msg) => match self.filter_update(msg, prompt_id).await {
227                    Ok(Some(images)) => Some(Ok(images)),
228                    Ok(None) => None,
229                    Err(e) => Some(Err(e)),
230                },
231                Err(e) => Some(Err(ComfyApiError::ReceiveUpdateFailure(e))),
232            }
233        }))
234    }
235
236    /// Executes a prompt and returns a stream of generated images.
237    ///
238    /// # Arguments
239    ///
240    /// * `prompt` - A `Prompt` to send to the ComfyUI API.
241    ///
242    /// # Returns
243    ///
244    /// A `Result` containing a `Stream` of `Result<NodeOutput>` values on success, or an error if the request failed.
245    pub async fn stream_prompt<'a>(
246        &'a self,
247        prompt: &Prompt,
248    ) -> Result<impl FusedStream<Item = Result<NodeOutput>> + 'a> {
249        let stream = self.prompt_impl(prompt).await?;
250        Ok(stream! {
251            let mut executed = HashSet::new();
252            for await msg in stream {
253                match msg {
254                    Ok(State::Executing(node, images)) => {
255                        executed.insert(node.clone());
256                        let fut = images.into_iter().map(|image| async move {
257                            self.view.get(&image).await
258                        }).collect::<FuturesOrdered<_>>();
259                        for await image in fut {
260                            yield Ok(NodeOutput { node: node.clone(), image: image? });
261                        }
262                    }
263                    Ok(State::Finished(images)) => {
264                        for (node, images) in images {
265                            if executed.contains(&node) {
266                                continue;
267                            }
268                            let fut = images.into_iter().map(|image| async move {
269                                self.view.get(&image).await
270                            }).collect::<FuturesOrdered<_>>();
271                            for await image in fut {
272                                yield Ok(NodeOutput { node: node.clone(), image: image? });
273                            }
274                        }
275                        return;
276                    }
277                    Err(e) => Err(e)?,
278                }
279            }
280        })
281    }
282
283    /// Executes a prompt and returns the generated images.
284    ///
285    /// # Arguments
286    ///
287    /// * `prompt` - A `Prompt` to send to the ComfyUI API.
288    ///
289    /// # Returns
290    ///
291    /// A `Result` containing a `Vec<NodeOutput>` on success, or an error if the request failed.
292    pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<Vec<NodeOutput>> {
293        let mut images = vec![];
294        let mut stream = pin!(self.stream_prompt(prompt).await?);
295        while let Some(image) = stream.next().await {
296            match image {
297                Ok(image) => images.push(image),
298                Err(e) => return Err(e),
299            }
300        }
301        Ok(images)
302    }
303
304    /// Uploads a file to the ComfyUI API and returns information about the uploaded image.
305    ///
306    /// # Arguments
307    ///
308    /// * `file` - A `Vec<u8>` containing the file data to upload.
309    ///
310    /// # Returns
311    ///
312    /// A `Result` containing an `ImageUpload` on success, or an error if the request failed.
313    pub async fn upload_file(&self, file: Vec<u8>) -> Result<ImageUpload> {
314        Ok(self.upload.image(file).await?)
315    }
316}
317
318/// Information about the generated image.
319#[derive(Debug, Clone, Default)]
320pub struct ImageInfo {
321    /// The prompt used to generate the image.
322    pub prompt: Option<String>,
323    /// The negative prompt used to generate the image.
324    pub negative_prompt: Option<String>,
325    /// The model used to generate the image.
326    pub model: Option<String>,
327    /// The width of the image.
328    pub width: Option<u32>,
329    /// The height of the image.
330    pub height: Option<u32>,
331    /// The seed used to generate the image.
332    pub seed: Option<i64>,
333}
334
335impl ImageInfo {
336    /// Returns a new `ImageInfo` instance based on the given `Prompt` and output node.
337    ///
338    /// # Arguments
339    ///
340    /// * `prompt` - A `Prompt` describing the workflow used to generate an image.
341    /// * `output_node` - The output node that produced the image.
342    ///
343    /// # Returns
344    ///
345    /// A `Result` containing a new `ImageInfo` instance on success, or an error if the output node was not found.
346    pub fn new_from_prompt(prompt: &Prompt, output_node: &str) -> anyhow::Result<ImageInfo> {
347        let mut image_info = ImageInfo::default();
348        if let Some(node) = prompt.get_node_by_id(output_node) {
349            image_info.visit(prompt, node);
350        } else {
351            return Err(anyhow!("Output node not found: {}", output_node));
352        }
353        Ok(image_info)
354    }
355}
356
357#[derive(Debug, Clone)]
358struct OverrideNode<T> {
359    node: Option<String>,
360    value: T,
361}
362
363impl<T> Default for OverrideNode<T>
364where
365    T: Default,
366{
367    fn default() -> Self {
368        Self {
369            node: Default::default(),
370            value: Default::default(),
371        }
372    }
373}
374
375/// A builder for creating a `Prompt` instance.
376#[derive(Debug, Clone)]
377pub struct PromptBuilder {
378    base_prompt: Prompt,
379    output_node: Option<String>,
380    prompt: Option<OverrideNode<String>>,
381    negative_prompt: Option<OverrideNode<String>>,
382    model: Option<OverrideNode<String>>,
383    width: Option<OverrideNode<u32>>,
384    height: Option<OverrideNode<u32>>,
385    seed: Option<OverrideNode<i64>>,
386}
387
388impl PromptBuilder {
389    /// Constructs a new `PromptBuilder` instance.
390    ///
391    /// # Arguments
392    ///
393    /// * `base_prompt` - The base `Prompt` to use as a starting point.
394    /// * `output_node` - The output node to use when building the prompt.
395    ///
396    /// # Returns
397    ///
398    /// A new `PromptBuilder` instance.
399    pub fn new(base_prompt: &Prompt, output_node: Option<String>) -> Self {
400        Self {
401            prompt: None,
402            negative_prompt: None,
403            model: None,
404            width: None,
405            height: None,
406            seed: None,
407            base_prompt: base_prompt.clone(),
408            output_node,
409        }
410    }
411
412    /// Sets the prompt.
413    ///
414    /// # Arguments
415    ///
416    /// * `value` - The prompt value to use.
417    /// * `node` - The node to set the prompt on.
418    pub fn prompt(mut self, value: String, node: Option<String>) -> Self {
419        self.prompt = Some(OverrideNode { node, value });
420        self
421    }
422
423    /// Sets the negative prompt.
424    ///
425    /// # Arguments
426    ///
427    /// * `value` - The negative prompt value to use.
428    /// * `node` - The node to set the negative prompt on.
429    pub fn negative_prompt(mut self, value: String, node: Option<String>) -> Self {
430        self.negative_prompt = Some(OverrideNode { node, value });
431        self
432    }
433
434    /// Sets the model.
435    ///
436    /// # Arguments
437    ///
438    /// * `value` - The model value to use.
439    /// * `node` - The node to set the model on.
440    pub fn model(mut self, value: String, node: Option<String>) -> Self {
441        self.model = Some(OverrideNode { node, value });
442        self
443    }
444
445    /// Sets the width.
446    ///
447    /// # Arguments
448    ///
449    /// * `value` - The width value to use.
450    /// * `node` - The node to set the width on.
451    pub fn width(mut self, value: u32, node: Option<String>) -> Self {
452        self.width = Some(OverrideNode { node, value });
453        self
454    }
455
456    /// Sets the height.
457    ///
458    /// # Arguments
459    ///
460    /// * `value` - The height value to use.
461    /// * `node` - The node to set the height on.
462    pub fn height(mut self, value: u32, node: Option<String>) -> Self {
463        self.height = Some(OverrideNode { node, value });
464        self
465    }
466
467    /// Sets the seed.
468    ///
469    /// # Arguments
470    ///
471    /// * `value` - The seed value to use.
472    /// * `node` - The node to set the seed on.
473    pub fn seed(mut self, value: i64, node: Option<String>) -> Self {
474        self.seed = Some(OverrideNode { node, value });
475        self
476    }
477
478    /// Builds a new `Prompt` instance based on the given parameters.
479    ///
480    /// # Returns
481    ///
482    /// A `Result` containing a new `Prompt` instance on success, or an error if a suitable output node could not be found.
483    pub fn build(mut self) -> anyhow::Result<Prompt> {
484        let mut new_prompt = self.base_prompt.clone();
485
486        if self.output_node.is_none() {
487            self.output_node = Some(
488                find_output_node(&new_prompt).context("failed to find a suitable output node")?,
489            );
490        }
491
492        if let Some(ref prompt) = self.prompt {
493            if let Some(ref node) = prompt.node {
494                new_prompt.set_node::<accessors::Prompt>(node, prompt.value.clone())?;
495            } else {
496                new_prompt.set_from::<accessors::Prompt>(
497                    &self.output_node.clone().unwrap(),
498                    prompt.value.clone(),
499                )?;
500            }
501        }
502        if let Some(ref negative_prompt) = self.negative_prompt {
503            if let Some(ref node) = negative_prompt.node {
504                new_prompt
505                    .set_node::<accessors::NegativePrompt>(node, negative_prompt.value.clone())?;
506            } else {
507                new_prompt.set_from::<accessors::NegativePrompt>(
508                    &self.output_node.clone().unwrap(),
509                    negative_prompt.value.clone(),
510                )?;
511            }
512        }
513        if let Some(ref model) = self.model {
514            if let Some(ref node) = model.node {
515                new_prompt.set_node::<accessors::Model>(node, model.value.clone())?;
516            } else {
517                new_prompt.set_from::<accessors::Model>(
518                    &self.output_node.clone().unwrap(),
519                    model.value.clone(),
520                )?;
521            }
522        }
523        if let Some(width) = self.width {
524            if let Some(ref node) = width.node {
525                new_prompt.set_node::<accessors::Width>(node, width.value)?;
526            } else {
527                new_prompt.set_from::<accessors::Width>(
528                    &self.output_node.clone().unwrap(),
529                    width.value,
530                )?;
531            }
532        }
533        if let Some(height) = self.height {
534            if let Some(ref node) = height.node {
535                new_prompt.set_node::<accessors::Height>(node, height.value)?;
536            } else {
537                new_prompt.set_from::<accessors::Height>(
538                    &self.output_node.clone().unwrap(),
539                    height.value,
540                )?;
541            }
542        }
543        if let Some(ref seed) = self.seed {
544            if let Some(ref node) = seed.node {
545                new_prompt.set_node::<accessors::Seed>(node, seed.value)?;
546            } else {
547                new_prompt
548                    .set_from::<accessors::Seed>(&self.output_node.clone().unwrap(), seed.value)?;
549            }
550        }
551        Ok(new_prompt)
552    }
553}