1use std::collections::HashSet;
2use std::pin::pin;
3
4use anyhow::{anyhow, Context};
5use async_stream::stream;
6use futures_util::{
7 stream::{FusedStream, FuturesOrdered},
8 Stream, StreamExt,
9};
10use uuid::Uuid;
11
12use crate::{
13 api::{self, *},
14 models::*,
15};
16
17pub mod visitor;
18pub use visitor::Visitor;
19
20pub mod setter;
21
22pub mod getter;
23use getter::*;
24
25mod accessors;
26
27use self::setter::SetterExt as _;
28
29enum State {
30 Executing(String, Vec<Image>),
31 Finished(Vec<(String, Vec<Image>)>),
32}
33
34#[derive(Debug, Clone)]
36pub struct NodeOutput {
37 pub node: String,
39 pub image: Vec<u8>,
41}
42
43#[derive(thiserror::Error, Debug)]
45#[non_exhaustive]
46pub enum ComfyApiError {
47 #[error("Failed to create API")]
49 CreateApiFailed(#[from] api::ApiError),
50 #[error("Execution was interrupted: node {} ({})", response.node_id, response.node_type)]
52 ExecutionInterrupted { response: ExecutionInterrupted },
53 #[error("Error occurred during execution: {exception_type}: {exception_message}")]
55 ExecutionError {
56 exception_type: String,
57 exception_message: String,
58 },
59 #[error("Failed to get prompt execution update")]
61 ReceiveUpdateFailure(#[from] api::WebSocketApiError),
62 #[error("Failed to get task for prompt")]
64 PromptTaskNotFound(#[source] api::HistoryApiError),
65 #[error("Failed to send prompt to API")]
67 SendPromptFailed(#[from] PromptApiError),
68 #[error("Failed to get image from API")]
70 GetImageFailed(#[from] ViewApiError),
71 #[error("Failed to upload image to API")]
73 UploadImageFailed(#[from] UploadApiError),
74}
75
76type Result<T> = std::result::Result<T, ComfyApiError>;
77
78#[derive(Clone, Debug)]
80pub struct Comfy {
81 api: Api,
82 history: HistoryApi,
83 upload: UploadApi,
84 view: ViewApi,
85}
86
87impl Default for Comfy {
88 fn default() -> Self {
89 let api = Api::default();
90 Self {
91 history: api.history().expect("failed to create history api"),
92 upload: api.upload().expect("failed to create upload api"),
93 view: api.view().expect("failed to create view api"),
94 api,
95 }
96 }
97}
98
99impl Comfy {
100 pub fn new() -> Result<Self> {
102 let api = Api::default();
103 Ok(Self {
104 history: api.history()?,
105 upload: api.upload()?,
106 view: api.view()?,
107 api,
108 })
109 }
110
111 pub fn new_with_url<S>(url: S) -> Result<Self>
121 where
122 S: AsRef<str>,
123 {
124 let api = Api::new_with_url(url.as_ref())?;
125 Ok(Self {
126 history: api.history()?,
127 upload: api.upload()?,
128 view: api.view()?,
129 api,
130 })
131 }
132
133 pub fn new_with_client_and_url<S>(client: reqwest::Client, url: S) -> Result<Self>
144 where
145 S: AsRef<str>,
146 {
147 let api = Api::new_with_client_and_url(client, url.as_ref())?;
148 Ok(Self {
149 history: api.history()?,
150 upload: api.upload()?,
151 view: api.view()?,
152 api,
153 })
154 }
155
156 async fn filter_update(&self, update: Update, target_prompt_id: Uuid) -> Result<Option<State>> {
157 match update {
158 Update::Executing(data) => {
159 if data.node.is_none() {
160 if let Some(prompt_id) = data.prompt_id {
161 if prompt_id != target_prompt_id {
162 return Ok(None);
163 }
164 let task = self
165 .history
166 .get_prompt(&prompt_id)
167 .await
168 .map_err(ComfyApiError::PromptTaskNotFound)?;
169 let images = task
170 .outputs
171 .nodes
172 .into_iter()
173 .filter_map(|(key, value)| {
174 if let NodeOutputOrUnknown::NodeOutput(output) = value {
175 Some((key, output.images))
176 } else {
177 None
178 }
179 })
180 .collect::<Vec<(String, Vec<Image>)>>();
181 return Ok(Some(State::Finished(images)));
182 }
183 }
184 Ok(None)
185 }
186 Update::Executed(data) => {
187 if data.prompt_id != target_prompt_id {
188 return Ok(None);
189 }
190 Ok(Some(State::Executing(data.node, data.output.images)))
191 }
192 Update::ExecutionInterrupted(data) => {
193 if data.prompt_id != target_prompt_id {
194 return Ok(None);
195 }
196 Err(ComfyApiError::ExecutionInterrupted { response: data })
197 }
198 Update::ExecutionError(data) => {
199 if data.execution_status.prompt_id != target_prompt_id {
200 return Ok(None);
201 }
202 Err(ComfyApiError::ExecutionError {
203 exception_type: data.exception_type,
204 exception_message: data.exception_message,
205 })
206 }
207 _ => Ok(None),
208 }
209 }
210
211 async fn prompt_impl<'a>(
212 &'a self,
213 prompt: &Prompt,
214 ) -> Result<impl Stream<Item = Result<State>> + 'a> {
215 let client_id = Uuid::new_v4();
216 let prompt_api = self.api.prompt_with_client(client_id)?;
217 let websocket_api = self.api.websocket_with_client(client_id)?;
218 let stream = websocket_api
219 .updates()
220 .await
221 .map_err(ComfyApiError::ReceiveUpdateFailure)?;
222 let response = prompt_api.send(prompt).await?;
223 let prompt_id = response.prompt_id;
224 Ok(stream.filter_map(move |msg| async move {
225 match msg {
226 Ok(msg) => match self.filter_update(msg, prompt_id).await {
227 Ok(Some(images)) => Some(Ok(images)),
228 Ok(None) => None,
229 Err(e) => Some(Err(e)),
230 },
231 Err(e) => Some(Err(ComfyApiError::ReceiveUpdateFailure(e))),
232 }
233 }))
234 }
235
236 pub async fn stream_prompt<'a>(
246 &'a self,
247 prompt: &Prompt,
248 ) -> Result<impl FusedStream<Item = Result<NodeOutput>> + 'a> {
249 let stream = self.prompt_impl(prompt).await?;
250 Ok(stream! {
251 let mut executed = HashSet::new();
252 for await msg in stream {
253 match msg {
254 Ok(State::Executing(node, images)) => {
255 executed.insert(node.clone());
256 let fut = images.into_iter().map(|image| async move {
257 self.view.get(&image).await
258 }).collect::<FuturesOrdered<_>>();
259 for await image in fut {
260 yield Ok(NodeOutput { node: node.clone(), image: image? });
261 }
262 }
263 Ok(State::Finished(images)) => {
264 for (node, images) in images {
265 if executed.contains(&node) {
266 continue;
267 }
268 let fut = images.into_iter().map(|image| async move {
269 self.view.get(&image).await
270 }).collect::<FuturesOrdered<_>>();
271 for await image in fut {
272 yield Ok(NodeOutput { node: node.clone(), image: image? });
273 }
274 }
275 return;
276 }
277 Err(e) => Err(e)?,
278 }
279 }
280 })
281 }
282
283 pub async fn execute_prompt(&self, prompt: &Prompt) -> Result<Vec<NodeOutput>> {
293 let mut images = vec![];
294 let mut stream = pin!(self.stream_prompt(prompt).await?);
295 while let Some(image) = stream.next().await {
296 match image {
297 Ok(image) => images.push(image),
298 Err(e) => return Err(e),
299 }
300 }
301 Ok(images)
302 }
303
304 pub async fn upload_file(&self, file: Vec<u8>) -> Result<ImageUpload> {
314 Ok(self.upload.image(file).await?)
315 }
316}
317
318#[derive(Debug, Clone, Default)]
320pub struct ImageInfo {
321 pub prompt: Option<String>,
323 pub negative_prompt: Option<String>,
325 pub model: Option<String>,
327 pub width: Option<u32>,
329 pub height: Option<u32>,
331 pub seed: Option<i64>,
333}
334
335impl ImageInfo {
336 pub fn new_from_prompt(prompt: &Prompt, output_node: &str) -> anyhow::Result<ImageInfo> {
347 let mut image_info = ImageInfo::default();
348 if let Some(node) = prompt.get_node_by_id(output_node) {
349 image_info.visit(prompt, node);
350 } else {
351 return Err(anyhow!("Output node not found: {}", output_node));
352 }
353 Ok(image_info)
354 }
355}
356
357#[derive(Debug, Clone)]
358struct OverrideNode<T> {
359 node: Option<String>,
360 value: T,
361}
362
363impl<T> Default for OverrideNode<T>
364where
365 T: Default,
366{
367 fn default() -> Self {
368 Self {
369 node: Default::default(),
370 value: Default::default(),
371 }
372 }
373}
374
375#[derive(Debug, Clone)]
377pub struct PromptBuilder {
378 base_prompt: Prompt,
379 output_node: Option<String>,
380 prompt: Option<OverrideNode<String>>,
381 negative_prompt: Option<OverrideNode<String>>,
382 model: Option<OverrideNode<String>>,
383 width: Option<OverrideNode<u32>>,
384 height: Option<OverrideNode<u32>>,
385 seed: Option<OverrideNode<i64>>,
386}
387
388impl PromptBuilder {
389 pub fn new(base_prompt: &Prompt, output_node: Option<String>) -> Self {
400 Self {
401 prompt: None,
402 negative_prompt: None,
403 model: None,
404 width: None,
405 height: None,
406 seed: None,
407 base_prompt: base_prompt.clone(),
408 output_node,
409 }
410 }
411
412 pub fn prompt(mut self, value: String, node: Option<String>) -> Self {
419 self.prompt = Some(OverrideNode { node, value });
420 self
421 }
422
423 pub fn negative_prompt(mut self, value: String, node: Option<String>) -> Self {
430 self.negative_prompt = Some(OverrideNode { node, value });
431 self
432 }
433
434 pub fn model(mut self, value: String, node: Option<String>) -> Self {
441 self.model = Some(OverrideNode { node, value });
442 self
443 }
444
445 pub fn width(mut self, value: u32, node: Option<String>) -> Self {
452 self.width = Some(OverrideNode { node, value });
453 self
454 }
455
456 pub fn height(mut self, value: u32, node: Option<String>) -> Self {
463 self.height = Some(OverrideNode { node, value });
464 self
465 }
466
467 pub fn seed(mut self, value: i64, node: Option<String>) -> Self {
474 self.seed = Some(OverrideNode { node, value });
475 self
476 }
477
478 pub fn build(mut self) -> anyhow::Result<Prompt> {
484 let mut new_prompt = self.base_prompt.clone();
485
486 if self.output_node.is_none() {
487 self.output_node = Some(
488 find_output_node(&new_prompt).context("failed to find a suitable output node")?,
489 );
490 }
491
492 if let Some(ref prompt) = self.prompt {
493 if let Some(ref node) = prompt.node {
494 new_prompt.set_node::<accessors::Prompt>(node, prompt.value.clone())?;
495 } else {
496 new_prompt.set_from::<accessors::Prompt>(
497 &self.output_node.clone().unwrap(),
498 prompt.value.clone(),
499 )?;
500 }
501 }
502 if let Some(ref negative_prompt) = self.negative_prompt {
503 if let Some(ref node) = negative_prompt.node {
504 new_prompt
505 .set_node::<accessors::NegativePrompt>(node, negative_prompt.value.clone())?;
506 } else {
507 new_prompt.set_from::<accessors::NegativePrompt>(
508 &self.output_node.clone().unwrap(),
509 negative_prompt.value.clone(),
510 )?;
511 }
512 }
513 if let Some(ref model) = self.model {
514 if let Some(ref node) = model.node {
515 new_prompt.set_node::<accessors::Model>(node, model.value.clone())?;
516 } else {
517 new_prompt.set_from::<accessors::Model>(
518 &self.output_node.clone().unwrap(),
519 model.value.clone(),
520 )?;
521 }
522 }
523 if let Some(width) = self.width {
524 if let Some(ref node) = width.node {
525 new_prompt.set_node::<accessors::Width>(node, width.value)?;
526 } else {
527 new_prompt.set_from::<accessors::Width>(
528 &self.output_node.clone().unwrap(),
529 width.value,
530 )?;
531 }
532 }
533 if let Some(height) = self.height {
534 if let Some(ref node) = height.node {
535 new_prompt.set_node::<accessors::Height>(node, height.value)?;
536 } else {
537 new_prompt.set_from::<accessors::Height>(
538 &self.output_node.clone().unwrap(),
539 height.value,
540 )?;
541 }
542 }
543 if let Some(ref seed) = self.seed {
544 if let Some(ref node) = seed.node {
545 new_prompt.set_node::<accessors::Seed>(node, seed.value)?;
546 } else {
547 new_prompt
548 .set_from::<accessors::Seed>(&self.output_node.clone().unwrap(), seed.value)?;
549 }
550 }
551 Ok(new_prompt)
552 }
553}