stable_diffusion_bot/bot/handlers/
image.rs

1use anyhow::{anyhow, Context};
2use sal_e_api::{GenParams, ImageParams, Response};
3use teloxide::{
4    dispatching::UpdateHandler,
5    dptree::case,
6    macros::BotCommands,
7    payloads::setters::*,
8    prelude::*,
9    types::{
10        ChatAction, InlineKeyboardButton, InlineKeyboardMarkup, InputFile, InputMedia,
11        InputMediaPhoto, Me, MessageId, PhotoSize,
12    },
13    utils::command::BotCommands as _,
14};
15use tracing::{info, instrument, warn};
16
17use crate::{
18    bot::{helpers, State},
19    BotState,
20};
21
22use super::{
23    filter_command, filter_map_bot_state, filter_map_settings, ConfigParameters, DiffusionDialogue,
24};
25
26/// BotCommands for generating images.
27#[derive(BotCommands, Debug, Clone)]
28#[command(rename_rule = "lowercase", description = "Image generation commands")]
29pub(crate) enum GenCommands {
30    /// Command to generate an image
31    #[command(description = "generate an image")]
32    Gen(String),
33    /// Alias for `gen`. Hidden from help to avoid confusion.
34    #[command(description = "off")]
35    G(String),
36    /// Alias for `gen`. Hidden from help to avoid confusion.
37    #[command(description = "off")]
38    Generate(String),
39}
40
41enum Photo {
42    Single(Vec<u8>),
43    Album(Vec<Vec<u8>>),
44}
45
46impl Photo {
47    #[allow(dead_code)]
48    pub fn single(photo: Vec<u8>) -> anyhow::Result<Self> {
49        Ok(Self::Single(photo))
50    }
51
52    pub fn album(photos: Vec<Vec<u8>>) -> anyhow::Result<Self> {
53        if photos.len() == 1 {
54            let images = photos
55                .into_iter()
56                .next()
57                .ok_or_else(|| anyhow!("Failed to get image"))?;
58            Ok(Photo::Single(images))
59        } else {
60            Ok(Photo::Album(photos))
61        }
62    }
63}
64
65struct Reply {
66    caption: String,
67    images: Photo,
68    source: MessageId,
69    seed: i64,
70}
71
72impl Reply {
73    pub fn new(
74        caption: String,
75        images: Vec<Vec<u8>>,
76        seed: i64,
77        source: MessageId,
78    ) -> anyhow::Result<Self> {
79        let images = Photo::album(images)?;
80        Ok(Self {
81            caption,
82            images,
83            source,
84            seed,
85        })
86    }
87
88    pub async fn send(self, bot: &Bot, chat_id: ChatId) -> anyhow::Result<()> {
89        match self.images {
90            Photo::Single(image) => {
91                bot.send_photo(chat_id, InputFile::memory(image))
92                    .parse_mode(teloxide::types::ParseMode::MarkdownV2)
93                    .caption(self.caption)
94                    .reply_markup(keyboard(self.seed))
95                    .reply_to_message_id(self.source)
96                    .await?;
97            }
98            Photo::Album(images) => {
99                let mut caption = Some(self.caption);
100                let input_media = images.into_iter().map(|i| {
101                    let mut media = InputMediaPhoto::new(InputFile::memory(i));
102                    media.caption = caption.take();
103                    media.parse_mode = Some(teloxide::types::ParseMode::MarkdownV2);
104                    InputMedia::Photo(media)
105                });
106
107                bot.send_media_group(chat_id, input_media)
108                    .reply_to_message_id(self.source)
109                    .await?;
110                bot.send_message(
111                    chat_id,
112                    "What would you like to do? Select below, or enter a new prompt.",
113                )
114                .reply_markup(keyboard(self.seed))
115                .reply_to_message_id(self.source)
116                .await?;
117            }
118        }
119
120        Ok(())
121    }
122}
123
124struct MessageText(String);
125
126impl MessageText {
127    pub fn new_with_image_params(prompt: &str, infotxt: &dyn ImageParams) -> Self {
128        use teloxide::utils::markdown::escape;
129
130        Self(format!(
131            "`{}`\n\n{}",
132            escape(prompt),
133            [
134                infotxt
135                    .negative_prompt()
136                    .as_ref()
137                    .and_then(|s| (!s.trim().is_empty()).then(|| escape(s)))
138                    .map(|s| format!("Negative prompt: `{s}`")),
139                infotxt.steps().map(|s| format!("Steps: `{s}`")),
140                infotxt
141                    .sampler()
142                    .as_ref()
143                    .map(|s| format!("Sampler: `{s}`")),
144                infotxt.cfg().map(|s| format!("CFG scale: `{s}`")),
145                infotxt.seed().map(|s| format!("Seed: `{s}`")),
146                infotxt
147                    .width()
148                    .and_then(|w| infotxt.height().map(|h| format!("Size: `{w}×{h}`"))),
149                infotxt.model().as_ref().map(|s| format!("Model: `{s}`")),
150                infotxt
151                    .denoising()
152                    .map(|s| format!("Denoising strength: `{s}`")),
153            ]
154            .into_iter()
155            .flatten()
156            .collect::<Vec<_>>()
157            .join("\n")
158        ))
159    }
160}
161
162impl TryFrom<&dyn ImageParams> for MessageText {
163    type Error = anyhow::Error;
164
165    fn try_from(params: &dyn ImageParams) -> Result<Self, Self::Error> {
166        let prompt = if let Some(prompt) = params.prompt() {
167            prompt
168        } else {
169            return Err(anyhow!("No prompt in image info response"));
170        };
171        Ok(Self::new_with_image_params(prompt.as_str(), params))
172    }
173}
174
175impl TryFrom<Response> for MessageText {
176    type Error = anyhow::Error;
177
178    fn try_from(response: Response) -> Result<Self, Self::Error> {
179        Self::try_from(response.params.as_ref())
180    }
181}
182
183async fn do_img2img(
184    bot: &Bot,
185    cfg: &ConfigParameters,
186    img2img: &mut Box<dyn GenParams>,
187    msg: &Message,
188    photo: Vec<PhotoSize>,
189    prompt: String,
190) -> anyhow::Result<Response> {
191    img2img.set_prompt(prompt);
192
193    let photo = if let Some(photo) = photo
194        .iter()
195        .reduce(|a, p| if a.height > p.height { a } else { p })
196    {
197        photo
198    } else {
199        bot.send_message(msg.chat.id, "Something went wrong.")
200            .await?;
201        return Err(anyhow!("Photo vec was empty!"));
202    };
203    let file = bot.get_file(&photo.file.id).send().await?;
204
205    let photo = helpers::get_file(bot, &file).await?;
206
207    img2img.set_image(Some(photo.into()));
208
209    let resp = cfg.img2img_api.img2img(img2img.as_ref()).await?;
210
211    img2img.set_image(None);
212
213    Ok(resp)
214}
215
216async fn handle_image(
217    bot: Bot,
218    cfg: ConfigParameters,
219    dialogue: DiffusionDialogue,
220    (txt2img, mut img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
221    msg: Message,
222    photo: Vec<PhotoSize>,
223    text: String,
224) -> anyhow::Result<()> {
225    if text.is_empty() {
226        bot.send_message(msg.chat.id, "A prompt is required.")
227            .reply_to_message_id(msg.id)
228            .await?;
229        return Ok(());
230    }
231
232    bot.send_chat_action(msg.chat.id, ChatAction::UploadPhoto)
233        .await?;
234
235    let resp = do_img2img(&bot, &cfg, &mut img2img, &msg, photo, text).await?;
236
237    let seed = if resp.params.seed() == resp.gen_params.seed() {
238        -1
239    } else {
240        resp.params.seed().unwrap_or(-1)
241    };
242
243    let caption = MessageText::try_from(resp.params.as_ref())
244        .context("Failed to build caption from response")?;
245
246    Reply::new(caption.0, resp.images, seed, msg.id)
247        .context("Failed to create response!")?
248        .send(&bot, msg.chat.id)
249        .await?;
250
251    dialogue
252        .update(State::Ready {
253            bot_state: BotState::default(),
254            txt2img,
255            img2img,
256        })
257        .await
258        .map_err(|e| anyhow!(e))?;
259
260    Ok(())
261}
262
263async fn do_txt2img(
264    prompt: String,
265    cfg: &ConfigParameters,
266    txt2img: &mut dyn GenParams,
267) -> anyhow::Result<Response> {
268    txt2img.set_prompt(prompt);
269
270    let resp = cfg.txt2img_api.txt2img(txt2img).await?;
271
272    Ok(resp)
273}
274
275async fn handle_prompt(
276    bot: Bot,
277    cfg: ConfigParameters,
278    dialogue: DiffusionDialogue,
279    (mut txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
280    msg: Message,
281    text: String,
282) -> anyhow::Result<()> {
283    if text.is_empty() {
284        bot.send_message(msg.chat.id, "A prompt is required.")
285            .reply_to_message_id(msg.id)
286            .await?;
287        return Ok(());
288    }
289
290    bot.send_chat_action(msg.chat.id, ChatAction::UploadPhoto)
291        .await?;
292
293    let resp = do_txt2img(text, &cfg, txt2img.as_mut()).await?;
294
295    let seed = if resp.params.seed() == resp.gen_params.seed() {
296        -1
297    } else {
298        resp.params.seed().unwrap_or(-1)
299    };
300
301    let caption = MessageText::try_from(resp.params.as_ref())
302        .context("Failed to build caption from response")?;
303
304    Reply::new(caption.0, resp.images, seed, msg.id)
305        .context("Failed to create response!")?
306        .send(&bot, msg.chat.id)
307        .await?;
308
309    dialogue
310        .update(State::Ready {
311            bot_state: BotState::default(),
312            txt2img,
313            img2img,
314        })
315        .await
316        .map_err(|e| anyhow!(e))?;
317
318    Ok(())
319}
320
321fn keyboard(seed: i64) -> InlineKeyboardMarkup {
322    let seed_button = if seed == -1 {
323        InlineKeyboardButton::callback("🎲 Seed", "reuse/-1")
324    } else {
325        InlineKeyboardButton::callback("♻️ Seed", format!("reuse/{seed}"))
326    };
327    InlineKeyboardMarkup::new([[
328        InlineKeyboardButton::callback("🔄 Rerun", "rerun"),
329        seed_button,
330        InlineKeyboardButton::callback("⚙️ Settings", "settings"),
331    ]])
332}
333
334#[instrument(skip_all)]
335async fn handle_rerun(
336    me: Me,
337    bot: Bot,
338    cfg: ConfigParameters,
339    dialogue: DiffusionDialogue,
340    (txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
341    q: CallbackQuery,
342) -> anyhow::Result<()> {
343    let message = if let Some(message) = q.message {
344        message
345    } else {
346        bot.answer_callback_query(q.id)
347            .cache_time(60)
348            .text("Sorry, this message is no longer available.")
349            .await?;
350        return Ok(());
351    };
352
353    let id = message.id;
354    let chat_id = message.chat.id;
355
356    let parent = if let Some(parent) = message.reply_to_message().cloned() {
357        parent
358    } else {
359        bot.answer_callback_query(q.id)
360            .cache_time(60)
361            .text("Oops, something went wrong.")
362            .await?;
363        return Ok(());
364    };
365
366    if let Some(photo) = parent.photo().map(ToOwned::to_owned) {
367        if let Some(text) = message.caption().map(ToOwned::to_owned) {
368            let bot_name = me.user.username.expect("Bots must have a username");
369            let text = if let Ok(command) = GenCommands::parse(&text, &bot_name) {
370                match command {
371                    GenCommands::Gen(s) | GenCommands::G(s) | GenCommands::Generate(s) => s,
372                }
373            } else {
374                text
375            };
376
377            if let Err(e) = bot
378                .answer_callback_query(q.id)
379                .cache_time(60)
380                .text("Rerunning this image...")
381                .await
382            {
383                warn!("Failed to answer image rerun callback query: {}", e)
384            }
385            handle_image(
386                bot.clone(),
387                cfg,
388                dialogue,
389                (txt2img, img2img),
390                parent,
391                photo,
392                text,
393            )
394            .await?;
395        } else {
396            bot.send_message(message.chat.id, "A prompt is required to run img2img.")
397                .await?;
398            return Err(anyhow!("No prompt provided for img2img"));
399        }
400    } else if let Some(text) = parent.text().map(ToOwned::to_owned) {
401        if let Err(e) = bot
402            .answer_callback_query(q.id)
403            .cache_time(60)
404            .text("Rerunning this prompt...")
405            .await
406        {
407            warn!("Failed to answer prompt rerun callback query: {}", e)
408        }
409        let bot_name = me.user.username.expect("Bots must have a username");
410        let text = if let Ok(command) = GenCommands::parse(&text, &bot_name) {
411            match command {
412                GenCommands::Gen(s) | GenCommands::G(s) | GenCommands::Generate(s) => s,
413            }
414        } else {
415            text
416        };
417        handle_prompt(bot.clone(), cfg, dialogue, (txt2img, img2img), parent, text).await?;
418    } else {
419        bot.answer_callback_query(q.id)
420            .cache_time(60)
421            .text("Oops, something went wrong.")
422            .await?;
423        return Ok(());
424    }
425
426    bot.edit_message_reply_markup(chat_id, id)
427        .reply_markup(InlineKeyboardMarkup::new([[]]))
428        .send()
429        .await?;
430
431    Ok(())
432}
433
434async fn handle_reuse(
435    bot: Bot,
436    dialogue: DiffusionDialogue,
437    (mut txt2img, mut img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
438    q: CallbackQuery,
439    seed: i64,
440) -> anyhow::Result<()> {
441    let message = if let Some(message) = q.message {
442        message
443    } else {
444        bot.answer_callback_query(q.id)
445            .cache_time(60)
446            .text("Sorry, this message is no longer available.")
447            .await?;
448        return Ok(());
449    };
450
451    let id = message.id;
452    let chat_id = message.chat.id;
453
454    let parent = if let Some(parent) = message.reply_to_message().cloned() {
455        parent
456    } else {
457        bot.answer_callback_query(q.id)
458            .cache_time(60)
459            .text("Oops, something went wrong.")
460            .await?;
461        return Ok(());
462    };
463
464    if parent.photo().is_some() {
465        img2img.set_seed(seed);
466        dialogue
467            .update(State::Ready {
468                bot_state: BotState::default(),
469                txt2img,
470                img2img,
471            })
472            .await
473            .map_err(|e| anyhow!(e))?;
474    } else if parent.text().is_some() {
475        txt2img.set_seed(seed);
476        dialogue
477            .update(State::Ready {
478                bot_state: BotState::default(),
479                txt2img,
480                img2img,
481            })
482            .await
483            .map_err(|e| anyhow!(e))?;
484    } else {
485        bot.answer_callback_query(q.id)
486            .cache_time(60)
487            .text("Oops, something went wrong.")
488            .await?;
489        return Ok(());
490    }
491    if seed == -1 {
492        if let Err(e) = bot
493            .answer_callback_query(q.id)
494            .text("Seed randomized.")
495            .await
496        {
497            warn!("Failed to answer randomize seed callback query: {}", e)
498        }
499    } else {
500        if let Err(e) = bot
501            .answer_callback_query(q.id)
502            .text(format!("Seed set to {seed}."))
503            .await
504        {
505            warn!("Failed to answer set seed callback query: {}", e)
506        }
507        bot.edit_message_reply_markup(chat_id, id)
508            .reply_markup(keyboard(-1))
509            .send()
510            .await?;
511    }
512
513    Ok(())
514}
515
516pub(crate) fn image_schema() -> UpdateHandler<anyhow::Error> {
517    let gen_command_handler = Update::filter_message()
518        .chain(filter_command::<GenCommands>())
519        .chain(dptree::filter_map(|g: GenCommands| match g {
520            GenCommands::Gen(s) | GenCommands::G(s) | GenCommands::Generate(s) => Some(s),
521        }))
522        .branch(Message::filter_photo().endpoint(handle_image))
523        .branch(dptree::endpoint(handle_prompt));
524
525    let message_handler = Update::filter_message()
526        .branch(
527            dptree::filter(|msg: Message| {
528                msg.text().map(|t| t.starts_with('/')).unwrap_or_default()
529            })
530            .endpoint(|msg: Message| async move {
531                info!(
532                    "Ignoring unknown command: {}",
533                    msg.text().unwrap_or_default()
534                );
535                Ok(())
536            }),
537        )
538        .branch(
539            Message::filter_photo()
540                .map(|msg: Message| msg.caption().map(str::to_string).unwrap_or_default())
541                .endpoint(handle_image),
542        )
543        .branch(Message::filter_text().endpoint(handle_prompt));
544
545    let callback_handler = Update::filter_callback_query()
546        .branch(
547            dptree::filter_map(|q: CallbackQuery| {
548                q.data
549                    .filter(|d| d.starts_with("reuse"))
550                    .and_then(|d| d.split('/').skip(1).flat_map(str::parse::<i64>).next())
551            })
552            .endpoint(handle_reuse),
553        )
554        .branch(
555            dptree::filter(|q: CallbackQuery| q.data.filter(|d| d.starts_with("rerun")).is_some())
556                .endpoint(handle_rerun),
557        );
558
559    dptree::entry()
560        .chain(filter_map_bot_state())
561        .chain(case![BotState::Generate])
562        .chain(filter_map_settings())
563        .branch(gen_command_handler)
564        .branch(message_handler)
565        .branch(callback_handler)
566}