1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
use std::collections::HashMap;

use serde::{Deserialize, Serialize};

/// An enum representing a websocket message.
#[allow(clippy::large_enum_variant)]
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum PreviewOrUpdate {
    /// Enum variant representing an image preview.
    Preview(Preview),
    /// Enum variant representing an update.
    Update(Update),
}

/// Struct representing an image preview.
#[derive(Default, Serialize, Deserialize, Debug)]
pub struct Preview(pub Vec<u8>);

/// Enum of possible update variants.
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "snake_case")]
pub enum Update {
    /// Enum variant representing a status update.
    Status { status: Status },
    /// Enum variant representing a progress update.
    Progress(Progress),
    /// Enum variant representing an execution start update.
    ExecutionStart(ExecutionStart),
    /// Enum variant representing an executing update.
    Executing(Executing),
    /// Enum variant representing an executed update.
    Executed(Executed),
    /// Enum variant representing an execution cached update.
    ExecutionCached(ExecutionCached),
    /// Enum variant representing an execution interrupted update.
    ExecutionInterrupted(ExecutionInterrupted),
    /// Enum variant representing an execution error update.
    ExecutionError(ExecutionError),
}

/// Struct representing a progress update.
#[derive(Serialize, Deserialize, Debug)]
pub struct Progress {
    /// The current progress value.
    pub value: u64,
    /// The maximum progress value.
    pub max: u64,
}

/// Struct representing a status update.
#[derive(Serialize, Deserialize, Debug)]
pub struct Status {
    /// The current status.
    pub exec_info: ExecInfo,
}

/// Struct representing execution information.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecInfo {
    /// Number of items remaining in the queue.
    pub queue_remaining: u64,
}

/// Struct representing an execution start update.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionStart {
    /// The prompt id.
    pub prompt_id: uuid::Uuid,
}

/// Struct representing an execution cached update.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionCached {
    /// The prompt id.
    pub prompt_id: uuid::Uuid,
    /// The ids of the nodes that were cached.
    pub nodes: Vec<String>,
}

/// Struct representing an executing update.
#[derive(Serialize, Deserialize, Debug)]
pub struct Executing {
    /// The prompt id. None if reconnecting to a session.
    pub prompt_id: Option<uuid::Uuid>,
    /// The node that is executing. None if execution is finished.
    pub node: Option<String>,
}

/// Struct representing an executed update.
#[derive(Serialize, Deserialize, Debug)]
pub struct Executed {
    /// The prompt id.
    pub prompt_id: uuid::Uuid,
    /// The node that was executed.
    pub node: String,
    /// The output of the node.
    pub output: Output,
}

/// Struct representing an output.
#[derive(Serialize, Deserialize, Debug)]
pub struct Output {
    /// A list of images.
    pub images: Vec<Image>,
}

/// Struct representing an image.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Image {
    /// The filename of the image.
    pub filename: String,
    /// The subfolder.
    pub subfolder: String,
    /// The folder type.
    #[serde(rename = "type")]
    pub folder_type: String,
}

/// Struct representing an execution interrupted update.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionInterrupted {
    /// The prompt id.
    pub prompt_id: uuid::Uuid,
    /// The node that was executing.
    pub node_id: String,
    /// The type of the node.
    pub node_type: String,
    /// What was executed prior to interruption.
    pub executed: Vec<String>,
}

/// Struct representing an execution error update.
#[derive(Serialize, Deserialize, Debug)]
pub struct ExecutionError {
    /// The state of execution that was interrupted.
    #[serde(flatten)]
    pub execution_status: ExecutionInterrupted,
    /// The exception message.
    pub exception_message: String,
    /// The exception type.
    pub exception_type: String,
    /// The traceback.
    pub traceback: Vec<String>,
    /// The current inputs.
    pub current_inputs: CurrentInputs,
    /// The current outputs.
    pub current_outputs: CurrentOutputs,
}

/// Struct representing the current inputs when the execution error occurred.
#[derive(Serialize, Deserialize, Debug)]
#[serde(transparent)]
pub struct CurrentInputs {
    /// Hashmap of inputs keyed by input name.
    pub inputs: HashMap<String, serde_json::Value>,
}

/// Struct representing the current outputs when the execution error occurred.
#[derive(Serialize, Deserialize, Debug)]
#[serde(transparent)]
pub struct CurrentOutputs {
    /// Hashmap of outputs keyed by node id.
    pub outputs: HashMap<String, Vec<serde_json::Value>>,
}