stable_diffusion_bot/bot/
mod.rs

1use std::{collections::HashSet, path::PathBuf, sync::Arc};
2
3use anyhow::{anyhow, Context};
4use comfyui_api::comfy::getter::{LoadImageExt, PromptExt, SeedExt};
5use sal_e_api::{ComfyPromptApi, GenParams, Img2ImgApi, StableDiffusionWebUiApi, Txt2ImgApi};
6use serde::{Deserialize, Serialize};
7use teloxide::{
8    dispatching::{
9        dialogue::{
10            serializer::Json, ErasedStorage, GetChatId, InMemStorage, SqliteStorage, Storage,
11        },
12        DpHandlerDescription, UpdateHandler,
13    },
14    prelude::*,
15    types::Update,
16    utils::command::BotCommands,
17};
18use tokio::fs::File;
19use tokio::io::AsyncReadExt;
20use tracing::{error, warn};
21
22use stable_diffusion_api::{Api, Img2ImgRequest, Txt2ImgRequest};
23
24mod handlers;
25mod helpers;
26use handlers::*;
27
28#[derive(Clone, Serialize, Deserialize, Debug, Default)]
29pub(crate) enum State {
30    #[default]
31    New,
32    Ready {
33        bot_state: BotState,
34        txt2img: Box<dyn GenParams>,
35        img2img: Box<dyn GenParams>,
36    },
37}
38
39impl State {
40    fn new_with_defaults(txt2img: Box<dyn GenParams>, img2img: Box<dyn GenParams>) -> Self {
41        Self::Ready {
42            txt2img,
43            img2img,
44            bot_state: BotState::Generate,
45        }
46    }
47}
48
49#[derive(Clone, Serialize, Deserialize, Debug, Default)]
50pub(crate) enum BotState {
51    #[default]
52    Generate,
53    SettingsTxt2Img {
54        selection: Option<String>,
55    },
56    SettingsImg2Img {
57        selection: Option<String>,
58    },
59}
60
61fn default_txt2img(txt2img: Txt2ImgRequest) -> Txt2ImgRequest {
62    Txt2ImgRequest {
63        seed: Some(-1),
64        sampler_index: Some("Euler".to_owned()),
65        batch_size: Some(1),
66        n_iter: Some(1),
67        steps: Some(50),
68        cfg_scale: Some(7.0),
69        width: Some(512),
70        height: Some(512),
71        negative_prompt: Some("".to_owned()),
72        ..Default::default()
73    }
74    .merge(txt2img)
75}
76
77fn default_img2img(img2img: Img2ImgRequest) -> Img2ImgRequest {
78    Img2ImgRequest {
79        denoising_strength: Some(0.75),
80        seed: Some(-1),
81        sampler_index: Some("Euler".to_owned()),
82        batch_size: Some(1),
83        n_iter: Some(1),
84        steps: Some(50),
85        cfg_scale: Some(7.0),
86        width: Some(512),
87        height: Some(512),
88        negative_prompt: Some("".to_owned()),
89        ..Default::default()
90    }
91    .merge(img2img)
92}
93
94type DialogueStorage = std::sync::Arc<ErasedStorage<State>>;
95
96type DiffusionDialogue = Dialogue<State, ErasedStorage<State>>;
97
98/// Struct to run a StableDiffusionBot
99#[derive(Clone)]
100pub struct StableDiffusionBot {
101    bot: Bot,
102    storage: DialogueStorage,
103    config: ConfigParameters,
104}
105
106impl StableDiffusionBot {
107    /// Creates an UpdateHandler for the bot
108    fn schema() -> UpdateHandler<anyhow::Error> {
109        Self::enter::<Update, ErasedStorage<State>, _>()
110            .branch(unauth_command_handler())
111            .branch(authenticated_command_handler())
112    }
113
114    // Borrowed and adapted from Teloxide's `dialogue::enter()` function.
115    // Instead of building a default dialogue if one doesn't exist via `get_or_default()`,
116    // we build a dialogue with the defaults that are defined in the `ConfigParameters`.
117    fn enter<Upd, S, Output>() -> Handler<'static, DependencyMap, Output, DpHandlerDescription>
118    where
119        S: Storage<State> + ?Sized + Send + Sync + 'static,
120        <S as Storage<State>>::Error: std::fmt::Debug + Send,
121        Upd: GetChatId + Clone + Send + Sync + 'static,
122        Output: Send + Sync + 'static,
123    {
124        dptree::filter_map(|storage: Arc<S>, upd: Upd| {
125            let chat_id = upd.chat_id()?;
126            Some(Dialogue::new(storage, chat_id))
127        })
128        .filter_map_async(
129            |dialogue: Dialogue<State, S>, cfg: ConfigParameters| async move {
130                match dialogue.get().await {
131                    Ok(dialogue) => {
132                        let mut dialogue = if let Some(dialogue) = dialogue {
133                            dialogue
134                        } else {
135                            return Some(State::new_with_defaults(
136                                cfg.txt2img_api.gen_params(None),
137                                cfg.img2img_api.gen_params(None),
138                            ));
139                        };
140                        match dialogue {
141                            State::New => {}
142                            State::Ready {
143                                ref mut txt2img,
144                                ref mut img2img,
145                                ..
146                            } => {
147                                let txt2img_params = cfg.txt2img_api.gen_params(None);
148                                if txt2img.as_any().type_id() != txt2img_params.as_any().type_id() {
149                                    warn!("txt2img settings type mismatch, resetting to default");
150                                    *txt2img = txt2img_params;
151                                } else {
152                                    *txt2img = cfg.txt2img_api.gen_params(Some(txt2img.as_ref()));
153                                }
154                                let img2img_params = cfg.img2img_api.gen_params(None);
155                                if img2img.as_any().type_id() != img2img_params.as_any().type_id() {
156                                    warn!("img2img settings type mismatch, resetting to default");
157                                    *img2img = img2img_params;
158                                } else {
159                                    *img2img = cfg.img2img_api.gen_params(Some(img2img.as_ref()));
160                                }
161                            }
162                        }
163                        Some(dialogue)
164                    }
165                    Err(err) => {
166                        error!("dialogue.get() failed: {:?}", err);
167                        let defaults = State::new_with_defaults(
168                            cfg.txt2img_api.gen_params(None),
169                            cfg.img2img_api.gen_params(None),
170                        );
171                        match dialogue.update(defaults.clone()).await {
172                            Ok(_) => {
173                                warn!("dialogue reset to default state: {:?}", defaults);
174                                Some(defaults)
175                            }
176                            Err(err) => {
177                                error!("dialogue.update() failed: {:?}", err);
178                                None
179                            }
180                        }
181                    }
182                }
183            },
184        )
185    }
186
187    /// Runs the StableDiffusionBot
188    pub async fn run(self) -> anyhow::Result<()> {
189        let StableDiffusionBot {
190            bot,
191            storage,
192            config,
193        } = self;
194
195        let mut commands = UnauthenticatedCommands::bot_commands();
196        commands.extend(SettingsCommands::bot_commands());
197        commands.extend(GenCommands::bot_commands());
198        bot.set_my_commands(commands)
199            .scope(teloxide::types::BotCommandScope::Default)
200            .await
201            .context("Failed to set bot commands")?;
202
203        Dispatcher::builder(bot, Self::schema())
204            .dependencies(dptree::deps![config, storage])
205            .default_handler(|upd| async move {
206                warn!("Unhandled update: {:?}", upd);
207            })
208            .error_handler(LoggingErrorHandler::with_custom_text(
209                "An error has occurred in the dispatcher",
210            ))
211            .enable_ctrlc_handler()
212            .build()
213            .dispatch()
214            .await;
215
216        Ok(())
217    }
218}
219
220#[derive(Clone, Debug)]
221pub(crate) struct ConfigParameters {
222    allowed_users: HashSet<ChatId>,
223    txt2img_api: Box<dyn sal_e_api::Txt2ImgApi>,
224    img2img_api: Box<dyn sal_e_api::Img2ImgApi>,
225    allow_all_users: bool,
226}
227
228impl ConfigParameters {
229    /// Checks whether a chat is allowed by the config.
230    pub fn chat_is_allowed(&self, chat_id: &ChatId) -> bool {
231        self.allow_all_users || self.allowed_users.contains(chat_id)
232    }
233}
234
235/// Enum representing the types of Stable Diffusion API.
236#[derive(Serialize, Deserialize, Default, Debug)]
237pub enum ApiType {
238    /// ComfyUI API
239    ComfyUI,
240    /// Stable Diffusion Web UI API
241    #[default]
242    StableDiffusionWebUi,
243}
244
245/// Struct that represents the configuration for the ComfyUI API.
246#[derive(Serialize, Deserialize, Default, Debug)]
247pub struct ComfyUIConfig {
248    /// Path to the prompt file for text to image requests.
249    pub txt2img_prompt_file: Option<PathBuf>,
250    /// Path to the prompt file for image to image requests.
251    pub img2img_prompt_file: Option<PathBuf>,
252}
253
254/// Struct that builds a StableDiffusionBot instance.
255pub struct StableDiffusionBotBuilder {
256    api_key: String,
257    allowed_users: Vec<i64>,
258    db_path: Option<String>,
259    sd_api_url: String,
260    api_type: ApiType,
261    txt2img_defaults: Option<Txt2ImgRequest>,
262    img2img_defaults: Option<Img2ImgRequest>,
263    comfyui_img2img_prompt_file: Option<PathBuf>,
264    comfyui_txt2img_prompt_file: Option<PathBuf>,
265    allow_all_users: bool,
266}
267
268impl StableDiffusionBotBuilder {
269    /// Constructor that returns a new StableDiffusionBotBuilder instance.
270    pub fn new(
271        api_key: String,
272        allowed_users: Vec<i64>,
273        sd_api_url: String,
274        api_type: ApiType,
275        allow_all_users: bool,
276    ) -> Self {
277        StableDiffusionBotBuilder {
278            api_key,
279            allowed_users,
280            db_path: None,
281            sd_api_url,
282            txt2img_defaults: None,
283            img2img_defaults: None,
284            allow_all_users,
285            api_type,
286            comfyui_txt2img_prompt_file: None,
287            comfyui_img2img_prompt_file: None,
288        }
289    }
290
291    /// Builder function that sets the path of the storage database for the bot.
292    ///
293    /// # Arguments
294    ///
295    /// * `path` - An optional `String` representing the path to the storage database.
296    ///
297    /// # Examples
298    ///
299    /// ```ignore
300    /// # use stable_diffusion_bot::StableDiffusionBotBuilder;
301    /// # let api_key = "api_key".to_string();
302    /// # let allowed_users = vec![1, 2, 3];
303    /// # let sd_api_url = "http://localhost:7860".to_string();
304    /// # let allow_all_users = false;
305    /// # tokio_test::block_on(async {
306    /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url, allow_all_users);
307    ///
308    /// let bot = builder.db_path(Some("database.sqlite".to_string())).build().await.unwrap();
309    /// # });
310    /// ```
311    pub fn db_path(mut self, path: Option<String>) -> Self {
312        self.db_path = path;
313        self
314    }
315
316    /// Builder function that sets the defaults for text to image requests.
317    ///
318    /// # Arguments
319    ///
320    /// * `request` - A `Txt2ImgRequest` representing the default settings for text to image conversion.
321    ///
322    /// # Examples
323    ///
324    /// ```
325    /// # use stable_diffusion_bot::StableDiffusionBotBuilder;
326    /// # use stable_diffusion_api::Txt2ImgRequest;
327    /// # let api_key = "api_key".to_string();
328    /// # let allowed_users = vec![1, 2, 3];
329    /// # let sd_api_url = "http://localhost:7860".to_string();
330    /// # let allow_all_users = false;
331    /// # let api_type = stable_diffusion_bot::ApiType::StableDiffusionWebUi;
332    /// # tokio_test::block_on(async {
333    /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url, api_type, allow_all_users);
334    ///
335    /// let bot = builder.txt2img_defaults(Txt2ImgRequest::default()).build().await.unwrap();
336    /// # });
337    /// ```
338    pub fn txt2img_defaults(mut self, request: Txt2ImgRequest) -> Self {
339        self.txt2img_defaults = Some(self.txt2img_defaults.unwrap_or_default().merge(request));
340        self
341    }
342
343    /// Builder function that clears the defaults for text to image requests.
344    pub fn clear_txt2img_defaults(mut self) -> Self {
345        self.txt2img_defaults = None;
346        self
347    }
348
349    /// Builder function that sets the defaults for image to image requests.
350    ///
351    /// # Arguments
352    ///
353    /// * `request` - An `Img2ImgRequest` representing the default settings for image to image conversion.
354    ///
355    /// # Examples
356    ///
357    /// ```
358    /// # use stable_diffusion_bot::StableDiffusionBotBuilder;
359    /// # use stable_diffusion_api::Img2ImgRequest;
360    /// # let api_key = "api_key".to_string();
361    /// # let allowed_users = vec![1, 2, 3];
362    /// # let sd_api_url = "http://localhost:7860".to_string();
363    /// # let allow_all_users = false;
364    /// # let api_type = stable_diffusion_bot::ApiType::StableDiffusionWebUi;
365    /// # tokio_test::block_on(async {
366    /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url, api_type, allow_all_users);
367    ///
368    /// let bot = builder.img2img_defaults(Img2ImgRequest::default()).build().await.unwrap();
369    /// # });
370    /// ```
371    pub fn img2img_defaults(mut self, request: Img2ImgRequest) -> Self {
372        self.img2img_defaults = Some(self.img2img_defaults.unwrap_or_default().merge(request));
373        self
374    }
375
376    /// Builder function that clears the defaults for image to image requests.
377    pub fn clear_img2img_defaults(mut self) -> Self {
378        self.img2img_defaults = None;
379        self
380    }
381
382    pub fn comfyui_config(
383        mut self,
384        ComfyUIConfig {
385            txt2img_prompt_file,
386            img2img_prompt_file,
387        }: ComfyUIConfig,
388    ) -> Self {
389        self.comfyui_txt2img_prompt_file = txt2img_prompt_file;
390        self.comfyui_img2img_prompt_file = img2img_prompt_file;
391        self
392    }
393
394    /// Consumes the builder and builds a `StableDiffusionBot` instance.
395    ///
396    /// # Examples
397    ///
398    /// ```
399    /// # use stable_diffusion_bot::StableDiffusionBotBuilder;
400    /// # let api_key = "api_key".to_string();
401    /// # let allowed_users = vec![1, 2, 3];
402    /// # let sd_api_url = "http://localhost:7860".to_string();
403    /// # let allow_all_users = false;
404    /// # let api_type = stable_diffusion_bot::ApiType::StableDiffusionWebUi;
405    /// # tokio_test::block_on(async {
406    /// let builder = StableDiffusionBotBuilder::new(api_key, allowed_users, sd_api_url, api_type, allow_all_users);
407    ///
408    /// let bot = builder.build().await.unwrap();
409    /// # });
410    /// ```
411    pub async fn build(self) -> anyhow::Result<StableDiffusionBot> {
412        let storage: DialogueStorage = if let Some(path) = self.db_path {
413            SqliteStorage::open(&path, Json)
414                .await
415                .context("failed to open db")?
416                .erase()
417        } else {
418            InMemStorage::new().erase()
419        };
420
421        let bot = Bot::new(self.api_key.clone());
422
423        let allowed_users = self.allowed_users.into_iter().map(ChatId).collect();
424
425        let client = reqwest::Client::new();
426
427        let (txt2img_api, img2img_api): (Box<dyn Txt2ImgApi>, Box<dyn Img2ImgApi>) = match self
428            .api_type
429        {
430            ApiType::ComfyUI => {
431                let mut txt2img_prompt = String::new();
432
433                File::open(
434                    self.comfyui_txt2img_prompt_file
435                        .ok_or_else(|| anyhow!("No ComfyUI txt2img prompt file provided."))?,
436                )
437                .await
438                .context("Failed to open comfyui txt2img prompt file")?
439                .read_to_string(&mut txt2img_prompt)
440                .await?;
441
442                let mut img2img_prompt = String::new();
443
444                File::open(
445                    self.comfyui_img2img_prompt_file
446                        .ok_or_else(|| anyhow!("No ComfyUI img2img prompt file provided."))?,
447                )
448                .await
449                .context("Failed to open comfyui img2img prompt file")?
450                .read_to_string(&mut img2img_prompt)
451                .await?;
452
453                let txt2img_prompt =
454                    serde_json::from_str::<comfyui_api::models::Prompt>(&txt2img_prompt)
455                        .context("Failed to deserialize prompt")?;
456
457                _ = txt2img_prompt
458                    .prompt()
459                    .context("Failed to find a valid txt2img prompt node.")?;
460                _ = txt2img_prompt
461                    .seed()
462                    .context("Failed to find a valid txt2img seed node.")?;
463
464                let txt2img_api = ComfyPromptApi::new_with_client_and_url(
465                    client.clone(),
466                    self.sd_api_url.clone(),
467                    txt2img_prompt,
468                )?;
469
470                let img2img_prompt =
471                    serde_json::from_str::<comfyui_api::models::Prompt>(&img2img_prompt)
472                        .context("Failed to deserialize prompt")?;
473
474                _ = img2img_prompt
475                    .prompt()
476                    .context("Failed to find a valid img2img prompt node.")?;
477                _ = img2img_prompt
478                    .image()
479                    .context("Failed to find a valid img2img image node.")?;
480                _ = img2img_prompt
481                    .seed()
482                    .context("Failed to find a valid img2img seed node.")?;
483
484                let img2img_api = ComfyPromptApi::new_with_client_and_url(
485                    client,
486                    self.sd_api_url,
487                    img2img_prompt,
488                )
489                .context("Failed to create ComfyUI client")?;
490                (Box::new(txt2img_api), Box::new(img2img_api))
491            }
492            ApiType::StableDiffusionWebUi => {
493                let api = Api::new_with_client_and_url(client, self.sd_api_url)
494                    .context("Failed to initialize sd api")?;
495                let txt2img_api = StableDiffusionWebUiApi {
496                    client: api.clone(),
497                    txt2img_defaults: default_txt2img(
498                        self.txt2img_defaults.clone().unwrap_or_default(),
499                    ),
500                    img2img_defaults: default_img2img(
501                        self.img2img_defaults.clone().unwrap_or_default(),
502                    ),
503                };
504
505                let img2img_api = StableDiffusionWebUiApi {
506                    client: api,
507                    txt2img_defaults: default_txt2img(self.txt2img_defaults.unwrap_or_default()),
508                    img2img_defaults: default_img2img(self.img2img_defaults.unwrap_or_default()),
509                };
510
511                (Box::new(txt2img_api), Box::new(img2img_api))
512            }
513        };
514
515        let parameters = ConfigParameters {
516            allowed_users,
517            txt2img_api,
518            img2img_api,
519            allow_all_users: self.allow_all_users,
520        };
521
522        Ok(StableDiffusionBot {
523            bot,
524            storage,
525            config: parameters,
526        })
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    use stable_diffusion_api::{Img2ImgRequest, Txt2ImgRequest};
535
536    #[tokio::test]
537    async fn test_stable_diffusion_bot_builder() {
538        let api_key = "api_key".to_string();
539        let sd_api_url = "http://localhost:7860".to_string();
540        let allowed_users = vec![1, 2, 3];
541        let allow_all_users = false;
542        let api_type = ApiType::StableDiffusionWebUi;
543
544        let builder = StableDiffusionBotBuilder::new(
545            api_key,
546            allowed_users,
547            sd_api_url,
548            api_type,
549            allow_all_users,
550        );
551
552        let bot = builder
553            .db_path(Some("database.sqlite".to_string()))
554            .build()
555            .await
556            .unwrap();
557
558        assert_eq!(bot.config.allowed_users.len(), 3);
559        assert!(!bot.config.allow_all_users);
560    }
561
562    #[tokio::test]
563    async fn test_stable_diffusion_bot_defaults() {
564        let api_key = "api_key".to_string();
565        let sd_api_url = "http://localhost:7860".to_string();
566        let allowed_users = vec![1, 2, 3];
567        let allow_all_users = false;
568        let api_type = ApiType::StableDiffusionWebUi;
569
570        let builder = StableDiffusionBotBuilder::new(
571            api_key.clone(),
572            allowed_users.clone(),
573            sd_api_url.clone(),
574            api_type,
575            allow_all_users,
576        );
577
578        let bot = builder.build().await.unwrap();
579
580        assert_eq!(
581            bot.config.allowed_users,
582            allowed_users.into_iter().map(ChatId).collect()
583        );
584        assert_eq!(bot.config.allow_all_users, allow_all_users);
585        assert_eq!(
586            bot.config
587                .txt2img_api
588                .as_any()
589                .downcast_ref::<StableDiffusionWebUiApi>()
590                .unwrap()
591                .txt2img_defaults,
592            default_txt2img(Txt2ImgRequest::default())
593        );
594        assert_eq!(
595            bot.config
596                .img2img_api
597                .as_any()
598                .downcast_ref::<StableDiffusionWebUiApi>()
599                .unwrap()
600                .img2img_defaults,
601            default_img2img(Img2ImgRequest::default())
602        );
603    }
604
605    #[tokio::test]
606    async fn test_stable_diffusion_bot_user_defaults() {
607        let api_key = "api_key".to_string();
608        let sd_api_url = "http://localhost:7860".to_string();
609        let allowed_users = vec![1, 2, 3];
610        let allow_all_users = false;
611        let api_type = ApiType::StableDiffusionWebUi;
612
613        let txt2img_settings = Txt2ImgRequest {
614            width: Some(1024),
615            height: Some(768),
616            ..Default::default()
617        };
618        let img2img_settings = Img2ImgRequest {
619            width: Some(1024),
620            height: Some(768),
621            ..Default::default()
622        };
623
624        let builder = StableDiffusionBotBuilder::new(
625            api_key.clone(),
626            allowed_users.clone(),
627            sd_api_url.clone(),
628            api_type,
629            allow_all_users,
630        );
631
632        let bot = builder
633            .txt2img_defaults(txt2img_settings.clone())
634            .img2img_defaults(img2img_settings.clone())
635            .build()
636            .await
637            .unwrap();
638
639        assert_eq!(
640            bot.config.allowed_users,
641            allowed_users.into_iter().map(ChatId).collect()
642        );
643        assert_eq!(bot.config.allow_all_users, allow_all_users);
644        assert_eq!(
645            bot.config
646                .txt2img_api
647                .as_any()
648                .downcast_ref::<StableDiffusionWebUiApi>()
649                .unwrap()
650                .txt2img_defaults,
651            default_txt2img(txt2img_settings)
652        );
653        assert_eq!(
654            bot.config
655                .img2img_api
656                .as_any()
657                .downcast_ref::<StableDiffusionWebUiApi>()
658                .unwrap()
659                .img2img_defaults,
660            default_img2img(img2img_settings)
661        );
662    }
663
664    #[tokio::test]
665    async fn test_stable_diffusion_bot_no_user_defaults() {
666        let api_key = "api_key".to_string();
667        let sd_api_url = "http://localhost:7860".to_string();
668        let allowed_users = vec![1, 2, 3];
669        let allow_all_users = false;
670        let api_type = ApiType::StableDiffusionWebUi;
671
672        let builder = StableDiffusionBotBuilder::new(
673            api_key.clone(),
674            allowed_users.clone(),
675            sd_api_url.clone(),
676            api_type,
677            allow_all_users,
678        );
679
680        let bot = builder
681            .txt2img_defaults(Txt2ImgRequest {
682                width: Some(1024),
683                height: Some(768),
684                ..Default::default()
685            })
686            .img2img_defaults(Img2ImgRequest {
687                width: Some(1024),
688                height: Some(768),
689                ..Default::default()
690            })
691            .clear_txt2img_defaults()
692            .clear_img2img_defaults()
693            .build()
694            .await
695            .unwrap();
696
697        assert_eq!(
698            bot.config.allowed_users,
699            allowed_users.into_iter().map(ChatId).collect()
700        );
701        assert_eq!(bot.config.allow_all_users, allow_all_users);
702        assert_eq!(
703            bot.config
704                .txt2img_api
705                .as_any()
706                .downcast_ref::<StableDiffusionWebUiApi>()
707                .unwrap()
708                .txt2img_defaults,
709            default_txt2img(Txt2ImgRequest::default())
710        );
711        assert_eq!(
712            bot.config
713                .img2img_api
714                .as_any()
715                .downcast_ref::<StableDiffusionWebUiApi>()
716                .unwrap()
717                .img2img_defaults,
718            default_img2img(Img2ImgRequest::default())
719        );
720    }
721}