comfyui_api/api/
websocket.rs1use futures_util::{stream::FusedStream, StreamExt};
2use reqwest::Url;
3use tokio_tungstenite::{connect_async, tungstenite::Message};
4use tracing::warn;
5
6use crate::models::{Preview, PreviewOrUpdate, Update};
7
8#[derive(thiserror::Error, Debug)]
10#[non_exhaustive]
11pub enum WebSocketApiError {
12 #[error("Failed to parse endpoint URL")]
14 ParseError(#[from] url::ParseError),
15 #[error("Failed to parse endpoint URL")]
17 ConnectFailed(#[from] tokio_tungstenite::tungstenite::Error),
18 #[error("Parsing response failed")]
20 InvalidResponse(#[from] serde_json::Error),
21 #[error("Error occurred while reading websocket message")]
23 ReadFailed(#[source] tokio_tungstenite::tungstenite::Error),
24}
25
26type Result<T> = std::result::Result<T, WebSocketApiError>;
27
28#[derive(Clone, Debug)]
30pub struct WebsocketApi {
31 endpoint: Url,
32}
33
34impl WebsocketApi {
35 pub fn new<S>(endpoint: S) -> Result<Self>
45 where
46 S: AsRef<str>,
47 {
48 Ok(Self::new_with_url(Url::parse(endpoint.as_ref())?))
49 }
50
51 pub fn new_with_url(endpoint: Url) -> Self {
61 Self { endpoint }
62 }
63
64 async fn connect_to_endpoint(
65 &self,
66 endpoint: &Url,
67 ) -> Result<impl FusedStream<Item = Result<PreviewOrUpdate>>> {
68 let (connection, _) = connect_async(endpoint).await?;
69 Ok(connection.filter_map(|m| async {
70 match m {
71 Ok(m) => match m {
72 Message::Text(t) => Some(
73 serde_json::from_str::<Update>(t.as_str())
74 .map(PreviewOrUpdate::Update)
75 .map_err(WebSocketApiError::InvalidResponse),
76 ),
77 Message::Binary(_) => {
78 Some(Ok(PreviewOrUpdate::Preview(Preview(m.into_data()))))
79 }
80 _ => {
81 warn!("unexpected websocket message type");
82 None
83 }
84 },
85 Err(e) => Some(Err(WebSocketApiError::ReadFailed(e))),
86 }
87 }))
88 }
89
90 async fn connect_impl(&self) -> Result<impl FusedStream<Item = Result<PreviewOrUpdate>>> {
91 self.connect_to_endpoint(&self.endpoint).await
92 }
93
94 pub async fn connect(&self) -> Result<impl FusedStream<Item = Result<PreviewOrUpdate>>> {
101 self.connect_impl().await
102 }
103
104 pub async fn updates(&self) -> Result<impl FusedStream<Item = Result<Update>>> {
110 Ok(self.connect_impl().await?.filter_map(|m| async {
111 match m {
112 Ok(PreviewOrUpdate::Update(u)) => Some(Ok(u)),
113 Ok(PreviewOrUpdate::Preview(_)) => None,
114 Err(e) => Some(Err(e)),
115 }
116 }))
117 }
118
119 pub async fn previews(&self) -> Result<impl FusedStream<Item = Result<Preview>>> {
125 Ok(self.connect_impl().await?.filter_map(|m| async {
126 match m {
127 Ok(PreviewOrUpdate::Update(_)) => None,
128 Ok(PreviewOrUpdate::Preview(p)) => Some(Ok(p)),
129 Err(e) => Some(Err(e)),
130 }
131 }))
132 }
133}