1use anyhow::anyhow;
2use itertools::Itertools as _;
3use sal_e_api::GenParams;
4use teloxide::{
5 dispatching::UpdateHandler,
6 dptree::case,
7 macros::BotCommands,
8 payloads::setters::*,
9 prelude::*,
10 types::{InlineKeyboardButton, InlineKeyboardMarkup},
11};
12use tracing::{error, warn};
13
14use crate::{bot::ConfigParameters, BotState};
15
16use super::{filter_map_bot_state, filter_map_settings, DiffusionDialogue, State};
17
18#[derive(BotCommands, Clone)]
20#[command(rename_rule = "lowercase", description = "Authenticated commands")]
21pub(crate) enum SettingsCommands {
22 #[command(description = "txt2img settings")]
24 Txt2ImgSettings,
25 #[command(description = "img2img settings")]
27 Img2ImgSettings,
28}
29
30#[allow(dead_code)]
32pub(crate) struct Settings {
33 pub steps: Option<u32>,
35 pub seed: Option<i64>,
37 pub batch_size: Option<u32>,
39 pub n_iter: Option<u32>,
41 pub cfg_scale: Option<f32>,
43 pub width: Option<u32>,
45 pub height: Option<u32>,
47 pub negative_prompt: Option<String>,
49 pub denoising_strength: Option<f32>,
51 pub sampler_index: Option<String>,
53}
54
55impl Settings {
56 pub fn keyboard(&self) -> InlineKeyboardMarkup {
58 InlineKeyboardMarkup::new(
59 [
60 self.steps.map(|steps| {
61 InlineKeyboardButton::callback(format!("Steps: {}", steps), "settings_steps")
62 }),
63 self.seed.map(|seed| {
64 InlineKeyboardButton::callback(format!("Seed: {}", seed), "settings_seed")
65 }),
66 self.n_iter.map(|n_iter| {
67 InlineKeyboardButton::callback(
68 format!("Batch Count: {}", n_iter),
69 "settings_count",
70 )
71 }),
72 self.cfg_scale.map(|cfg_scale| {
73 InlineKeyboardButton::callback(
74 format!("CFG Scale: {}", cfg_scale),
75 "settings_cfg",
76 )
77 }),
78 self.width.map(|width| {
79 InlineKeyboardButton::callback(format!("Width: {}", width), "settings_width")
80 }),
81 self.height.map(|height| {
82 InlineKeyboardButton::callback(format!("Height: {}", height), "settings_height")
83 }),
84 self.negative_prompt.as_ref().map(|_| {
85 InlineKeyboardButton::callback(
86 "Negative Prompt".to_owned(),
87 "settings_negative",
88 )
89 }),
90 self.denoising_strength.map(|denoising_strength| {
91 InlineKeyboardButton::callback(
92 format!("Denoising Strength: {}", denoising_strength),
93 "settings_denoising",
94 )
95 }),
96 Some(InlineKeyboardButton::callback(
97 "Cancel".to_owned(),
98 "settings_back",
99 )),
100 ]
101 .into_iter()
102 .flatten()
103 .chunks(2)
104 .into_iter()
105 .map(Iterator::collect)
106 .collect::<Vec<Vec<_>>>(),
107 )
108 }
109}
110
111impl From<&dyn GenParams> for Settings {
112 fn from(value: &dyn GenParams) -> Self {
113 Self {
114 steps: value.steps(),
115 seed: value.seed(),
116 batch_size: value.batch_size(),
117 n_iter: value.count(),
118 cfg_scale: value.cfg(),
119 width: value.width(),
120 height: value.height(),
121 negative_prompt: value.negative_prompt().clone(),
122 denoising_strength: value.denoising(),
123 sampler_index: value.sampler().clone(),
124 }
125 }
126}
127
128pub(crate) fn filter_callback_query_chat_id() -> UpdateHandler<anyhow::Error> {
129 dptree::filter_map(|q: CallbackQuery| q.message.map(|m| m.chat.id))
130}
131
132pub(crate) async fn handle_message_expired(bot: Bot, q: CallbackQuery) -> anyhow::Result<()> {
133 bot.answer_callback_query(q.id)
134 .cache_time(60)
135 .text("Sorry, this message is no longer available.")
136 .await?;
137 Ok(())
138}
139
140pub(crate) fn filter_callback_query_parent() -> UpdateHandler<anyhow::Error> {
141 dptree::filter_map(|q: CallbackQuery| q.message.and_then(|m| m.reply_to_message().cloned()))
142}
143
144pub(crate) async fn handle_parent_unavailable(bot: Bot, q: CallbackQuery) -> anyhow::Result<()> {
145 bot.answer_callback_query(q.id)
146 .cache_time(60)
147 .text("Oops, something went wrong.")
148 .await?;
149 Ok(())
150}
151
152pub(crate) async fn handle_settings(
153 bot: Bot,
154 dialogue: DiffusionDialogue,
155 (txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
156 q: CallbackQuery,
157 chat_id: ChatId,
158 parent: Message,
159) -> anyhow::Result<()> {
160 let settings = if parent.photo().is_some() {
161 let settings = Settings::from(img2img.as_ref());
162 dialogue
163 .update(State::Ready {
164 bot_state: BotState::SettingsImg2Img { selection: None },
165 txt2img,
166 img2img,
167 })
168 .await
169 .map_err(|e| anyhow!(e))?;
170 settings
171 } else if parent.text().is_some() {
172 let settings = Settings::from(txt2img.as_ref());
173 dialogue
174 .update(State::Ready {
175 bot_state: BotState::SettingsTxt2Img { selection: None },
176 txt2img: txt2img.clone(),
177 img2img,
178 })
179 .await
180 .map_err(|e| anyhow!(e))?;
181 settings
182 } else {
183 bot.answer_callback_query(q.id)
184 .cache_time(60)
185 .text("Oops, something went wrong.")
186 .await?;
187 return Ok(());
188 };
189
190 if let Err(e) = bot.answer_callback_query(q.id).await {
191 warn!("Failed to answer settings callback query: {}", e)
192 }
193 bot.send_message(chat_id, "Please make a selection.")
194 .reply_markup(settings.keyboard())
195 .send()
196 .await?;
197
198 Ok(())
199}
200
201pub(crate) async fn handle_settings_button(
202 bot: Bot,
203 cfg: ConfigParameters,
204 dialogue: DiffusionDialogue,
205 (_, txt2img, img2img): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>),
206 q: CallbackQuery,
207) -> anyhow::Result<()> {
208 let (message, data) = match q {
209 CallbackQuery {
210 message: Some(message),
211 data: Some(data),
212 ..
213 } => (message, data),
214 _ => {
215 bot.answer_callback_query(q.id)
216 .cache_time(60)
217 .text("Sorry, something went wrong.")
218 .await?;
219 return Ok(());
220 }
221 };
222
223 let setting = match data.strip_prefix("settings_") {
224 Some(setting) => setting,
225 None => {
226 bot.answer_callback_query(q.id)
227 .cache_time(60)
228 .text("Sorry, something went wrong.")
229 .await?;
230 return Ok(());
231 }
232 };
233
234 if setting == "back" {
235 dialogue
236 .update(State::Ready {
237 bot_state: BotState::Generate,
238 txt2img,
239 img2img,
240 })
241 .await
242 .map_err(|e| anyhow!(e))?;
243 if let Err(e) = bot.answer_callback_query(q.id).text("Canceled.").await {
244 warn!("Failed to answer back button callback query: {}", e)
245 }
246
247 if let Err(e) = bot.delete_message(message.chat.id, message.id).await {
248 error!("Failed to delete message: {:?}", e);
249 bot.edit_message_text(message.chat.id, message.id, "Please enter a prompt.")
250 .reply_markup(InlineKeyboardMarkup::new([[]]))
251 .await?;
252 }
253 return Ok(());
254 }
255
256 let mut state = dialogue
257 .get()
258 .await
259 .map_err(|e| anyhow!(e))?
260 .unwrap_or_else(|| {
261 State::new_with_defaults(
262 cfg.txt2img_api.gen_params(None),
263 cfg.img2img_api.gen_params(None),
264 )
265 });
266 match &mut state {
267 State::Ready {
268 bot_state: BotState::SettingsTxt2Img { selection },
269 ..
270 }
271 | State::Ready {
272 bot_state: BotState::SettingsImg2Img { selection },
273 ..
274 } => *selection = Some(setting.to_string()),
275 _ => {
276 bot.answer_callback_query(q.id)
277 .cache_time(60)
278 .text("Sorry, something went wrong.")
279 .await?;
280 return Ok(());
281 }
282 }
283
284 if let Err(e) = bot.answer_callback_query(q.id).await {
285 warn!("Failed to answer settings button callback query: {}", e)
286 }
287 dialogue.update(state).await.map_err(|e| anyhow!(e))?;
288
289 bot.send_message(message.chat.id, "Please enter a new value.")
290 .await?;
291
292 Ok(())
293}
294
295fn update_txt2img_setting<S1, S2>(
296 txt2img: &mut dyn GenParams,
297 setting: S1,
298 value: S2,
299) -> anyhow::Result<()>
300where
301 S1: AsRef<str>,
302 S2: AsRef<str>,
303{
304 let value = value.as_ref();
305 match setting.as_ref() {
306 "steps" => txt2img.set_steps(value.parse()?),
307 "seed" => txt2img.set_seed(value.parse()?),
308 "count" => txt2img.set_count(value.parse()?),
309 "cfg" => txt2img.set_cfg(value.parse()?),
310 "width" => txt2img.set_width(value.parse()?),
311 "height" => txt2img.set_height(value.parse()?),
312 "negative" => txt2img.set_negative_prompt(value.to_owned()),
313 "denoising" => txt2img.set_denoising(value.parse()?),
314 _ => return Err(anyhow!("Got invalid setting: {}", setting.as_ref())),
315 }
316 Ok(())
317}
318
319fn update_img2img_setting<S1, S2>(
320 img2img: &mut dyn GenParams,
321 setting: S1,
322 value: S2,
323) -> anyhow::Result<()>
324where
325 S1: AsRef<str>,
326 S2: AsRef<str>,
327{
328 let value = value.as_ref();
329 match setting.as_ref() {
330 "steps" => img2img.set_steps(200.min(value.parse()?)),
331 "seed" => img2img.set_seed((-1).max(value.parse()?)),
332 "count" => img2img.set_count(value.parse::<u32>()?.clamp(1, 10)),
333 "cfg" => img2img.set_cfg(value.parse::<f32>()?.clamp(0.0, 20.0)),
334 "width" => img2img.set_width({
335 let mut value = value.parse::<u32>()?;
336 value -= value % 64;
337 value.clamp(64, 1024)
338 }),
339 "height" => img2img.set_height({
340 let mut value = value.parse::<u32>()?;
341 value -= value % 64;
342 value.clamp(64, 1024)
343 }),
344 "negative" => img2img.set_negative_prompt(value.to_owned()),
345 "denoising" => img2img.set_denoising(value.parse::<f32>()?.clamp(0.0, 1.0)),
346 _ => return Err(anyhow!("invalid setting: {}", setting.as_ref())),
347 }
348 Ok(())
349}
350
351pub(crate) fn state_or_default() -> UpdateHandler<anyhow::Error> {
352 dptree::map_async(
353 |cfg: ConfigParameters, dialogue: DiffusionDialogue| async move {
354 let result = dialogue.get().await;
355 if let Err(ref err) = result {
356 error!("Failed to get state: {:?}", err);
357 }
358 result.ok().flatten().unwrap_or_else(|| {
359 State::new_with_defaults(
360 cfg.txt2img_api.gen_params(None),
361 cfg.img2img_api.gen_params(None),
362 )
363 })
364 },
365 )
366}
367
368pub(crate) async fn update_settings_value(
369 bot: Bot,
370 dialogue: DiffusionDialogue,
371 chat_id: ChatId,
372 settings: Settings,
373 state: State,
374) -> anyhow::Result<()> {
375 dialogue.update(state).await.map_err(|e| anyhow!(e))?;
376
377 bot.send_message(chat_id, "Please make a selection.")
378 .reply_markup(settings.keyboard())
379 .await?;
380
381 Ok(())
382}
383
384pub(crate) async fn handle_txt2img_settings_value(
385 bot: Bot,
386 dialogue: DiffusionDialogue,
387 msg: Message,
388 text: String,
389 (selection, mut txt2img, img2img): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>),
390) -> anyhow::Result<()> {
391 if let Some(ref setting) = selection {
392 if let Err(e) = update_txt2img_setting(txt2img.as_mut(), setting, text) {
393 bot.send_message(msg.chat.id, format!("Please enter a valid value: {e:?}."))
394 .await?;
395 return Ok(());
396 }
397 }
398
399 let bot_state = BotState::SettingsTxt2Img { selection: None };
400
401 update_settings_value(
402 bot,
403 dialogue,
404 msg.chat.id,
405 Settings::from(txt2img.as_ref()),
406 State::Ready {
407 bot_state,
408 txt2img,
409 img2img,
410 },
411 )
412 .await
413}
414
415pub(crate) async fn handle_img2img_settings_value(
416 bot: Bot,
417 dialogue: DiffusionDialogue,
418 msg: Message,
419 text: String,
420 (selection, txt2img, mut img2img): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>),
421) -> anyhow::Result<()> {
422 if let Some(ref setting) = selection {
423 if let Err(e) = update_img2img_setting(img2img.as_mut(), setting, text) {
424 bot.send_message(msg.chat.id, format!("Please enter a valid value: {e:?}."))
425 .await?;
426 return Ok(());
427 }
428 }
429
430 let bot_state = BotState::SettingsImg2Img { selection: None };
431
432 update_settings_value(
433 bot,
434 dialogue,
435 msg.chat.id,
436 Settings::from(img2img.as_ref()),
437 State::Ready {
438 bot_state,
439 txt2img,
440 img2img,
441 },
442 )
443 .await
444}
445
446pub(crate) fn map_settings() -> UpdateHandler<anyhow::Error> {
447 dptree::map(|cfg: ConfigParameters, state: State| match state {
448 State::Ready {
449 txt2img, img2img, ..
450 } => (txt2img, img2img),
451 State::New => (
452 cfg.txt2img_api.gen_params(None),
453 cfg.img2img_api.gen_params(None),
454 ),
455 })
456}
457
458async fn handle_img2img_settings_command(
459 msg: Message,
460 bot: Bot,
461 dialogue: DiffusionDialogue,
462 (txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
463) -> anyhow::Result<()> {
464 let settings = Settings::from(img2img.as_ref());
465 dialogue
466 .update(State::Ready {
467 bot_state: BotState::SettingsImg2Img { selection: None },
468 txt2img,
469 img2img,
470 })
471 .await
472 .map_err(|e| anyhow!(e))?;
473 bot.send_message(msg.chat.id, "Please make a selection.")
474 .reply_markup(settings.keyboard())
475 .send()
476 .await?;
477 Ok(())
478}
479
480async fn handle_txt2img_settings_command(
481 msg: Message,
482 bot: Bot,
483 dialogue: DiffusionDialogue,
484 (txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
485) -> anyhow::Result<()> {
486 let settings = Settings::from(txt2img.as_ref());
487 dialogue
488 .update(State::Ready {
489 bot_state: BotState::SettingsTxt2Img { selection: None },
490 txt2img,
491 img2img,
492 })
493 .await
494 .map_err(|e| anyhow!(e))?;
495 bot.send_message(msg.chat.id, "Please make a selection.")
496 .reply_markup(settings.keyboard())
497 .send()
498 .await?;
499 Ok(())
500}
501
502async fn handle_invalid_setting_value(bot: Bot, msg: Message) -> anyhow::Result<()> {
503 bot.send_message(msg.chat.id, "Please enter a valid value.")
504 .await?;
505 Ok(())
506}
507
508pub(crate) fn settings_command_handler() -> UpdateHandler<anyhow::Error> {
509 Update::filter_message()
510 .filter_command::<SettingsCommands>()
511 .chain(state_or_default())
512 .chain(map_settings())
513 .branch(case![SettingsCommands::Txt2ImgSettings].endpoint(handle_txt2img_settings_command))
514 .branch(case![SettingsCommands::Img2ImgSettings].endpoint(handle_img2img_settings_command))
515}
516
517pub(crate) fn filter_settings_callback_query() -> UpdateHandler<anyhow::Error> {
518 Update::filter_callback_query()
519 .filter(|q: CallbackQuery| q.data.is_some_and(|data| data.starts_with("settings")))
520}
521
522pub(crate) fn filter_settings_state() -> UpdateHandler<anyhow::Error> {
523 dptree::filter(|state: State| {
524 let bot_state = match state {
525 State::Ready { bot_state, .. } => bot_state,
526 _ => return false,
527 };
528 matches!(
529 bot_state,
530 BotState::SettingsTxt2Img { .. } | BotState::SettingsImg2Img { .. }
531 )
532 })
533}
534
535pub(crate) fn filter_map_settings_state() -> UpdateHandler<anyhow::Error> {
536 dptree::filter_map(|state: State| {
537 let (bot_state, txt2img, img2img) = match state {
538 State::Ready {
539 bot_state,
540 txt2img,
541 img2img,
542 } => (bot_state, txt2img, img2img),
543 _ => return None,
544 };
545 match bot_state {
546 BotState::SettingsTxt2Img { selection } => Some((selection, txt2img, img2img)),
547 BotState::SettingsImg2Img { selection } => Some((selection, txt2img, img2img)),
548 _ => None,
549 }
550 })
551}
552
553pub(crate) fn settings_schema() -> UpdateHandler<anyhow::Error> {
554 let callback_handler = filter_settings_callback_query()
555 .branch(
556 filter_map_bot_state()
557 .chain(case![BotState::Generate])
558 .chain(filter_map_settings())
559 .branch(
560 filter_callback_query_chat_id()
561 .branch(filter_callback_query_parent().endpoint(handle_settings))
562 .endpoint(handle_parent_unavailable),
563 )
564 .endpoint(handle_message_expired),
565 )
566 .branch(filter_map_settings_state().endpoint(handle_settings_button));
567
568 let message_handler = Update::filter_message()
569 .branch(
570 Message::filter_text()
571 .chain(filter_map_settings_state())
572 .chain(state_or_default())
573 .chain(filter_map_bot_state())
574 .branch(
575 case![BotState::SettingsTxt2Img { selection }]
576 .endpoint(handle_txt2img_settings_value),
577 )
578 .branch(
579 case![BotState::SettingsImg2Img { selection }]
580 .endpoint(handle_img2img_settings_value),
581 )
582 .endpoint(|| async { Err(anyhow!("Invalid settings state")) }),
583 )
584 .branch(filter_settings_state().endpoint(handle_invalid_setting_value));
585
586 dptree::entry()
587 .branch(settings_command_handler())
588 .branch(message_handler)
589 .branch(callback_handler)
590}
591
592#[cfg(test)]
593mod tests {
594 use async_trait::async_trait;
595 use sal_e_api::{
596 Img2ImgApi, Img2ImgApiError, Img2ImgParams, Response, Txt2ImgApi, Txt2ImgApiError,
597 Txt2ImgParams,
598 };
599 use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest};
600 use teloxide::types::{UpdateKind, User};
601
602 use super::*;
603 use crate::BotState;
604
605 fn create_callback_query_update(data: Option<String>) -> Update {
606 let query = CallbackQuery {
607 id: "123456".to_string(),
608 from: User {
609 id: UserId(123456780),
610 is_bot: true,
611 first_name: "Stable Diffusion".to_string(),
612 last_name: None,
613 username: Some("sdbot".to_string()),
614 language_code: Some("en".to_string()),
615 is_premium: false,
616 added_to_attachment_menu: false,
617 },
618 message: None,
619 inline_message_id: None,
620 chat_instance: "123456".to_string(),
621 data,
622 game_short_name: None,
623 };
624
625 Update {
626 id: 1,
627 kind: UpdateKind::CallbackQuery(query),
628 }
629 }
630
631 #[tokio::test]
632 async fn test_filter_settings_query() {
633 let update = create_callback_query_update(Some("settings".to_string()));
634
635 assert!(matches!(
636 filter_settings_callback_query()
637 .endpoint(|| async { anyhow::Ok(()) })
638 .dispatch(dptree::deps![update])
639 .await,
640 ControlFlow::Break(_)
641 ));
642 }
643
644 #[tokio::test]
645 async fn test_filter_settings_query_none() {
646 let update = create_callback_query_update(None);
647
648 assert!(matches!(
649 filter_settings_callback_query()
650 .endpoint(|| async { anyhow::Ok(()) })
651 .dispatch(dptree::deps![update])
652 .await,
653 ControlFlow::Continue(_)
654 ));
655 }
656
657 #[tokio::test]
658 async fn test_filter_settings_query_bad_data() {
659 let update = create_callback_query_update(Some("bad_data".to_string()));
660
661 assert!(matches!(
662 filter_settings_callback_query()
663 .endpoint(|| async { anyhow::Ok(()) })
664 .dispatch(dptree::deps![update])
665 .await,
666 ControlFlow::Continue(_)
667 ));
668 }
669
670 #[tokio::test]
671 async fn test_filter_settings_state_txt2img() {
672 assert!(matches!(
673 filter_settings_state()
674 .endpoint(|| async { anyhow::Ok(()) })
675 .dispatch(dptree::deps![State::Ready {
676 bot_state: BotState::SettingsTxt2Img { selection: None },
677 txt2img: Box::<Txt2ImgParams>::default(),
678 img2img: Box::<Img2ImgParams>::default()
679 }])
680 .await,
681 ControlFlow::Break(_)
682 ));
683 }
684
685 #[tokio::test]
686 async fn test_filter_settings_state_img2img() {
687 assert!(matches!(
688 filter_settings_state()
689 .endpoint(|| async { anyhow::Ok(()) })
690 .dispatch(dptree::deps![State::Ready {
691 bot_state: BotState::SettingsImg2Img { selection: None },
692 txt2img: Box::<Txt2ImgParams>::default(),
693 img2img: Box::<Img2ImgParams>::default()
694 }])
695 .await,
696 ControlFlow::Break(_)
697 ));
698 }
699
700 #[tokio::test]
701 async fn test_filter_settings_state() {
702 assert!(matches!(
703 filter_settings_state()
704 .endpoint(|| async { anyhow::Ok(()) })
705 .dispatch(dptree::deps![State::New])
706 .await,
707 ControlFlow::Continue(_)
708 ));
709 }
710
711 #[tokio::test]
712 async fn test_filter_map_settings_state_txt2img() {
713 assert!(matches!(
714 filter_map_settings_state()
715 .endpoint(
716 |(_, _, _): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>)| async {
717 anyhow::Ok(())
718 }
719 )
720 .dispatch(dptree::deps![State::Ready {
721 bot_state: BotState::SettingsTxt2Img { selection: None },
722 txt2img: Box::<Txt2ImgParams>::default(),
723 img2img: Box::<Img2ImgParams>::default()
724 }])
725 .await,
726 ControlFlow::Break(_)
727 ));
728 }
729
730 #[tokio::test]
731 async fn test_filter_map_settings_state_img2img() {
732 assert!(matches!(
733 filter_map_settings_state()
734 .endpoint(
735 |(_, _, _): (Option<String>, Box<dyn GenParams>, Box<dyn GenParams>)| async {
736 anyhow::Ok(())
737 }
738 )
739 .dispatch(dptree::deps![State::Ready {
740 bot_state: BotState::SettingsImg2Img { selection: None },
741 txt2img: Box::<Txt2ImgParams>::default(),
742 img2img: Box::<Img2ImgParams>::default()
743 }])
744 .await,
745 ControlFlow::Break(_)
746 ));
747 }
748
749 #[tokio::test]
750 async fn test_filter_map_settings_state() {
751 assert!(matches!(
752 filter_map_settings_state()
753 .endpoint(
754 |(_, _, _): (Option<String>, &dyn GenParams, &dyn GenParams)| async {
755 anyhow::Ok(())
756 }
757 )
758 .dispatch(dptree::deps![State::New])
759 .await,
760 ControlFlow::Continue(_)
761 ));
762 }
763
764 #[derive(Debug, Clone, Default)]
765 struct MockApi;
766
767 #[async_trait]
768 impl Txt2ImgApi for MockApi {
769 fn gen_params(&self, _user_params: Option<&dyn GenParams>) -> Box<dyn GenParams> {
770 Box::<Txt2ImgParams>::default()
771 }
772
773 async fn txt2img(&self, _config: &dyn GenParams) -> Result<Response, Txt2ImgApiError> {
774 Err(anyhow!("Not implemented"))?
775 }
776 }
777
778 #[async_trait]
779 impl Img2ImgApi for MockApi {
780 fn gen_params(&self, _user_params: Option<&dyn GenParams>) -> Box<dyn GenParams> {
781 Box::<Img2ImgParams>::default()
782 }
783
784 async fn img2img(&self, _config: &dyn GenParams) -> Result<Response, Img2ImgApiError> {
785 Err(anyhow!("Not implemented"))?
786 }
787 }
788
789 #[tokio::test]
790 async fn test_map_settings_default() {
791 assert!(matches!(
792 map_settings()
793 .endpoint(
794 |(txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>)| async move {
795 let txt2img = txt2img
796 .as_ref()
797 .as_any()
798 .downcast_ref::<Txt2ImgParams>()
799 .unwrap();
800 let img2img = img2img
801 .as_ref()
802 .as_any()
803 .downcast_ref::<Img2ImgParams>()
804 .unwrap();
805 assert!(
806 (txt2img, img2img)
807 == (&Txt2ImgParams::default(), &Img2ImgParams::default())
808 );
809 anyhow::Ok(())
810 }
811 )
812 .dispatch(dptree::deps![
813 ConfigParameters {
814 txt2img_api: Box::new(MockApi),
815 img2img_api: Box::new(MockApi),
816 allowed_users: Default::default(),
817 allow_all_users: false
818 },
819 State::New
820 ])
821 .await,
822 ControlFlow::Break(_)
823 ));
824 }
825
826 #[tokio::test]
827 async fn test_map_settings_ready() {
828 let txt2img = Txt2ImgParams {
829 user_params: Txt2ImgRequest {
830 negative_prompt: Some("test".to_string()),
831 ..Txt2ImgRequest::default()
832 },
833 defaults: Some(Txt2ImgRequest::default()),
834 };
835 let img2img = Img2ImgParams {
836 user_params: Img2ImgRequest {
837 negative_prompt: Some("test".to_string()),
838 ..Img2ImgRequest::default()
839 },
840 defaults: Some(Img2ImgRequest::default()),
841 };
842 assert!(matches!(
843 map_settings()
844 .endpoint(
845 |(txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>)| async move {
846 let txt2img = txt2img
847 .as_ref()
848 .as_any()
849 .downcast_ref::<Txt2ImgParams>()
850 .unwrap();
851 let img2img = img2img
852 .as_ref()
853 .as_any()
854 .downcast_ref::<Img2ImgParams>()
855 .unwrap();
856 assert!(
857 (txt2img, img2img)
858 == (
859 &Txt2ImgParams {
860 user_params: Txt2ImgRequest {
861 negative_prompt: Some("test".to_string()),
862 ..Txt2ImgRequest::default()
863 },
864 defaults: Some(Txt2ImgRequest::default()),
865 },
866 &Img2ImgParams {
867 user_params: Img2ImgRequest {
868 negative_prompt: Some("test".to_string()),
869 ..Img2ImgRequest::default()
870 },
871 defaults: Some(Img2ImgRequest::default()),
872 }
873 )
874 );
875 anyhow::Ok(())
876 }
877 )
878 .dispatch(dptree::deps![
879 ConfigParameters {
880 txt2img_api: Box::new(MockApi),
881 img2img_api: Box::new(MockApi),
882 allowed_users: Default::default(),
883 allow_all_users: false
884 },
885 State::Ready {
886 bot_state: BotState::Generate,
887 txt2img: Box::new(txt2img),
888 img2img: Box::new(img2img)
889 }
890 ])
891 .await,
892 ControlFlow::Break(_)
893 ));
894 }
895}