stable_diffusion_bot/bot/handlers/
mod.rs1use anyhow::anyhow;
2use teloxide::{
3 dispatching::UpdateHandler,
4 prelude::*,
5 types::{Me, ParseMode},
6 utils::{command::BotCommands, markdown},
7};
8
9use crate::BotState;
10
11use super::{ConfigParameters, DiffusionDialogue, State};
12
13mod image;
14pub(crate) use image::*;
15
16mod settings;
17pub(crate) use settings::*;
18
19#[derive(BotCommands, Clone)]
20#[command(rename_rule = "lowercase", description = "Simple commands")]
21pub(crate) enum UnauthenticatedCommands {
22 #[command(description = "show help message.")]
23 Help,
24 #[command(description = "start the bot.")]
25 Start,
26 #[command(description = "change settings.")]
27 Settings,
28}
29
30pub(crate) async fn unauthenticated_commands_handler(
31 cfg: ConfigParameters,
32 bot: Bot,
33 me: teloxide::types::Me,
34 msg: Message,
35 cmd: UnauthenticatedCommands,
36 dialogue: DiffusionDialogue,
37) -> anyhow::Result<()> {
38 let text = match cmd {
39 UnauthenticatedCommands::Help => {
40 if cfg.chat_is_allowed(&msg.chat.id)
41 || cfg.chat_is_allowed(&msg.from().unwrap().id.into())
42 {
43 format!(
44 "{}\n\n{}\n\n{}",
45 UnauthenticatedCommands::descriptions(),
46 SettingsCommands::descriptions(),
47 GenCommands::descriptions()
48 )
49 } else if msg.chat.is_group() || msg.chat.is_supergroup() {
50 UnauthenticatedCommands::descriptions()
51 .username_from_me(&me)
52 .to_string()
53 } else {
54 UnauthenticatedCommands::descriptions().to_string()
55 }
56 }
57 UnauthenticatedCommands::Start => {
58 dialogue
59 .update(State::Ready {
60 bot_state: BotState::default(),
61 txt2img: cfg.txt2img_api.gen_params(None),
62 img2img: cfg.img2img_api.gen_params(None),
63 })
64 .await
65 .map_err(|e| anyhow!(e))?;
66 "This bot generates images using stable diffusion! Enter a prompt to get started!"
67 .to_owned()
68 }
69 UnauthenticatedCommands::Settings => "Sorry, not yet implemented.".to_owned(),
70 };
71
72 bot.send_message(msg.chat.id, markdown::escape(&text))
73 .parse_mode(ParseMode::MarkdownV2)
74 .await?;
75
76 Ok(())
77}
78
79pub(crate) fn filter_map_bot_state() -> UpdateHandler<anyhow::Error> {
80 dptree::filter_map(|state: State| match state {
81 State::Ready { bot_state, .. } => Some(bot_state),
82 _ => None,
83 })
84}
85
86pub(crate) fn filter_map_settings() -> UpdateHandler<anyhow::Error> {
87 dptree::filter_map(|state: State| match state {
88 State::Ready {
89 txt2img, img2img, ..
90 } => Some((txt2img, img2img)),
91 _ => None,
92 })
93}
94
95pub(crate) fn auth_filter() -> UpdateHandler<anyhow::Error> {
96 dptree::filter(|cfg: ConfigParameters, upd: Update| {
97 upd.chat()
98 .map(|chat| cfg.chat_is_allowed(&chat.id))
99 .unwrap_or_default()
100 || upd
101 .user()
102 .map(|user| cfg.chat_is_allowed(&user.id.into()))
103 .unwrap_or_default()
104 })
105}
106
107pub fn filter_command<C>() -> UpdateHandler<anyhow::Error>
108where
109 C: BotCommands + Send + Sync + 'static,
110{
111 dptree::filter_map(move |message: Message, me: Me| {
112 let bot_name = me.user.username.expect("Bots must have a username");
113 message
114 .text()
115 .and_then(|text| C::parse(text, &bot_name).ok())
116 .or_else(|| {
117 message
118 .caption()
119 .and_then(|text| C::parse(text, &bot_name).ok())
120 })
121 })
122}
123
124pub(crate) fn unauth_command_filter() -> UpdateHandler<anyhow::Error> {
125 Update::filter_message().chain(teloxide::filter_command::<UnauthenticatedCommands, _>())
126}
127
128pub(crate) fn unauth_command_handler() -> UpdateHandler<anyhow::Error> {
129 unauth_command_filter().endpoint(unauthenticated_commands_handler)
130}
131
132pub(crate) fn authenticated_command_handler() -> UpdateHandler<anyhow::Error> {
133 auth_filter()
134 .branch(settings_schema())
135 .branch(image_schema())
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use async_trait::async_trait;
142 use sal_e_api::{
143 GenParams, Img2ImgApi, Img2ImgApiError, Img2ImgParams, Response, Txt2ImgApi,
144 Txt2ImgApiError, Txt2ImgParams,
145 };
146 use teloxide::types::{Me, UpdateKind, User};
147
148 fn create_message(text: &str) -> Message {
149 let json = format!(
150 r#"{{
151 "message_id": 123456,
152 "from": {{
153 "id": 123456789,
154 "is_bot": false,
155 "first_name": "Stable",
156 "last_name": "Diffusion",
157 "username": "sd",
158 "language_code": "en"
159 }},
160 "chat": {{
161 "id": 1234567890,
162 "first_name": "Stable",
163 "last_name": "Diffusion",
164 "username": "sd",
165 "type": "private"
166 }},
167 "date": 1634567890,
168 "text": "{}"
169 }}"#,
170 text
171 );
172 serde_json::from_str::<Message>(&json).unwrap()
173 }
174
175 fn create_me() -> Me {
176 Me {
177 user: User {
178 id: UserId(123456780),
179 is_bot: true,
180 first_name: "Stable Diffusion".to_string(),
181 last_name: None,
182 username: Some("sdbot".to_string()),
183 language_code: Some("en".to_string()),
184 is_premium: false,
185 added_to_attachment_menu: false,
186 },
187 can_join_groups: false,
188 can_read_all_group_messages: false,
189 supports_inline_queries: false,
190 }
191 }
192
193 #[derive(Debug, Clone, Default)]
194 struct MockApi;
195
196 #[async_trait]
197 impl Txt2ImgApi for MockApi {
198 fn gen_params(&self, _user_settings: Option<&dyn GenParams>) -> Box<dyn GenParams> {
199 Box::<Txt2ImgParams>::default()
200 }
201
202 async fn txt2img(&self, _config: &dyn GenParams) -> Result<Response, Txt2ImgApiError> {
203 Err(anyhow!("Not implemented"))?
204 }
205 }
206
207 #[async_trait]
208 impl Img2ImgApi for MockApi {
209 fn gen_params(&self, _user_settings: Option<&dyn GenParams>) -> Box<dyn GenParams> {
210 Box::<Img2ImgParams>::default()
211 }
212
213 async fn img2img(&self, _config: &dyn GenParams) -> Result<Response, Img2ImgApiError> {
214 Err(anyhow!("Not implemented"))?
215 }
216 }
217
218 fn create_config(allowed_users: Vec<i64>, allow_all_users: bool) -> ConfigParameters {
219 ConfigParameters {
220 allowed_users: allowed_users.into_iter().map(ChatId).collect(),
221 allow_all_users,
222 txt2img_api: Box::new(MockApi),
223 img2img_api: Box::new(MockApi),
224 }
225 }
226
227 #[tokio::test]
228 async fn test_unauth_command_filter_help() {
229 let me = create_me();
230
231 let msg = create_message("/help");
232
233 let update = Update {
234 id: 1,
235 kind: UpdateKind::Message(msg.clone()),
236 };
237
238 assert!(matches!(
239 unauth_command_filter()
240 .endpoint(|| async move { anyhow::Ok(()) })
241 .dispatch(dptree::deps![msg, update, me])
242 .await,
243 ControlFlow::Break(_)
244 ));
245 }
246
247 #[tokio::test]
248 async fn test_unauth_command_handler_start() {
249 let me = create_me();
250
251 let msg = create_message("/start");
252
253 let update = Update {
254 id: 1,
255 kind: UpdateKind::Message(msg.clone()),
256 };
257
258 assert!(matches!(
259 unauth_command_filter()
260 .endpoint(|| async move { anyhow::Ok(()) })
261 .dispatch(dptree::deps![msg, update, me])
262 .await,
263 ControlFlow::Break(_)
264 ));
265 }
266
267 #[tokio::test]
268 async fn test_unauth_command_filter_settings() {
269 let me = create_me();
270
271 let msg = create_message("/settings");
272
273 let update = Update {
274 id: 1,
275 kind: UpdateKind::Message(msg.clone()),
276 };
277
278 assert!(matches!(
279 unauth_command_filter()
280 .endpoint(|| async move { anyhow::Ok(()) })
281 .dispatch(dptree::deps![msg, update, me])
282 .await,
283 ControlFlow::Break(_)
284 ));
285 }
286
287 #[tokio::test]
288 async fn test_auth_filter_allow_all_users() {
289 let cfg = create_config(vec![], true);
290
291 let me = create_me();
292
293 let msg = create_message("");
294
295 let update = Update {
296 id: 1,
297 kind: UpdateKind::Message(msg.clone()),
298 };
299
300 assert!(matches!(
301 auth_filter()
302 .endpoint(|| async move { anyhow::Ok(()) })
303 .dispatch(dptree::deps![msg, update, me, cfg])
304 .await,
305 ControlFlow::Break(_)
306 ));
307 }
308
309 #[tokio::test]
310 async fn test_auth_filter_allow_no_users() {
311 let cfg = create_config(vec![], false);
312
313 let me = create_me();
314
315 let msg = create_message("");
316
317 let update = Update {
318 id: 1,
319 kind: UpdateKind::Message(msg.clone()),
320 };
321
322 assert!(matches!(
323 auth_filter()
324 .endpoint(|| async move { anyhow::Ok(()) })
325 .dispatch(dptree::deps![msg, update, me, cfg])
326 .await,
327 ControlFlow::Continue(_)
328 ));
329 }
330
331 #[tokio::test]
332 async fn test_auth_filter_allow_user() {
333 let cfg = create_config(vec![123456789], false);
334
335 let me = create_me();
336
337 let msg = create_message("");
338
339 let update = Update {
340 id: 1,
341 kind: UpdateKind::Message(msg.clone()),
342 };
343
344 assert!(matches!(
345 auth_filter()
346 .endpoint(|| async move { anyhow::Ok(()) })
347 .dispatch(dptree::deps![msg, update, me, cfg])
348 .await,
349 ControlFlow::Break(_)
350 ));
351 }
352
353 #[tokio::test]
354 async fn test_auth_filter_allow_chat() {
355 let cfg = create_config(vec![1234567890], false);
356
357 let me = create_me();
358
359 let msg = create_message("");
360
361 let update = Update {
362 id: 1,
363 kind: UpdateKind::Message(msg.clone()),
364 };
365
366 assert!(matches!(
367 auth_filter()
368 .endpoint(|| async move { anyhow::Ok(()) })
369 .dispatch(dptree::deps![msg, update, me, cfg])
370 .await,
371 ControlFlow::Break(_)
372 ));
373 }
374}