stable_diffusion_bot/bot/handlers/
settings.rs

1use anyhow::anyhow;
2use itertools::Itertools as _;
3use sal_e_api::GenParams;
4use teloxide::{
5    dispatching::UpdateHandler,
6    dptree::case,
7    macros::BotCommands,
8    payloads::setters::*,
9    prelude::*,
10    types::{InlineKeyboardButton, InlineKeyboardMarkup},
11};
12use tracing::{error, warn};
13
14use crate::{bot::ConfigParameters, BotState};
15
16use super::{filter_map_bot_state, filter_map_settings, DiffusionDialogue, State};
17
18/// BotCommands for settings.
19#[derive(BotCommands, Clone)]
20#[command(rename_rule = "lowercase", description = "Authenticated commands")]
21pub(crate) enum SettingsCommands {
22    /// Command to set txt2img settings
23    #[command(description = "txt2img settings")]
24    Txt2ImgSettings,
25    /// Command to set img2img settings
26    #[command(description = "img2img settings")]
27    Img2ImgSettings,
28}
29
30/// User-configurable image generation settings.
31#[allow(dead_code)]
32pub(crate) struct Settings {
33    // Number of sampling steps.
34    pub steps: Option<u32>,
35    // Random seed.
36    pub seed: Option<i64>,
37    // Number of images to generate per batch.
38    pub batch_size: Option<u32>,
39    // Number of batches of images to generate.
40    pub n_iter: Option<u32>,
41    // CFG scale.
42    pub cfg_scale: Option<f32>,
43    // Image width.
44    pub width: Option<u32>,
45    // Image height.
46    pub height: Option<u32>,
47    // Negative prompt.
48    pub negative_prompt: Option<String>,
49    // Denoising strength. Only used for img2img.
50    pub denoising_strength: Option<f32>,
51    // Sampler name.
52    pub sampler_index: Option<String>,
53}
54
55impl Settings {
56    /// Build an inline keyboard to configure settings.
57    pub fn keyboard(&self) -> InlineKeyboardMarkup {
58        InlineKeyboardMarkup::new(
59            [
60                self.steps.map(|steps| {
61                    InlineKeyboardButton::callback(format!("Steps: {}", steps), "settings_steps")
62                }),
63                self.seed.map(|seed| {
64                    InlineKeyboardButton::callback(format!("Seed: {}", seed), "settings_seed")
65                }),
66                self.n_iter.map(|n_iter| {
67                    InlineKeyboardButton::callback(
68                        format!("Batch Count: {}", n_iter),
69                        "settings_count",
70                    )
71                }),
72                self.cfg_scale.map(|cfg_scale| {
73                    InlineKeyboardButton::callback(
74                        format!("CFG Scale: {}", cfg_scale),
75                        "settings_cfg",
76                    )
77                }),
78                self.width.map(|width| {
79                    InlineKeyboardButton::callback(format!("Width: {}", width), "settings_width")
80                }),
81                self.height.map(|height| {
82                    InlineKeyboardButton::callback(format!("Height: {}", height), "settings_height")
83                }),
84                self.negative_prompt.as_ref().map(|_| {
85                    InlineKeyboardButton::callback(
86                        "Negative Prompt".to_owned(),
87                        "settings_negative",
88                    )
89                }),
90                self.denoising_strength.map(|denoising_strength| {
91                    InlineKeyboardButton::callback(
92                        format!("Denoising Strength: {}", denoising_strength),
93                        "settings_denoising",
94                    )
95                }),
96                Some(InlineKeyboardButton::callback(
97                    "Cancel".to_owned(),
98                    "settings_back",
99                )),
100            ]
101            .into_iter()
102            .flatten()
103            .chunks(2)
104            .into_iter()
105            .map(Iterator::collect)
106            .collect::<Vec<Vec<_>>>(),
107        )
108    }
109}
110
111impl From<&dyn GenParams> for Settings {
112    fn from(value: &dyn GenParams) -> Self {
113        Self {
114            steps: value.steps(),
115            seed: value.seed(),
116            batch_size: value.batch_size(),
117            n_iter: value.count(),
118            cfg_scale: value.cfg(),
119            width: value.width(),
120            height: value.height(),
121            negative_prompt: value.negative_prompt().clone(),
122            denoising_strength: value.denoising(),
123            sampler_index: value.sampler().clone(),
124        }
125    }
126}
127
128pub(crate) fn filter_callback_query_chat_id() -> UpdateHandler<anyhow::Error> {
129    dptree::filter_map(|q: CallbackQuery| q.message.map(|m| m.chat.id))
130}
131
132pub(crate) async fn handle_message_expired(bot: Bot, q: CallbackQuery) -> anyhow::Result<()> {
133    bot.answer_callback_query(q.id)
134        .cache_time(60)
135        .text("Sorry, this message is no longer available.")
136        .await?;
137    Ok(())
138}
139
140pub(crate) fn filter_callback_query_parent() -> UpdateHandler<anyhow::Error> {
141    dptree::filter_map(|q: CallbackQuery| q.message.and_then(|m| m.reply_to_message().cloned()))
142}
143
144pub(crate) async fn handle_parent_unavailable(bot: Bot, q: CallbackQuery) -> anyhow::Result<()> {
145    bot.answer_callback_query(q.id)
146        .cache_time(60)
147        .text("Oops, something went wrong.")
148        .await?;
149    Ok(())
150}
151
152pub(crate) async fn handle_settings(
153    bot: Bot,
154    dialogue: DiffusionDialogue,
155    (txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
156    q: CallbackQuery,
157    chat_id: ChatId,
158    parent: Message,
159) -> anyhow::Result<()> {
160    let settings = if parent.photo().is_some() {
161        let settings = Settings::from(img2img.as_ref());
162        dialogue
163            .update(State::Ready {
164                bot_state: BotState::SettingsImg2Img { selection: None },
165                txt2img,
166                img2img,
167            })
168            .await
169            .map_err(|e| anyhow!(e))?;
170        settings
171    } else if parent.text().is_some() {
172        let settings = Settings::from(txt2img.as_ref());
173        dialogue
174            .update(State::Ready {
175                bot_state: BotState::SettingsTxt2Img { selection: None },
176                txt2img: txt2img.clone(),
177                img2img,
178            })
179            .await
180            .map_err(|e| anyhow!(e))?;
181        settings
182    } else {
183        bot.answer_callback_query(q.id)
184            .cache_time(60)
185            .text("Oops, something went wrong.")
186            .await?;
187        return Ok(());
188    };
189
190    if let Err(e) = bot.answer_callback_query(q.id).await {
191        warn!("Failed to answer settings callback query: {}", e)
192    }
193    bot.send_message(chat_id, "Please make a selection.")
194        .reply_markup(settings.keyboard())
195        .send()
196        .await?;
197
198    Ok(())
199}
200
201pub(crate) async fn handle_settings_button(
202    bot: Bot,
203    cfg: ConfigParameters,
204    dialogue: DiffusionDialogue,
205    (_, txt2img, img2img): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>),
206    q: CallbackQuery,
207) -> anyhow::Result<()> {
208    let (message, data) = match q {
209        CallbackQuery {
210            message: Some(message),
211            data: Some(data),
212            ..
213        } => (message, data),
214        _ => {
215            bot.answer_callback_query(q.id)
216                .cache_time(60)
217                .text("Sorry, something went wrong.")
218                .await?;
219            return Ok(());
220        }
221    };
222
223    let setting = match data.strip_prefix("settings_") {
224        Some(setting) => setting,
225        None => {
226            bot.answer_callback_query(q.id)
227                .cache_time(60)
228                .text("Sorry, something went wrong.")
229                .await?;
230            return Ok(());
231        }
232    };
233
234    if setting == "back" {
235        dialogue
236            .update(State::Ready {
237                bot_state: BotState::Generate,
238                txt2img,
239                img2img,
240            })
241            .await
242            .map_err(|e| anyhow!(e))?;
243        if let Err(e) = bot.answer_callback_query(q.id).text("Canceled.").await {
244            warn!("Failed to answer back button callback query: {}", e)
245        }
246
247        if let Err(e) = bot.delete_message(message.chat.id, message.id).await {
248            error!("Failed to delete message: {:?}", e);
249            bot.edit_message_text(message.chat.id, message.id, "Please enter a prompt.")
250                .reply_markup(InlineKeyboardMarkup::new([[]]))
251                .await?;
252        }
253        return Ok(());
254    }
255
256    let mut state = dialogue
257        .get()
258        .await
259        .map_err(|e| anyhow!(e))?
260        .unwrap_or_else(|| {
261            State::new_with_defaults(
262                cfg.txt2img_api.gen_params(None),
263                cfg.img2img_api.gen_params(None),
264            )
265        });
266    match &mut state {
267        State::Ready {
268            bot_state: BotState::SettingsTxt2Img { selection },
269            ..
270        }
271        | State::Ready {
272            bot_state: BotState::SettingsImg2Img { selection },
273            ..
274        } => *selection = Some(setting.to_string()),
275        _ => {
276            bot.answer_callback_query(q.id)
277                .cache_time(60)
278                .text("Sorry, something went wrong.")
279                .await?;
280            return Ok(());
281        }
282    }
283
284    if let Err(e) = bot.answer_callback_query(q.id).await {
285        warn!("Failed to answer settings button callback query: {}", e)
286    }
287    dialogue.update(state).await.map_err(|e| anyhow!(e))?;
288
289    bot.send_message(message.chat.id, "Please enter a new value.")
290        .await?;
291
292    Ok(())
293}
294
295fn update_txt2img_setting<S1, S2>(
296    txt2img: &mut dyn GenParams,
297    setting: S1,
298    value: S2,
299) -> anyhow::Result<()>
300where
301    S1: AsRef<str>,
302    S2: AsRef<str>,
303{
304    let value = value.as_ref();
305    match setting.as_ref() {
306        "steps" => txt2img.set_steps(value.parse()?),
307        "seed" => txt2img.set_seed(value.parse()?),
308        "count" => txt2img.set_count(value.parse()?),
309        "cfg" => txt2img.set_cfg(value.parse()?),
310        "width" => txt2img.set_width(value.parse()?),
311        "height" => txt2img.set_height(value.parse()?),
312        "negative" => txt2img.set_negative_prompt(value.to_owned()),
313        "denoising" => txt2img.set_denoising(value.parse()?),
314        _ => return Err(anyhow!("Got invalid setting: {}", setting.as_ref())),
315    }
316    Ok(())
317}
318
319fn update_img2img_setting<S1, S2>(
320    img2img: &mut dyn GenParams,
321    setting: S1,
322    value: S2,
323) -> anyhow::Result<()>
324where
325    S1: AsRef<str>,
326    S2: AsRef<str>,
327{
328    let value = value.as_ref();
329    match setting.as_ref() {
330        "steps" => img2img.set_steps(200.min(value.parse()?)),
331        "seed" => img2img.set_seed((-1).max(value.parse()?)),
332        "count" => img2img.set_count(value.parse::<u32>()?.clamp(1, 10)),
333        "cfg" => img2img.set_cfg(value.parse::<f32>()?.clamp(0.0, 20.0)),
334        "width" => img2img.set_width({
335            let mut value = value.parse::<u32>()?;
336            value -= value % 64;
337            value.clamp(64, 1024)
338        }),
339        "height" => img2img.set_height({
340            let mut value = value.parse::<u32>()?;
341            value -= value % 64;
342            value.clamp(64, 1024)
343        }),
344        "negative" => img2img.set_negative_prompt(value.to_owned()),
345        "denoising" => img2img.set_denoising(value.parse::<f32>()?.clamp(0.0, 1.0)),
346        _ => return Err(anyhow!("invalid setting: {}", setting.as_ref())),
347    }
348    Ok(())
349}
350
351pub(crate) fn state_or_default() -> UpdateHandler<anyhow::Error> {
352    dptree::map_async(
353        |cfg: ConfigParameters, dialogue: DiffusionDialogue| async move {
354            let result = dialogue.get().await;
355            if let Err(ref err) = result {
356                error!("Failed to get state: {:?}", err);
357            }
358            result.ok().flatten().unwrap_or_else(|| {
359                State::new_with_defaults(
360                    cfg.txt2img_api.gen_params(None),
361                    cfg.img2img_api.gen_params(None),
362                )
363            })
364        },
365    )
366}
367
368pub(crate) async fn update_settings_value(
369    bot: Bot,
370    dialogue: DiffusionDialogue,
371    chat_id: ChatId,
372    settings: Settings,
373    state: State,
374) -> anyhow::Result<()> {
375    dialogue.update(state).await.map_err(|e| anyhow!(e))?;
376
377    bot.send_message(chat_id, "Please make a selection.")
378        .reply_markup(settings.keyboard())
379        .await?;
380
381    Ok(())
382}
383
384pub(crate) async fn handle_txt2img_settings_value(
385    bot: Bot,
386    dialogue: DiffusionDialogue,
387    msg: Message,
388    text: String,
389    (selection, mut txt2img, img2img): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>),
390) -> anyhow::Result<()> {
391    if let Some(ref setting) = selection {
392        if let Err(e) = update_txt2img_setting(txt2img.as_mut(), setting, text) {
393            bot.send_message(msg.chat.id, format!("Please enter a valid value: {e:?}."))
394                .await?;
395            return Ok(());
396        }
397    }
398
399    let bot_state = BotState::SettingsTxt2Img { selection: None };
400
401    update_settings_value(
402        bot,
403        dialogue,
404        msg.chat.id,
405        Settings::from(txt2img.as_ref()),
406        State::Ready {
407            bot_state,
408            txt2img,
409            img2img,
410        },
411    )
412    .await
413}
414
415pub(crate) async fn handle_img2img_settings_value(
416    bot: Bot,
417    dialogue: DiffusionDialogue,
418    msg: Message,
419    text: String,
420    (selection, txt2img, mut img2img): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>),
421) -> anyhow::Result<()> {
422    if let Some(ref setting) = selection {
423        if let Err(e) = update_img2img_setting(img2img.as_mut(), setting, text) {
424            bot.send_message(msg.chat.id, format!("Please enter a valid value: {e:?}."))
425                .await?;
426            return Ok(());
427        }
428    }
429
430    let bot_state = BotState::SettingsImg2Img { selection: None };
431
432    update_settings_value(
433        bot,
434        dialogue,
435        msg.chat.id,
436        Settings::from(img2img.as_ref()),
437        State::Ready {
438            bot_state,
439            txt2img,
440            img2img,
441        },
442    )
443    .await
444}
445
446pub(crate) fn map_settings() -> UpdateHandler<anyhow::Error> {
447    dptree::map(|cfg: ConfigParameters, state: State| match state {
448        State::Ready {
449            txt2img, img2img, ..
450        } => (txt2img, img2img),
451        State::New => (
452            cfg.txt2img_api.gen_params(None),
453            cfg.img2img_api.gen_params(None),
454        ),
455    })
456}
457
458async fn handle_img2img_settings_command(
459    msg: Message,
460    bot: Bot,
461    dialogue: DiffusionDialogue,
462    (txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
463) -> anyhow::Result<()> {
464    let settings = Settings::from(img2img.as_ref());
465    dialogue
466        .update(State::Ready {
467            bot_state: BotState::SettingsImg2Img { selection: None },
468            txt2img,
469            img2img,
470        })
471        .await
472        .map_err(|e| anyhow!(e))?;
473    bot.send_message(msg.chat.id, "Please make a selection.")
474        .reply_markup(settings.keyboard())
475        .send()
476        .await?;
477    Ok(())
478}
479
480async fn handle_txt2img_settings_command(
481    msg: Message,
482    bot: Bot,
483    dialogue: DiffusionDialogue,
484    (txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
485) -> anyhow::Result<()> {
486    let settings = Settings::from(txt2img.as_ref());
487    dialogue
488        .update(State::Ready {
489            bot_state: BotState::SettingsTxt2Img { selection: None },
490            txt2img,
491            img2img,
492        })
493        .await
494        .map_err(|e| anyhow!(e))?;
495    bot.send_message(msg.chat.id, "Please make a selection.")
496        .reply_markup(settings.keyboard())
497        .send()
498        .await?;
499    Ok(())
500}
501
502async fn handle_invalid_setting_value(bot: Bot, msg: Message) -> anyhow::Result<()> {
503    bot.send_message(msg.chat.id, "Please enter a valid value.")
504        .await?;
505    Ok(())
506}
507
508pub(crate) fn settings_command_handler() -> UpdateHandler<anyhow::Error> {
509    Update::filter_message()
510        .filter_command::<SettingsCommands>()
511        .chain(state_or_default())
512        .chain(map_settings())
513        .branch(case![SettingsCommands::Txt2ImgSettings].endpoint(handle_txt2img_settings_command))
514        .branch(case![SettingsCommands::Img2ImgSettings].endpoint(handle_img2img_settings_command))
515}
516
517pub(crate) fn filter_settings_callback_query() -> UpdateHandler<anyhow::Error> {
518    Update::filter_callback_query()
519        .filter(|q: CallbackQuery| q.data.is_some_and(|data| data.starts_with("settings")))
520}
521
522pub(crate) fn filter_settings_state() -> UpdateHandler<anyhow::Error> {
523    dptree::filter(|state: State| {
524        let bot_state = match state {
525            State::Ready { bot_state, .. } => bot_state,
526            _ => return false,
527        };
528        matches!(
529            bot_state,
530            BotState::SettingsTxt2Img { .. } | BotState::SettingsImg2Img { .. }
531        )
532    })
533}
534
535pub(crate) fn filter_map_settings_state() -> UpdateHandler<anyhow::Error> {
536    dptree::filter_map(|state: State| {
537        let (bot_state, txt2img, img2img) = match state {
538            State::Ready {
539                bot_state,
540                txt2img,
541                img2img,
542            } => (bot_state, txt2img, img2img),
543            _ => return None,
544        };
545        match bot_state {
546            BotState::SettingsTxt2Img { selection } => Some((selection, txt2img, img2img)),
547            BotState::SettingsImg2Img { selection } => Some((selection, txt2img, img2img)),
548            _ => None,
549        }
550    })
551}
552
553pub(crate) fn settings_schema() -> UpdateHandler<anyhow::Error> {
554    let callback_handler = filter_settings_callback_query()
555        .branch(
556            filter_map_bot_state()
557                .chain(case![BotState::Generate])
558                .chain(filter_map_settings())
559                .branch(
560                    filter_callback_query_chat_id()
561                        .branch(filter_callback_query_parent().endpoint(handle_settings))
562                        .endpoint(handle_parent_unavailable),
563                )
564                .endpoint(handle_message_expired),
565        )
566        .branch(filter_map_settings_state().endpoint(handle_settings_button));
567
568    let message_handler = Update::filter_message()
569        .branch(
570            Message::filter_text()
571                .chain(filter_map_settings_state())
572                .chain(state_or_default())
573                .chain(filter_map_bot_state())
574                .branch(
575                    case![BotState::SettingsTxt2Img { selection }]
576                        .endpoint(handle_txt2img_settings_value),
577                )
578                .branch(
579                    case![BotState::SettingsImg2Img { selection }]
580                        .endpoint(handle_img2img_settings_value),
581                )
582                .endpoint(|| async { Err(anyhow!("Invalid settings state")) }),
583        )
584        .branch(filter_settings_state().endpoint(handle_invalid_setting_value));
585
586    dptree::entry()
587        .branch(settings_command_handler())
588        .branch(message_handler)
589        .branch(callback_handler)
590}
591
592#[cfg(test)]
593mod tests {
594    use async_trait::async_trait;
595    use sal_e_api::{
596        Img2ImgApi, Img2ImgApiError, Img2ImgParams, Response, Txt2ImgApi, Txt2ImgApiError,
597        Txt2ImgParams,
598    };
599    use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest};
600    use teloxide::types::{UpdateKind, User};
601
602    use super::*;
603    use crate::BotState;
604
605    fn create_callback_query_update(data: Option<String>) -> Update {
606        let query = CallbackQuery {
607            id: "123456".to_string(),
608            from: User {
609                id: UserId(123456780),
610                is_bot: true,
611                first_name: "Stable Diffusion".to_string(),
612                last_name: None,
613                username: Some("sdbot".to_string()),
614                language_code: Some("en".to_string()),
615                is_premium: false,
616                added_to_attachment_menu: false,
617            },
618            message: None,
619            inline_message_id: None,
620            chat_instance: "123456".to_string(),
621            data,
622            game_short_name: None,
623        };
624
625        Update {
626            id: 1,
627            kind: UpdateKind::CallbackQuery(query),
628        }
629    }
630
631    #[tokio::test]
632    async fn test_filter_settings_query() {
633        let update = create_callback_query_update(Some("settings".to_string()));
634
635        assert!(matches!(
636            filter_settings_callback_query()
637                .endpoint(|| async { anyhow::Ok(()) })
638                .dispatch(dptree::deps![update])
639                .await,
640            ControlFlow::Break(_)
641        ));
642    }
643
644    #[tokio::test]
645    async fn test_filter_settings_query_none() {
646        let update = create_callback_query_update(None);
647
648        assert!(matches!(
649            filter_settings_callback_query()
650                .endpoint(|| async { anyhow::Ok(()) })
651                .dispatch(dptree::deps![update])
652                .await,
653            ControlFlow::Continue(_)
654        ));
655    }
656
657    #[tokio::test]
658    async fn test_filter_settings_query_bad_data() {
659        let update = create_callback_query_update(Some("bad_data".to_string()));
660
661        assert!(matches!(
662            filter_settings_callback_query()
663                .endpoint(|| async { anyhow::Ok(()) })
664                .dispatch(dptree::deps![update])
665                .await,
666            ControlFlow::Continue(_)
667        ));
668    }
669
670    #[tokio::test]
671    async fn test_filter_settings_state_txt2img() {
672        assert!(matches!(
673            filter_settings_state()
674                .endpoint(|| async { anyhow::Ok(()) })
675                .dispatch(dptree::deps![State::Ready {
676                    bot_state: BotState::SettingsTxt2Img { selection: None },
677                    txt2img: Box::<Txt2ImgParams>::default(),
678                    img2img: Box::<Img2ImgParams>::default()
679                }])
680                .await,
681            ControlFlow::Break(_)
682        ));
683    }
684
685    #[tokio::test]
686    async fn test_filter_settings_state_img2img() {
687        assert!(matches!(
688            filter_settings_state()
689                .endpoint(|| async { anyhow::Ok(()) })
690                .dispatch(dptree::deps![State::Ready {
691                    bot_state: BotState::SettingsImg2Img { selection: None },
692                    txt2img: Box::<Txt2ImgParams>::default(),
693                    img2img: Box::<Img2ImgParams>::default()
694                }])
695                .await,
696            ControlFlow::Break(_)
697        ));
698    }
699
700    #[tokio::test]
701    async fn test_filter_settings_state() {
702        assert!(matches!(
703            filter_settings_state()
704                .endpoint(|| async { anyhow::Ok(()) })
705                .dispatch(dptree::deps![State::New])
706                .await,
707            ControlFlow::Continue(_)
708        ));
709    }
710
711    #[tokio::test]
712    async fn test_filter_map_settings_state_txt2img() {
713        assert!(matches!(
714            filter_map_settings_state()
715                .endpoint(
716                    |(_, _, _): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>)| async {
717                        anyhow::Ok(())
718                    }
719                )
720                .dispatch(dptree::deps![State::Ready {
721                    bot_state: BotState::SettingsTxt2Img { selection: None },
722                    txt2img: Box::<Txt2ImgParams>::default(),
723                    img2img: Box::<Img2ImgParams>::default()
724                }])
725                .await,
726            ControlFlow::Break(_)
727        ));
728    }
729
730    #[tokio::test]
731    async fn test_filter_map_settings_state_img2img() {
732        assert!(matches!(
733            filter_map_settings_state()
734                .endpoint(
735                    |(_, _, _): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>)| async {
736                        anyhow::Ok(())
737                    }
738                )
739                .dispatch(dptree::deps![State::Ready {
740                    bot_state: BotState::SettingsImg2Img { selection: None },
741                    txt2img: Box::<Txt2ImgParams>::default(),
742                    img2img: Box::<Img2ImgParams>::default()
743                }])
744                .await,
745            ControlFlow::Break(_)
746        ));
747    }
748
749    #[tokio::test]
750    async fn test_filter_map_settings_state() {
751        assert!(matches!(
752            filter_map_settings_state()
753                .endpoint(
754                    |(_, _, _): (Option<String>, &dyn GenParams, &dyn GenParams)| async {
755                        anyhow::Ok(())
756                    }
757                )
758                .dispatch(dptree::deps![State::New])
759                .await,
760            ControlFlow::Continue(_)
761        ));
762    }
763
764    #[derive(Debug, Clone, Default)]
765    struct MockApi;
766
767    #[async_trait]
768    impl Txt2ImgApi for MockApi {
769        fn gen_params(&self, _user_params: Option<&dyn GenParams>) -> Box<dyn GenParams> {
770            Box::<Txt2ImgParams>::default()
771        }
772
773        async fn txt2img(&self, _config: &dyn GenParams) -> Result<Response, Txt2ImgApiError> {
774            Err(anyhow!("Not implemented"))?
775        }
776    }
777
778    #[async_trait]
779    impl Img2ImgApi for MockApi {
780        fn gen_params(&self, _user_params: Option<&dyn GenParams>) -> Box<dyn GenParams> {
781            Box::<Img2ImgParams>::default()
782        }
783
784        async fn img2img(&self, _config: &dyn GenParams) -> Result<Response, Img2ImgApiError> {
785            Err(anyhow!("Not implemented"))?
786        }
787    }
788
789    #[tokio::test]
790    async fn test_map_settings_default() {
791        assert!(matches!(
792            map_settings()
793                .endpoint(
794                    |(txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>)| async move {
795                        let txt2img = txt2img
796                            .as_ref()
797                            .as_any()
798                            .downcast_ref::<Txt2ImgParams>()
799                            .unwrap();
800                        let img2img = img2img
801                            .as_ref()
802                            .as_any()
803                            .downcast_ref::<Img2ImgParams>()
804                            .unwrap();
805                        assert!(
806                            (txt2img, img2img)
807                                == (&Txt2ImgParams::default(), &Img2ImgParams::default())
808                        );
809                        anyhow::Ok(())
810                    }
811                )
812                .dispatch(dptree::deps![
813                    ConfigParameters {
814                        txt2img_api: Box::new(MockApi),
815                        img2img_api: Box::new(MockApi),
816                        allowed_users: Default::default(),
817                        allow_all_users: false
818                    },
819                    State::New
820                ])
821                .await,
822            ControlFlow::Break(_)
823        ));
824    }
825
826    #[tokio::test]
827    async fn test_map_settings_ready() {
828        let txt2img = Txt2ImgParams {
829            user_params: Txt2ImgRequest {
830                negative_prompt: Some("test".to_string()),
831                ..Txt2ImgRequest::default()
832            },
833            defaults: Some(Txt2ImgRequest::default()),
834        };
835        let img2img = Img2ImgParams {
836            user_params: Img2ImgRequest {
837                negative_prompt: Some("test".to_string()),
838                ..Img2ImgRequest::default()
839            },
840            defaults: Some(Img2ImgRequest::default()),
841        };
842        assert!(matches!(
843            map_settings()
844                .endpoint(
845                    |(txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>)| async move {
846                        let txt2img = txt2img
847                            .as_ref()
848                            .as_any()
849                            .downcast_ref::<Txt2ImgParams>()
850                            .unwrap();
851                        let img2img = img2img
852                            .as_ref()
853                            .as_any()
854                            .downcast_ref::<Img2ImgParams>()
855                            .unwrap();
856                        assert!(
857                            (txt2img, img2img)
858                                == (
859                                    &Txt2ImgParams {
860                                        user_params: Txt2ImgRequest {
861                                            negative_prompt: Some("test".to_string()),
862                                            ..Txt2ImgRequest::default()
863                                        },
864                                        defaults: Some(Txt2ImgRequest::default()),
865                                    },
866                                    &Img2ImgParams {
867                                        user_params: Img2ImgRequest {
868                                            negative_prompt: Some("test".to_string()),
869                                            ..Img2ImgRequest::default()
870                                        },
871                                        defaults: Some(Img2ImgRequest::default()),
872                                    }
873                                )
874                        );
875                        anyhow::Ok(())
876                    }
877                )
878                .dispatch(dptree::deps![
879                    ConfigParameters {
880                        txt2img_api: Box::new(MockApi),
881                        img2img_api: Box::new(MockApi),
882                        allowed_users: Default::default(),
883                        allow_all_users: false
884                    },
885                    State::Ready {
886                        bot_state: BotState::Generate,
887                        txt2img: Box::new(txt2img),
888                        img2img: Box::new(img2img)
889                    }
890                ])
891                .await,
892            ControlFlow::Break(_)
893        ));
894    }
895}