1use std::{collections::HashSet, path::PathBuf, sync::Arc};
2
3use anyhow::{anyhow, Context};
4use comfyui_api::comfy::getter::{LoadImageExt, PromptExt, SeedExt};
5use sal_e_api::{ComfyPromptApi, GenParams, Img2ImgApi, StableDiffusionWebUiApi, Txt2ImgApi};
6use serde::{Deserialize, Serialize};
7use teloxide::{
8 dispatching::{
9 dialogue::{
10 serializer::Json, ErasedStorage, GetChatId, InMemStorage, SqliteStorage, Storage,
11 },
12 DpHandlerDescription, UpdateHandler,
13 },
14 prelude::*,
15 types::Update,
16 utils::command::BotCommands,
17};
18use tokio::fs::File;
19use tokio::io::AsyncReadExt;
20use tracing::{error, warn};
21
22use stable_diffusion_api::{Api, Img2ImgRequest, Txt2ImgRequest};
23
24mod handlers;
25mod helpers;
26use handlers::*;
27
28#[derive(Clone, Serialize, Deserialize, Debug, Default)]
29pub(crate) enum State {
30 #[default]
31 New,
32 Ready {
33 bot_state: BotState,
34 txt2img: Box<dyn GenParams>,
35 img2img: Box<dyn GenParams>,
36 },
37}
38
39impl State {
40 fn new_with_defaults(txt2img: Box<dyn GenParams>, img2img: Box<dyn GenParams>) -> Self {
41 Self::Ready {
42 txt2img,
43 img2img,
44 bot_state: BotState::Generate,
45 }
46 }
47}
48
49#[derive(Clone, Serialize, Deserialize, Debug, Default)]
50pub(crate) enum BotState {
51 #[default]
52 Generate,
53 SettingsTxt2Img {
54 selection: Option<String>,
55 },
56 SettingsImg2Img {
57 selection: Option<String>,
58 },
59}
60
61fn default_txt2img(txt2img: Txt2ImgRequest) -> Txt2ImgRequest {
62 Txt2ImgRequest {
63 seed: Some(-1),
64 sampler_index: Some("Euler".to_owned()),
65 batch_size: Some(1),
66 n_iter: Some(1),
67 steps: Some(50),
68 cfg_scale: Some(7.0),
69 width: Some(512),
70 height: Some(512),
71 negative_prompt: Some("".to_owned()),
72 ..Default::default()
73 }
74 .merge(txt2img)
75}
76
77fn default_img2img(img2img: Img2ImgRequest) -> Img2ImgRequest {
78 Img2ImgRequest {
79 denoising_strength: Some(0.75),
80 seed: Some(-1),
81 sampler_index: Some("Euler".to_owned()),
82 batch_size: Some(1),
83 n_iter: Some(1),
84 steps: Some(50),
85 cfg_scale: Some(7.0),
86 width: Some(512),
87 height: Some(512),
88 negative_prompt: Some("".to_owned()),
89 ..Default::default()
90 }
91 .merge(img2img)
92}
93
94type DialogueStorage = std::sync::Arc<ErasedStorage<State>>;
95
96type DiffusionDialogue = Dialogue<State, ErasedStorage<State>>;
97
98#[derive(Clone)]
100pub struct StableDiffusionBot {
101 bot: Bot,
102 storage: DialogueStorage,
103 config: ConfigParameters,
104}
105
106impl StableDiffusionBot {
107 fn schema() -> UpdateHandler<anyhow::Error> {
109 Self::enter::<Update, ErasedStorage<State>, _>()
110 .branch(unauth_command_handler())
111 .branch(authenticated_command_handler())
112 }
113
114 fn enter<Upd, S, Output>() -> Handler<'static, DependencyMap, Output, DpHandlerDescription>
118 where
119 S: Storage<State> + ?Sized + Send + Sync + 'static,
120 <S as Storage<State>>::Error: std::fmt::Debug + Send,
121 Upd: GetChatId + Clone + Send + Sync + 'static,
122 Output: Send + Sync + 'static,
123 {
124 dptree::filter_map(|storage: Arc<S>, upd: Upd| {
125 let chat_id = upd.chat_id()?;
126 Some(Dialogue::new(storage, chat_id))
127 })
128 .filter_map_async(
129 |dialogue: Dialogue<State, S>, cfg: ConfigParameters| async move {
130 match dialogue.get().await {
131 Ok(dialogue) => {
132 let mut dialogue = if let Some(dialogue) = dialogue {
133 dialogue
134 } else {
135 return Some(State::new_with_defaults(
136 cfg.txt2img_api.gen_params(None),
137 cfg.img2img_api.gen_params(None),
138 ));
139 };
140 match dialogue {
141 State::New => {}
142 State::Ready {
143 ref mut txt2img,
144 ref mut img2img,
145 ..
146 } => {
147 let txt2img_params = cfg.txt2img_api.gen_params(None);
148 if txt2img.as_any().type_id() != txt2img_params.as_any().type_id() {
149 warn!("txt2img settings type mismatch, resetting to default");
150 *txt2img = txt2img_params;
151 } else {
152 *txt2img = cfg.txt2img_api.gen_params(Some(txt2img.as_ref()));
153 }
154 let img2img_params = cfg.img2img_api.gen_params(None);
155 if img2img.as_any().type_id() != img2img_params.as_any().type_id() {
156 warn!("img2img settings type mismatch, resetting to default");
157 *img2img = img2img_params;
158 } else {
159 *img2img = cfg.img2img_api.gen_params(Some(img2img.as_ref()));
160 }
161 }
162 }
163 Some(dialogue)
164 }
165 Err(err) => {
166 error!("dialogue.get() failed: {:?}", err);
167 let defaults = State::new_with_defaults(
168 cfg.txt2img_api.gen_params(None),
169 cfg.img2img_api.gen_params(None),
170 );
171 match dialogue.update(defaults.clone()).await {
172 Ok(_) => {
173 warn!("dialogue reset to default state: {:?}", defaults);
174 Some(defaults)
175 }
176 Err(err) => {
177 error!("dialogue.update() failed: {:?}", err);
178 None
179 }
180 }
181 }
182 }
183 },
184 )
185 }
186
187 pub async fn run(self) -> anyhow::Result<()> {
189 let StableDiffusionBot {
190 bot,
191 storage,
192 config,
193 } = self;
194
195 let mut commands = UnauthenticatedCommands::bot_commands();
196 commands.extend(SettingsCommands::bot_commands());
197 commands.extend(GenCommands::bot_commands());
198 bot.set_my_commands(commands)
199 .scope(teloxide::types::BotCommandScope::Default)
200 .await
201 .context("Failed to set bot commands")?;
202
203 Dispatcher::builder(bot, Self::schema())
204 .dependencies(dptree::deps![config, storage])
205 .default_handler(|upd| async move {
206 warn!("Unhandled update: {:?}", upd);
207 })
208 .error_handler(LoggingErrorHandler::with_custom_text(
209 "An error has occurred in the dispatcher",
210 ))
211 .enable_ctrlc_handler()
212 .build()
213 .dispatch()
214 .await;
215
216 Ok(())
217 }
218}
219
220#[derive(Clone, Debug)]
221pub(crate) struct ConfigParameters {
222 allowed_users: HashSet<ChatId>,
223 txt2img_api: Box<dyn sal_e_api::Txt2ImgApi>,
224 img2img_api: Box<dyn sal_e_api::Img2ImgApi>,
225 allow_all_users: bool,
226}
227
228impl ConfigParameters {
229 pub fn chat_is_allowed(&self, chat_id: &ChatId) -> bool {
231 self.allow_all_users || self.allowed_users.contains(chat_id)
232 }
233}
234
235#[derive(Serialize, Deserialize, Default, Debug)]
237pub enum ApiType {
238 ComfyUI,
240 #[default]
242 StableDiffusionWebUi,
243}
244
245#[derive(Serialize, Deserialize, Default, Debug)]
247pub struct ComfyUIConfig {
248 pub txt2img_prompt_file: Option<PathBuf>,
250 pub img2img_prompt_file: Option<PathBuf>,
252}
253
254pub struct StableDiffusionBotBuilder {
256 api_key: String,
257 allowed_users: Vec<i64>,
258 db_path: Option<String>,
259 sd_api_url: String,
260 api_type: ApiType,
261 txt2img_defaults: Option<Txt2ImgRequest>,
262 img2img_defaults: Option<Img2ImgRequest>,
263 comfyui_img2img_prompt_file: Option<PathBuf>,
264 comfyui_txt2img_prompt_file: Option<PathBuf>,
265 allow_all_users: bool,
266}
267
268impl StableDiffusionBotBuilder {
269 pub fn new(
271 api_key: String,
272 allowed_users: Vec<i64>,
273 sd_api_url: String,
274 api_type: ApiType,
275 allow_all_users: bool,
276 ) -> Self {
277 StableDiffusionBotBuilder {
278 api_key,
279 allowed_users,
280 db_path: None,
281 sd_api_url,
282 txt2img_defaults: None,
283 img2img_defaults: None,
284 allow_all_users,
285 api_type,
286 comfyui_txt2img_prompt_file: None,
287 comfyui_img2img_prompt_file: None,
288 }
289 }
290
291 pub fn db_path(mut self, path: Option<String>) -> Self {
312 self.db_path = path;
313 self
314 }
315
316 pub fn txt2img_defaults(mut self, request: Txt2ImgRequest) -> Self {
339 self.txt2img_defaults = Some(self.txt2img_defaults.unwrap_or_default().merge(request));
340 self
341 }
342
343 pub fn clear_txt2img_defaults(mut self) -> Self {
345 self.txt2img_defaults = None;
346 self
347 }
348
349 pub fn img2img_defaults(mut self, request: Img2ImgRequest) -> Self {
372 self.img2img_defaults = Some(self.img2img_defaults.unwrap_or_default().merge(request));
373 self
374 }
375
376 pub fn clear_img2img_defaults(mut self) -> Self {
378 self.img2img_defaults = None;
379 self
380 }
381
382 pub fn comfyui_config(
383 mut self,
384 ComfyUIConfig {
385 txt2img_prompt_file,
386 img2img_prompt_file,
387 }: ComfyUIConfig,
388 ) -> Self {
389 self.comfyui_txt2img_prompt_file = txt2img_prompt_file;
390 self.comfyui_img2img_prompt_file = img2img_prompt_file;
391 self
392 }
393
394 pub async fn build(self) -> anyhow::Result<StableDiffusionBot> {
412 let storage: DialogueStorage = if let Some(path) = self.db_path {
413 SqliteStorage::open(&path, Json)
414 .await
415 .context("failed to open db")?
416 .erase()
417 } else {
418 InMemStorage::new().erase()
419 };
420
421 let bot = Bot::new(self.api_key.clone());
422
423 let allowed_users = self.allowed_users.into_iter().map(ChatId).collect();
424
425 let client = reqwest::Client::new();
426
427 let (txt2img_api, img2img_api): (Box<dyn Txt2ImgApi>, Box<dyn Img2ImgApi>) = match self
428 .api_type
429 {
430 ApiType::ComfyUI => {
431 let mut txt2img_prompt = String::new();
432
433 File::open(
434 self.comfyui_txt2img_prompt_file
435 .ok_or_else(|| anyhow!("No ComfyUI txt2img prompt file provided."))?,
436 )
437 .await
438 .context("Failed to open comfyui txt2img prompt file")?
439 .read_to_string(&mut txt2img_prompt)
440 .await?;
441
442 let mut img2img_prompt = String::new();
443
444 File::open(
445 self.comfyui_img2img_prompt_file
446 .ok_or_else(|| anyhow!("No ComfyUI img2img prompt file provided."))?,
447 )
448 .await
449 .context("Failed to open comfyui img2img prompt file")?
450 .read_to_string(&mut img2img_prompt)
451 .await?;
452
453 let txt2img_prompt =
454 serde_json::from_str::<comfyui_api::models::Prompt>(&txt2img_prompt)
455 .context("Failed to deserialize prompt")?;
456
457 _ = txt2img_prompt
458 .prompt()
459 .context("Failed to find a valid txt2img prompt node.")?;
460 _ = txt2img_prompt
461 .seed()
462 .context("Failed to find a valid txt2img seed node.")?;
463
464 let txt2img_api = ComfyPromptApi::new_with_client_and_url(
465 client.clone(),
466 self.sd_api_url.clone(),
467 txt2img_prompt,
468 )?;
469
470 let img2img_prompt =
471 serde_json::from_str::<comfyui_api::models::Prompt>(&img2img_prompt)
472 .context("Failed to deserialize prompt")?;
473
474 _ = img2img_prompt
475 .prompt()
476 .context("Failed to find a valid img2img prompt node.")?;
477 _ = img2img_prompt
478 .image()
479 .context("Failed to find a valid img2img image node.")?;
480 _ = img2img_prompt
481 .seed()
482 .context("Failed to find a valid img2img seed node.")?;
483
484 let img2img_api = ComfyPromptApi::new_with_client_and_url(
485 client,
486 self.sd_api_url,
487 img2img_prompt,
488 )
489 .context("Failed to create ComfyUI client")?;
490 (Box::new(txt2img_api), Box::new(img2img_api))
491 }
492 ApiType::StableDiffusionWebUi => {
493 let api = Api::new_with_client_and_url(client, self.sd_api_url)
494 .context("Failed to initialize sd api")?;
495 let txt2img_api = StableDiffusionWebUiApi {
496 client: api.clone(),
497 txt2img_defaults: default_txt2img(
498 self.txt2img_defaults.clone().unwrap_or_default(),
499 ),
500 img2img_defaults: default_img2img(
501 self.img2img_defaults.clone().unwrap_or_default(),
502 ),
503 };
504
505 let img2img_api = StableDiffusionWebUiApi {
506 client: api,
507 txt2img_defaults: default_txt2img(self.txt2img_defaults.unwrap_or_default()),
508 img2img_defaults: default_img2img(self.img2img_defaults.unwrap_or_default()),
509 };
510
511 (Box::new(txt2img_api), Box::new(img2img_api))
512 }
513 };
514
515 let parameters = ConfigParameters {
516 allowed_users,
517 txt2img_api,
518 img2img_api,
519 allow_all_users: self.allow_all_users,
520 };
521
522 Ok(StableDiffusionBot {
523 bot,
524 storage,
525 config: parameters,
526 })
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest};
535
536 #[tokio::test]
537 async fn test_stable_diffusion_bot_builder() {
538 let api_key = "api_key".to_string();
539 let sd_api_url = "http://localhost:7860".to_string();
540 let allowed_users = vec![1, 2, 3];
541 let allow_all_users = false;
542 let api_type = ApiType::StableDiffusionWebUi;
543
544 let builder = StableDiffusionBotBuilder::new(
545 api_key,
546 allowed_users,
547 sd_api_url,
548 api_type,
549 allow_all_users,
550 );
551
552 let bot = builder
553 .db_path(Some("database.sqlite".to_string()))
554 .build()
555 .await
556 .unwrap();
557
558 assert_eq!(bot.config.allowed_users.len(), 3);
559 assert!(!bot.config.allow_all_users);
560 }
561
562 #[tokio::test]
563 async fn test_stable_diffusion_bot_defaults() {
564 let api_key = "api_key".to_string();
565 let sd_api_url = "http://localhost:7860".to_string();
566 let allowed_users = vec![1, 2, 3];
567 let allow_all_users = false;
568 let api_type = ApiType::StableDiffusionWebUi;
569
570 let builder = StableDiffusionBotBuilder::new(
571 api_key.clone(),
572 allowed_users.clone(),
573 sd_api_url.clone(),
574 api_type,
575 allow_all_users,
576 );
577
578 let bot = builder.build().await.unwrap();
579
580 assert_eq!(
581 bot.config.allowed_users,
582 allowed_users.into_iter().map(ChatId).collect()
583 );
584 assert_eq!(bot.config.allow_all_users, allow_all_users);
585 assert_eq!(
586 bot.config
587 .txt2img_api
588 .as_any()
589 .downcast_ref::<StableDiffusionWebUiApi>()
590 .unwrap()
591 .txt2img_defaults,
592 default_txt2img(Txt2ImgRequest::default())
593 );
594 assert_eq!(
595 bot.config
596 .img2img_api
597 .as_any()
598 .downcast_ref::<StableDiffusionWebUiApi>()
599 .unwrap()
600 .img2img_defaults,
601 default_img2img(Img2ImgRequest::default())
602 );
603 }
604
605 #[tokio::test]
606 async fn test_stable_diffusion_bot_user_defaults() {
607 let api_key = "api_key".to_string();
608 let sd_api_url = "http://localhost:7860".to_string();
609 let allowed_users = vec![1, 2, 3];
610 let allow_all_users = false;
611 let api_type = ApiType::StableDiffusionWebUi;
612
613 let txt2img_settings = Txt2ImgRequest {
614 width: Some(1024),
615 height: Some(768),
616 ..Default::default()
617 };
618 let img2img_settings = Img2ImgRequest {
619 width: Some(1024),
620 height: Some(768),
621 ..Default::default()
622 };
623
624 let builder = StableDiffusionBotBuilder::new(
625 api_key.clone(),
626 allowed_users.clone(),
627 sd_api_url.clone(),
628 api_type,
629 allow_all_users,
630 );
631
632 let bot = builder
633 .txt2img_defaults(txt2img_settings.clone())
634 .img2img_defaults(img2img_settings.clone())
635 .build()
636 .await
637 .unwrap();
638
639 assert_eq!(
640 bot.config.allowed_users,
641 allowed_users.into_iter().map(ChatId).collect()
642 );
643 assert_eq!(bot.config.allow_all_users, allow_all_users);
644 assert_eq!(
645 bot.config
646 .txt2img_api
647 .as_any()
648 .downcast_ref::<StableDiffusionWebUiApi>()
649 .unwrap()
650 .txt2img_defaults,
651 default_txt2img(txt2img_settings)
652 );
653 assert_eq!(
654 bot.config
655 .img2img_api
656 .as_any()
657 .downcast_ref::<StableDiffusionWebUiApi>()
658 .unwrap()
659 .img2img_defaults,
660 default_img2img(img2img_settings)
661 );
662 }
663
664 #[tokio::test]
665 async fn test_stable_diffusion_bot_no_user_defaults() {
666 let api_key = "api_key".to_string();
667 let sd_api_url = "http://localhost:7860".to_string();
668 let allowed_users = vec![1, 2, 3];
669 let allow_all_users = false;
670 let api_type = ApiType::StableDiffusionWebUi;
671
672 let builder = StableDiffusionBotBuilder::new(
673 api_key.clone(),
674 allowed_users.clone(),
675 sd_api_url.clone(),
676 api_type,
677 allow_all_users,
678 );
679
680 let bot = builder
681 .txt2img_defaults(Txt2ImgRequest {
682 width: Some(1024),
683 height: Some(768),
684 ..Default::default()
685 })
686 .img2img_defaults(Img2ImgRequest {
687 width: Some(1024),
688 height: Some(768),
689 ..Default::default()
690 })
691 .clear_txt2img_defaults()
692 .clear_img2img_defaults()
693 .build()
694 .await
695 .unwrap();
696
697 assert_eq!(
698 bot.config.allowed_users,
699 allowed_users.into_iter().map(ChatId).collect()
700 );
701 assert_eq!(bot.config.allow_all_users, allow_all_users);
702 assert_eq!(
703 bot.config
704 .txt2img_api
705 .as_any()
706 .downcast_ref::<StableDiffusionWebUiApi>()
707 .unwrap()
708 .txt2img_defaults,
709 default_txt2img(Txt2ImgRequest::default())
710 );
711 assert_eq!(
712 bot.config
713 .img2img_api
714 .as_any()
715 .downcast_ref::<StableDiffusionWebUiApi>()
716 .unwrap()
717 .img2img_defaults,
718 default_img2img(Img2ImgRequest::default())
719 );
720 }
721}