stable_diffusion_bot/bot/handlers/
mod.rs

1use anyhow::anyhow;
2use teloxide::{
3    dispatching::UpdateHandler,
4    prelude::*,
5    types::{Me, ParseMode},
6    utils::{command::BotCommands, markdown},
7};
8
9use crate::BotState;
10
11use super::{ConfigParameters, DiffusionDialogue, State};
12
13mod image;
14pub(crate) use image::*;
15
16mod settings;
17pub(crate) use settings::*;
18
19#[derive(BotCommands, Clone)]
20#[command(rename_rule = "lowercase", description = "Simple commands")]
21pub(crate) enum UnauthenticatedCommands {
22    #[command(description = "show help message.")]
23    Help,
24    #[command(description = "start the bot.")]
25    Start,
26    #[command(description = "change settings.")]
27    Settings,
28}
29
30pub(crate) async fn unauthenticated_commands_handler(
31    cfg: ConfigParameters,
32    bot: Bot,
33    me: teloxide::types::Me,
34    msg: Message,
35    cmd: UnauthenticatedCommands,
36    dialogue: DiffusionDialogue,
37) -> anyhow::Result<()> {
38    let text = match cmd {
39        UnauthenticatedCommands::Help => {
40            if cfg.chat_is_allowed(&msg.chat.id)
41                || cfg.chat_is_allowed(&msg.from().unwrap().id.into())
42            {
43                format!(
44                    "{}\n\n{}\n\n{}",
45                    UnauthenticatedCommands::descriptions(),
46                    SettingsCommands::descriptions(),
47                    GenCommands::descriptions()
48                )
49            } else if msg.chat.is_group() || msg.chat.is_supergroup() {
50                UnauthenticatedCommands::descriptions()
51                    .username_from_me(&me)
52                    .to_string()
53            } else {
54                UnauthenticatedCommands::descriptions().to_string()
55            }
56        }
57        UnauthenticatedCommands::Start => {
58            dialogue
59                .update(State::Ready {
60                    bot_state: BotState::default(),
61                    txt2img: cfg.txt2img_api.gen_params(None),
62                    img2img: cfg.img2img_api.gen_params(None),
63                })
64                .await
65                .map_err(|e| anyhow!(e))?;
66            "This bot generates images using stable diffusion! Enter a prompt to get started!"
67                .to_owned()
68        }
69        UnauthenticatedCommands::Settings => "Sorry, not yet implemented.".to_owned(),
70    };
71
72    bot.send_message(msg.chat.id, markdown::escape(&text))
73        .parse_mode(ParseMode::MarkdownV2)
74        .await?;
75
76    Ok(())
77}
78
79pub(crate) fn filter_map_bot_state() -> UpdateHandler<anyhow::Error> {
80    dptree::filter_map(|state: State| match state {
81        State::Ready { bot_state, .. } => Some(bot_state),
82        _ => None,
83    })
84}
85
86pub(crate) fn filter_map_settings() -> UpdateHandler<anyhow::Error> {
87    dptree::filter_map(|state: State| match state {
88        State::Ready {
89            txt2img, img2img, ..
90        } => Some((txt2img, img2img)),
91        _ => None,
92    })
93}
94
95pub(crate) fn auth_filter() -> UpdateHandler<anyhow::Error> {
96    dptree::filter(|cfg: ConfigParameters, upd: Update| {
97        upd.chat()
98            .map(|chat| cfg.chat_is_allowed(&chat.id))
99            .unwrap_or_default()
100            || upd
101                .user()
102                .map(|user| cfg.chat_is_allowed(&user.id.into()))
103                .unwrap_or_default()
104    })
105}
106
107pub fn filter_command<C>() -> UpdateHandler<anyhow::Error>
108where
109    C: BotCommands + Send + Sync + 'static,
110{
111    dptree::filter_map(move |message: Message, me: Me| {
112        let bot_name = me.user.username.expect("Bots must have a username");
113        message
114            .text()
115            .and_then(|text| C::parse(text, &bot_name).ok())
116            .or_else(|| {
117                message
118                    .caption()
119                    .and_then(|text| C::parse(text, &bot_name).ok())
120            })
121    })
122}
123
124pub(crate) fn unauth_command_filter() -> UpdateHandler<anyhow::Error> {
125    Update::filter_message().chain(teloxide::filter_command::<UnauthenticatedCommands, _>())
126}
127
128pub(crate) fn unauth_command_handler() -> UpdateHandler<anyhow::Error> {
129    unauth_command_filter().endpoint(unauthenticated_commands_handler)
130}
131
132pub(crate) fn authenticated_command_handler() -> UpdateHandler<anyhow::Error> {
133    auth_filter()
134        .branch(settings_schema())
135        .branch(image_schema())
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use async_trait::async_trait;
142    use sal_e_api::{
143        GenParams, Img2ImgApi, Img2ImgApiError, Img2ImgParams, Response, Txt2ImgApi,
144        Txt2ImgApiError, Txt2ImgParams,
145    };
146    use teloxide::types::{Me, UpdateKind, User};
147
148    fn create_message(text: &str) -> Message {
149        let json = format!(
150            r#"{{
151          "message_id": 123456,
152          "from": {{
153           "id": 123456789,
154           "is_bot": false,
155           "first_name": "Stable",
156           "last_name": "Diffusion",
157           "username": "sd",
158           "language_code": "en"
159          }},
160          "chat": {{
161           "id": 1234567890,
162           "first_name": "Stable",
163           "last_name": "Diffusion",
164           "username": "sd",
165           "type": "private"
166          }},
167          "date": 1634567890,
168          "text": "{}"
169         }}"#,
170            text
171        );
172        serde_json::from_str::<Message>(&json).unwrap()
173    }
174
175    fn create_me() -> Me {
176        Me {
177            user: User {
178                id: UserId(123456780),
179                is_bot: true,
180                first_name: "Stable Diffusion".to_string(),
181                last_name: None,
182                username: Some("sdbot".to_string()),
183                language_code: Some("en".to_string()),
184                is_premium: false,
185                added_to_attachment_menu: false,
186            },
187            can_join_groups: false,
188            can_read_all_group_messages: false,
189            supports_inline_queries: false,
190        }
191    }
192
193    #[derive(Debug, Clone, Default)]
194    struct MockApi;
195
196    #[async_trait]
197    impl Txt2ImgApi for MockApi {
198        fn gen_params(&self, _user_settings: Option<&dyn GenParams>) -> Box<dyn GenParams> {
199            Box::<Txt2ImgParams>::default()
200        }
201
202        async fn txt2img(&self, _config: &dyn GenParams) -> Result<Response, Txt2ImgApiError> {
203            Err(anyhow!("Not implemented"))?
204        }
205    }
206
207    #[async_trait]
208    impl Img2ImgApi for MockApi {
209        fn gen_params(&self, _user_settings: Option<&dyn GenParams>) -> Box<dyn GenParams> {
210            Box::<Img2ImgParams>::default()
211        }
212
213        async fn img2img(&self, _config: &dyn GenParams) -> Result<Response, Img2ImgApiError> {
214            Err(anyhow!("Not implemented"))?
215        }
216    }
217
218    fn create_config(allowed_users: Vec<i64>, allow_all_users: bool) -> ConfigParameters {
219        ConfigParameters {
220            allowed_users: allowed_users.into_iter().map(ChatId).collect(),
221            allow_all_users,
222            txt2img_api: Box::new(MockApi),
223            img2img_api: Box::new(MockApi),
224        }
225    }
226
227    #[tokio::test]
228    async fn test_unauth_command_filter_help() {
229        let me = create_me();
230
231        let msg = create_message("/help");
232
233        let update = Update {
234            id: 1,
235            kind: UpdateKind::Message(msg.clone()),
236        };
237
238        assert!(matches!(
239            unauth_command_filter()
240                .endpoint(|| async move { anyhow::Ok(()) })
241                .dispatch(dptree::deps![msg, update, me])
242                .await,
243            ControlFlow::Break(_)
244        ));
245    }
246
247    #[tokio::test]
248    async fn test_unauth_command_handler_start() {
249        let me = create_me();
250
251        let msg = create_message("/start");
252
253        let update = Update {
254            id: 1,
255            kind: UpdateKind::Message(msg.clone()),
256        };
257
258        assert!(matches!(
259            unauth_command_filter()
260                .endpoint(|| async move { anyhow::Ok(()) })
261                .dispatch(dptree::deps![msg, update, me])
262                .await,
263            ControlFlow::Break(_)
264        ));
265    }
266
267    #[tokio::test]
268    async fn test_unauth_command_filter_settings() {
269        let me = create_me();
270
271        let msg = create_message("/settings");
272
273        let update = Update {
274            id: 1,
275            kind: UpdateKind::Message(msg.clone()),
276        };
277
278        assert!(matches!(
279            unauth_command_filter()
280                .endpoint(|| async move { anyhow::Ok(()) })
281                .dispatch(dptree::deps![msg, update, me])
282                .await,
283            ControlFlow::Break(_)
284        ));
285    }
286
287    #[tokio::test]
288    async fn test_auth_filter_allow_all_users() {
289        let cfg = create_config(vec![], true);
290
291        let me = create_me();
292
293        let msg = create_message("");
294
295        let update = Update {
296            id: 1,
297            kind: UpdateKind::Message(msg.clone()),
298        };
299
300        assert!(matches!(
301            auth_filter()
302                .endpoint(|| async move { anyhow::Ok(()) })
303                .dispatch(dptree::deps![msg, update, me, cfg])
304                .await,
305            ControlFlow::Break(_)
306        ));
307    }
308
309    #[tokio::test]
310    async fn test_auth_filter_allow_no_users() {
311        let cfg = create_config(vec![], false);
312
313        let me = create_me();
314
315        let msg = create_message("");
316
317        let update = Update {
318            id: 1,
319            kind: UpdateKind::Message(msg.clone()),
320        };
321
322        assert!(matches!(
323            auth_filter()
324                .endpoint(|| async move { anyhow::Ok(()) })
325                .dispatch(dptree::deps![msg, update, me, cfg])
326                .await,
327            ControlFlow::Continue(_)
328        ));
329    }
330
331    #[tokio::test]
332    async fn test_auth_filter_allow_user() {
333        let cfg = create_config(vec![123456789], false);
334
335        let me = create_me();
336
337        let msg = create_message("");
338
339        let update = Update {
340            id: 1,
341            kind: UpdateKind::Message(msg.clone()),
342        };
343
344        assert!(matches!(
345            auth_filter()
346                .endpoint(|| async move { anyhow::Ok(()) })
347                .dispatch(dptree::deps![msg, update, me, cfg])
348                .await,
349            ControlFlow::Break(_)
350        ));
351    }
352
353    #[tokio::test]
354    async fn test_auth_filter_allow_chat() {
355        let cfg = create_config(vec![1234567890], false);
356
357        let me = create_me();
358
359        let msg = create_message("");
360
361        let update = Update {
362            id: 1,
363            kind: UpdateKind::Message(msg.clone()),
364        };
365
366        assert!(matches!(
367            auth_filter()
368                .endpoint(|| async move { anyhow::Ok(()) })
369                .dispatch(dptree::deps![msg, update, me, cfg])
370                .await,
371            ControlFlow::Break(_)
372        ));
373    }
374}