comfyui_api/models/
websocket.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5/// An enum representing a websocket message.
6#[allow(clippy::large_enum_variant)]
7#[derive(Serialize, Deserialize, Debug)]
8#[serde(untagged)]
9pub enum PreviewOrUpdate {
10    /// Enum variant representing an image preview.
11    Preview(Preview),
12    /// Enum variant representing an update.
13    Update(Update),
14}
15
16/// Struct representing an image preview.
17#[derive(Default, Serialize, Deserialize, Debug)]
18pub struct Preview(pub Vec<u8>);
19
20/// Enum of possible update variants.
21#[derive(Serialize, Deserialize, Debug)]
22#[serde(tag = "type", content = "data")]
23#[serde(rename_all = "snake_case")]
24pub enum Update {
25    /// Enum variant representing a status update.
26    Status { status: Status },
27    /// Enum variant representing a progress update.
28    Progress(Progress),
29    /// Enum variant representing an execution start update.
30    ExecutionStart(ExecutionStart),
31    /// Enum variant representing an executing update.
32    Executing(Executing),
33    /// Enum variant representing an executed update.
34    Executed(Executed),
35    /// Enum variant representing an execution cached update.
36    ExecutionCached(ExecutionCached),
37    /// Enum variant representing an execution interrupted update.
38    ExecutionInterrupted(ExecutionInterrupted),
39    /// Enum variant representing an execution error update.
40    ExecutionError(ExecutionError),
41}
42
43/// Struct representing a progress update.
44#[derive(Serialize, Deserialize, Debug)]
45pub struct Progress {
46    /// The current progress value.
47    pub value: u64,
48    /// The maximum progress value.
49    pub max: u64,
50}
51
52/// Struct representing a status update.
53#[derive(Serialize, Deserialize, Debug)]
54pub struct Status {
55    /// The current status.
56    pub exec_info: ExecInfo,
57}
58
59/// Struct representing execution information.
60#[derive(Serialize, Deserialize, Debug)]
61pub struct ExecInfo {
62    /// Number of items remaining in the queue.
63    pub queue_remaining: u64,
64}
65
66/// Struct representing an execution start update.
67#[derive(Serialize, Deserialize, Debug)]
68pub struct ExecutionStart {
69    /// The prompt id.
70    pub prompt_id: uuid::Uuid,
71}
72
73/// Struct representing an execution cached update.
74#[derive(Serialize, Deserialize, Debug)]
75pub struct ExecutionCached {
76    /// The prompt id.
77    pub prompt_id: uuid::Uuid,
78    /// The ids of the nodes that were cached.
79    pub nodes: Vec<String>,
80}
81
82/// Struct representing an executing update.
83#[derive(Serialize, Deserialize, Debug)]
84pub struct Executing {
85    /// The prompt id. None if reconnecting to a session.
86    pub prompt_id: Option<uuid::Uuid>,
87    /// The node that is executing. None if execution is finished.
88    pub node: Option<String>,
89}
90
91/// Struct representing an executed update.
92#[derive(Serialize, Deserialize, Debug)]
93pub struct Executed {
94    /// The prompt id.
95    pub prompt_id: uuid::Uuid,
96    /// The node that was executed.
97    pub node: String,
98    /// The output of the node.
99    pub output: Output,
100}
101
102/// Struct representing an output.
103#[derive(Serialize, Deserialize, Debug)]
104pub struct Output {
105    /// A list of images.
106    pub images: Vec<Image>,
107}
108
109/// Struct representing an image.
110#[derive(Serialize, Deserialize, Debug, Clone)]
111pub struct Image {
112    /// The filename of the image.
113    pub filename: String,
114    /// The subfolder.
115    pub subfolder: String,
116    /// The folder type.
117    #[serde(rename = "type")]
118    pub folder_type: String,
119}
120
121/// Struct representing an execution interrupted update.
122#[derive(Serialize, Deserialize, Debug)]
123pub struct ExecutionInterrupted {
124    /// The prompt id.
125    pub prompt_id: uuid::Uuid,
126    /// The node that was executing.
127    pub node_id: String,
128    /// The type of the node.
129    pub node_type: String,
130    /// What was executed prior to interruption.
131    pub executed: Vec<String>,
132}
133
134/// Struct representing an execution error update.
135#[derive(Serialize, Deserialize, Debug)]
136pub struct ExecutionError {
137    /// The state of execution that was interrupted.
138    #[serde(flatten)]
139    pub execution_status: ExecutionInterrupted,
140    /// The exception message.
141    pub exception_message: String,
142    /// The exception type.
143    pub exception_type: String,
144    /// The traceback.
145    pub traceback: Vec<String>,
146    /// The current inputs.
147    pub current_inputs: CurrentInputs,
148    /// The current outputs.
149    pub current_outputs: CurrentOutputs,
150}
151
152/// Struct representing the current inputs when the execution error occurred.
153#[derive(Serialize, Deserialize, Debug)]
154#[serde(transparent)]
155pub struct CurrentInputs {
156    /// Hashmap of inputs keyed by input name.
157    pub inputs: HashMap<String, serde_json::Value>,
158}
159
160/// Struct representing the current outputs when the execution error occurred.
161#[derive(Serialize, Deserialize, Debug)]
162#[serde(transparent)]
163pub struct CurrentOutputs {
164    /// Hashmap of outputs keyed by node id.
165    pub outputs: HashMap<String, Vec<serde_json::Value>>,
166}