comfyui_api/api/
websocket.rs

1use futures_util::{stream::FusedStream, StreamExt};
2use reqwest::Url;
3use tokio_tungstenite::{connect_async, tungstenite::Message};
4use tracing::warn;
5
6use crate::models::{Preview, PreviewOrUpdate, Update};
7
8/// Errors that can occur when interacting with `WebSocketApi`.
9#[derive(thiserror::Error, Debug)]
10#[non_exhaustive]
11pub enum WebSocketApiError {
12    /// Error parsing endpoint URL
13    #[error("Failed to parse endpoint URL")]
14    ParseError(#[from] url::ParseError),
15    /// Error parsing endpoint URL
16    #[error("Failed to parse endpoint URL")]
17    ConnectFailed(#[from] tokio_tungstenite::tungstenite::Error),
18    /// An error occurred while parsing the response from the API.
19    #[error("Parsing response failed")]
20    InvalidResponse(#[from] serde_json::Error),
21    /// An error occurred while reading websocket message.
22    #[error("Error occurred while reading websocket message")]
23    ReadFailed(#[source] tokio_tungstenite::tungstenite::Error),
24}
25
26type Result<T> = std::result::Result<T, WebSocketApiError>;
27
28/// Struct representing a connection to the ComfyUI API `ws` endpoint.
29#[derive(Clone, Debug)]
30pub struct WebsocketApi {
31    endpoint: Url,
32}
33
34impl WebsocketApi {
35    /// Constructs a new `WebsocketApi` client with a given ComfyUI API endpoint.
36    ///
37    /// # Arguments
38    ///
39    /// * `endpoint` - A `str` representation of the endpoint url.
40    ///
41    /// # Returns
42    ///
43    /// A `Result` containing a new `WebsocketApi` instance on success, or an error if url parsing failed.
44    pub fn new<S>(endpoint: S) -> Result<Self>
45    where
46        S: AsRef<str>,
47    {
48        Ok(Self::new_with_url(Url::parse(endpoint.as_ref())?))
49    }
50
51    /// Constructs a new `WebsocketApi` client with a given endpoint `Url`.
52    ///
53    /// # Arguments
54    ///
55    /// * `endpoint` - A `Url` representing the endpoint url.
56    ///
57    /// # Returns
58    ///
59    /// A new `WebsocketApi` instance.
60    pub fn new_with_url(endpoint: Url) -> Self {
61        Self { endpoint }
62    }
63
64    async fn connect_to_endpoint(
65        &self,
66        endpoint: &Url,
67    ) -> Result<impl FusedStream<Item = Result<PreviewOrUpdate>>> {
68        let (connection, _) = connect_async(endpoint).await?;
69        Ok(connection.filter_map(|m| async {
70            match m {
71                Ok(m) => match m {
72                    Message::Text(t) => Some(
73                        serde_json::from_str::<Update>(t.as_str())
74                            .map(PreviewOrUpdate::Update)
75                            .map_err(WebSocketApiError::InvalidResponse),
76                    ),
77                    Message::Binary(_) => {
78                        Some(Ok(PreviewOrUpdate::Preview(Preview(m.into_data()))))
79                    }
80                    _ => {
81                        warn!("unexpected websocket message type");
82                        None
83                    }
84                },
85                Err(e) => Some(Err(WebSocketApiError::ReadFailed(e))),
86            }
87        }))
88    }
89
90    async fn connect_impl(&self) -> Result<impl FusedStream<Item = Result<PreviewOrUpdate>>> {
91        self.connect_to_endpoint(&self.endpoint).await
92    }
93
94    /// Connects to the websocket endpoint and returns a stream of `PreviewOrUpdate` values.
95    ///
96    /// # Returns
97    ///
98    /// A `Stream` of `PreviewOrUpdate` values. These are either `Update` values, which contain
99    /// progress updates for a task, or `Preview` values, which contain a preview image.
100    pub async fn connect(&self) -> Result<impl FusedStream<Item = Result<PreviewOrUpdate>>> {
101        self.connect_impl().await
102    }
103
104    /// Connects to the websocket endpoint and returns a stream of `Update` values.
105    ///
106    /// # Returns
107    ///
108    /// A `Stream` of `Update` values. These contain progress updates for a task.
109    pub async fn updates(&self) -> Result<impl FusedStream<Item = Result<Update>>> {
110        Ok(self.connect_impl().await?.filter_map(|m| async {
111            match m {
112                Ok(PreviewOrUpdate::Update(u)) => Some(Ok(u)),
113                Ok(PreviewOrUpdate::Preview(_)) => None,
114                Err(e) => Some(Err(e)),
115            }
116        }))
117    }
118
119    /// Connects to the websocket endpoint and returns a stream of `Preview` values.
120    ///
121    /// # Returns
122    ///
123    /// A `Stream` of `Preview` values. These contain preview images.
124    pub async fn previews(&self) -> Result<impl FusedStream<Item = Result<Preview>>> {
125        Ok(self.connect_impl().await?.filter_map(|m| async {
126            match m {
127                Ok(PreviewOrUpdate::Update(_)) => None,
128                Ok(PreviewOrUpdate::Preview(p)) => Some(Ok(p)),
129                Err(e) => Some(Err(e)),
130            }
131        }))
132    }
133}