1use anyhow::{anyhow, Context};
2use sal_e_api::{GenParams, ImageParams, Response};
3use teloxide::{
4 dispatching::UpdateHandler,
5 dptree::case,
6 macros::BotCommands,
7 payloads::setters::*,
8 prelude::*,
9 types::{
10 ChatAction, InlineKeyboardButton, InlineKeyboardMarkup, InputFile, InputMedia,
11 InputMediaPhoto, Me, MessageId, PhotoSize,
12 },
13 utils::command::BotCommands as _,
14};
15use tracing::{info, instrument, warn};
16
17use crate::{
18 bot::{helpers, State},
19 BotState,
20};
21
22use super::{
23 filter_command, filter_map_bot_state, filter_map_settings, ConfigParameters, DiffusionDialogue,
24};
25
26#[derive(BotCommands, Debug, Clone)]
28#[command(rename_rule = "lowercase", description = "Image generation commands")]
29pub(crate) enum GenCommands {
30 #[command(description = "generate an image")]
32 Gen(String),
33 #[command(description = "off")]
35 G(String),
36 #[command(description = "off")]
38 Generate(String),
39}
40
41enum Photo {
42 Single(Vec<u8>),
43 Album(Vec<Vec<u8>>),
44}
45
46impl Photo {
47 #[allow(dead_code)]
48 pub fn single(photo: Vec<u8>) -> anyhow::Result<Self> {
49 Ok(Self::Single(photo))
50 }
51
52 pub fn album(photos: Vec<Vec<u8>>) -> anyhow::Result<Self> {
53 if photos.len() == 1 {
54 let images = photos
55 .into_iter()
56 .next()
57 .ok_or_else(|| anyhow!("Failed to get image"))?;
58 Ok(Photo::Single(images))
59 } else {
60 Ok(Photo::Album(photos))
61 }
62 }
63}
64
65struct Reply {
66 caption: String,
67 images: Photo,
68 source: MessageId,
69 seed: i64,
70}
71
72impl Reply {
73 pub fn new(
74 caption: String,
75 images: Vec<Vec<u8>>,
76 seed: i64,
77 source: MessageId,
78 ) -> anyhow::Result<Self> {
79 let images = Photo::album(images)?;
80 Ok(Self {
81 caption,
82 images,
83 source,
84 seed,
85 })
86 }
87
88 pub async fn send(self, bot: &Bot, chat_id: ChatId) -> anyhow::Result<()> {
89 match self.images {
90 Photo::Single(image) => {
91 bot.send_photo(chat_id, InputFile::memory(image))
92 .parse_mode(teloxide::types::ParseMode::MarkdownV2)
93 .caption(self.caption)
94 .reply_markup(keyboard(self.seed))
95 .reply_to_message_id(self.source)
96 .await?;
97 }
98 Photo::Album(images) => {
99 let mut caption = Some(self.caption);
100 let input_media = images.into_iter().map(|i| {
101 let mut media = InputMediaPhoto::new(InputFile::memory(i));
102 media.caption = caption.take();
103 media.parse_mode = Some(teloxide::types::ParseMode::MarkdownV2);
104 InputMedia::Photo(media)
105 });
106
107 bot.send_media_group(chat_id, input_media)
108 .reply_to_message_id(self.source)
109 .await?;
110 bot.send_message(
111 chat_id,
112 "What would you like to do? Select below, or enter a new prompt.",
113 )
114 .reply_markup(keyboard(self.seed))
115 .reply_to_message_id(self.source)
116 .await?;
117 }
118 }
119
120 Ok(())
121 }
122}
123
124struct MessageText(String);
125
126impl MessageText {
127 pub fn new_with_image_params(prompt: &str, infotxt: &dyn ImageParams) -> Self {
128 use teloxide::utils::markdown::escape;
129
130 Self(format!(
131 "`{}`\n\n{}",
132 escape(prompt),
133 [
134 infotxt
135 .negative_prompt()
136 .as_ref()
137 .and_then(|s| (!s.trim().is_empty()).then(|| escape(s)))
138 .map(|s| format!("Negative prompt: `{s}`")),
139 infotxt.steps().map(|s| format!("Steps: `{s}`")),
140 infotxt
141 .sampler()
142 .as_ref()
143 .map(|s| format!("Sampler: `{s}`")),
144 infotxt.cfg().map(|s| format!("CFG scale: `{s}`")),
145 infotxt.seed().map(|s| format!("Seed: `{s}`")),
146 infotxt
147 .width()
148 .and_then(|w| infotxt.height().map(|h| format!("Size: `{w}×{h}`"))),
149 infotxt.model().as_ref().map(|s| format!("Model: `{s}`")),
150 infotxt
151 .denoising()
152 .map(|s| format!("Denoising strength: `{s}`")),
153 ]
154 .into_iter()
155 .flatten()
156 .collect::<Vec<_>>()
157 .join("\n")
158 ))
159 }
160}
161
162impl TryFrom<&dyn ImageParams> for MessageText {
163 type Error = anyhow::Error;
164
165 fn try_from(params: &dyn ImageParams) -> Result<Self, Self::Error> {
166 let prompt = if let Some(prompt) = params.prompt() {
167 prompt
168 } else {
169 return Err(anyhow!("No prompt in image info response"));
170 };
171 Ok(Self::new_with_image_params(prompt.as_str(), params))
172 }
173}
174
175impl TryFrom<Response> for MessageText {
176 type Error = anyhow::Error;
177
178 fn try_from(response: Response) -> Result<Self, Self::Error> {
179 Self::try_from(response.params.as_ref())
180 }
181}
182
183async fn do_img2img(
184 bot: &Bot,
185 cfg: &ConfigParameters,
186 img2img: &mut Box<dyn GenParams>,
187 msg: &Message,
188 photo: Vec<PhotoSize>,
189 prompt: String,
190) -> anyhow::Result<Response> {
191 img2img.set_prompt(prompt);
192
193 let photo = if let Some(photo) = photo
194 .iter()
195 .reduce(|a, p| if a.height > p.height { a } else { p })
196 {
197 photo
198 } else {
199 bot.send_message(msg.chat.id, "Something went wrong.")
200 .await?;
201 return Err(anyhow!("Photo vec was empty!"));
202 };
203 let file = bot.get_file(&photo.file.id).send().await?;
204
205 let photo = helpers::get_file(bot, &file).await?;
206
207 img2img.set_image(Some(photo.into()));
208
209 let resp = cfg.img2img_api.img2img(img2img.as_ref()).await?;
210
211 img2img.set_image(None);
212
213 Ok(resp)
214}
215
216async fn handle_image(
217 bot: Bot,
218 cfg: ConfigParameters,
219 dialogue: DiffusionDialogue,
220 (txt2img, mut img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
221 msg: Message,
222 photo: Vec<PhotoSize>,
223 text: String,
224) -> anyhow::Result<()> {
225 if text.is_empty() {
226 bot.send_message(msg.chat.id, "A prompt is required.")
227 .reply_to_message_id(msg.id)
228 .await?;
229 return Ok(());
230 }
231
232 bot.send_chat_action(msg.chat.id, ChatAction::UploadPhoto)
233 .await?;
234
235 let resp = do_img2img(&bot, &cfg, &mut img2img, &msg, photo, text).await?;
236
237 let seed = if resp.params.seed() == resp.gen_params.seed() {
238 -1
239 } else {
240 resp.params.seed().unwrap_or(-1)
241 };
242
243 let caption = MessageText::try_from(resp.params.as_ref())
244 .context("Failed to build caption from response")?;
245
246 Reply::new(caption.0, resp.images, seed, msg.id)
247 .context("Failed to create response!")?
248 .send(&bot, msg.chat.id)
249 .await?;
250
251 dialogue
252 .update(State::Ready {
253 bot_state: BotState::default(),
254 txt2img,
255 img2img,
256 })
257 .await
258 .map_err(|e| anyhow!(e))?;
259
260 Ok(())
261}
262
263async fn do_txt2img(
264 prompt: String,
265 cfg: &ConfigParameters,
266 txt2img: &mut dyn GenParams,
267) -> anyhow::Result<Response> {
268 txt2img.set_prompt(prompt);
269
270 let resp = cfg.txt2img_api.txt2img(txt2img).await?;
271
272 Ok(resp)
273}
274
275async fn handle_prompt(
276 bot: Bot,
277 cfg: ConfigParameters,
278 dialogue: DiffusionDialogue,
279 (mut txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
280 msg: Message,
281 text: String,
282) -> anyhow::Result<()> {
283 if text.is_empty() {
284 bot.send_message(msg.chat.id, "A prompt is required.")
285 .reply_to_message_id(msg.id)
286 .await?;
287 return Ok(());
288 }
289
290 bot.send_chat_action(msg.chat.id, ChatAction::UploadPhoto)
291 .await?;
292
293 let resp = do_txt2img(text, &cfg, txt2img.as_mut()).await?;
294
295 let seed = if resp.params.seed() == resp.gen_params.seed() {
296 -1
297 } else {
298 resp.params.seed().unwrap_or(-1)
299 };
300
301 let caption = MessageText::try_from(resp.params.as_ref())
302 .context("Failed to build caption from response")?;
303
304 Reply::new(caption.0, resp.images, seed, msg.id)
305 .context("Failed to create response!")?
306 .send(&bot, msg.chat.id)
307 .await?;
308
309 dialogue
310 .update(State::Ready {
311 bot_state: BotState::default(),
312 txt2img,
313 img2img,
314 })
315 .await
316 .map_err(|e| anyhow!(e))?;
317
318 Ok(())
319}
320
321fn keyboard(seed: i64) -> InlineKeyboardMarkup {
322 let seed_button = if seed == -1 {
323 InlineKeyboardButton::callback("🎲 Seed", "reuse/-1")
324 } else {
325 InlineKeyboardButton::callback("♻️ Seed", format!("reuse/{seed}"))
326 };
327 InlineKeyboardMarkup::new([[
328 InlineKeyboardButton::callback("🔄 Rerun", "rerun"),
329 seed_button,
330 InlineKeyboardButton::callback("⚙️ Settings", "settings"),
331 ]])
332}
333
334#[instrument(skip_all)]
335async fn handle_rerun(
336 me: Me,
337 bot: Bot,
338 cfg: ConfigParameters,
339 dialogue: DiffusionDialogue,
340 (txt2img, img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
341 q: CallbackQuery,
342) -> anyhow::Result<()> {
343 let message = if let Some(message) = q.message {
344 message
345 } else {
346 bot.answer_callback_query(q.id)
347 .cache_time(60)
348 .text("Sorry, this message is no longer available.")
349 .await?;
350 return Ok(());
351 };
352
353 let id = message.id;
354 let chat_id = message.chat.id;
355
356 let parent = if let Some(parent) = message.reply_to_message().cloned() {
357 parent
358 } else {
359 bot.answer_callback_query(q.id)
360 .cache_time(60)
361 .text("Oops, something went wrong.")
362 .await?;
363 return Ok(());
364 };
365
366 if let Some(photo) = parent.photo().map(ToOwned::to_owned) {
367 if let Some(text) = message.caption().map(ToOwned::to_owned) {
368 let bot_name = me.user.username.expect("Bots must have a username");
369 let text = if let Ok(command) = GenCommands::parse(&text, &bot_name) {
370 match command {
371 GenCommands::Gen(s) | GenCommands::G(s) | GenCommands::Generate(s) => s,
372 }
373 } else {
374 text
375 };
376
377 if let Err(e) = bot
378 .answer_callback_query(q.id)
379 .cache_time(60)
380 .text("Rerunning this image...")
381 .await
382 {
383 warn!("Failed to answer image rerun callback query: {}", e)
384 }
385 handle_image(
386 bot.clone(),
387 cfg,
388 dialogue,
389 (txt2img, img2img),
390 parent,
391 photo,
392 text,
393 )
394 .await?;
395 } else {
396 bot.send_message(message.chat.id, "A prompt is required to run img2img.")
397 .await?;
398 return Err(anyhow!("No prompt provided for img2img"));
399 }
400 } else if let Some(text) = parent.text().map(ToOwned::to_owned) {
401 if let Err(e) = bot
402 .answer_callback_query(q.id)
403 .cache_time(60)
404 .text("Rerunning this prompt...")
405 .await
406 {
407 warn!("Failed to answer prompt rerun callback query: {}", e)
408 }
409 let bot_name = me.user.username.expect("Bots must have a username");
410 let text = if let Ok(command) = GenCommands::parse(&text, &bot_name) {
411 match command {
412 GenCommands::Gen(s) | GenCommands::G(s) | GenCommands::Generate(s) => s,
413 }
414 } else {
415 text
416 };
417 handle_prompt(bot.clone(), cfg, dialogue, (txt2img, img2img), parent, text).await?;
418 } else {
419 bot.answer_callback_query(q.id)
420 .cache_time(60)
421 .text("Oops, something went wrong.")
422 .await?;
423 return Ok(());
424 }
425
426 bot.edit_message_reply_markup(chat_id, id)
427 .reply_markup(InlineKeyboardMarkup::new([[]]))
428 .send()
429 .await?;
430
431 Ok(())
432}
433
434async fn handle_reuse(
435 bot: Bot,
436 dialogue: DiffusionDialogue,
437 (mut txt2img, mut img2img): (Box<dyn GenParams>, Box<dyn GenParams>),
438 q: CallbackQuery,
439 seed: i64,
440) -> anyhow::Result<()> {
441 let message = if let Some(message) = q.message {
442 message
443 } else {
444 bot.answer_callback_query(q.id)
445 .cache_time(60)
446 .text("Sorry, this message is no longer available.")
447 .await?;
448 return Ok(());
449 };
450
451 let id = message.id;
452 let chat_id = message.chat.id;
453
454 let parent = if let Some(parent) = message.reply_to_message().cloned() {
455 parent
456 } else {
457 bot.answer_callback_query(q.id)
458 .cache_time(60)
459 .text("Oops, something went wrong.")
460 .await?;
461 return Ok(());
462 };
463
464 if parent.photo().is_some() {
465 img2img.set_seed(seed);
466 dialogue
467 .update(State::Ready {
468 bot_state: BotState::default(),
469 txt2img,
470 img2img,
471 })
472 .await
473 .map_err(|e| anyhow!(e))?;
474 } else if parent.text().is_some() {
475 txt2img.set_seed(seed);
476 dialogue
477 .update(State::Ready {
478 bot_state: BotState::default(),
479 txt2img,
480 img2img,
481 })
482 .await
483 .map_err(|e| anyhow!(e))?;
484 } else {
485 bot.answer_callback_query(q.id)
486 .cache_time(60)
487 .text("Oops, something went wrong.")
488 .await?;
489 return Ok(());
490 }
491 if seed == -1 {
492 if let Err(e) = bot
493 .answer_callback_query(q.id)
494 .text("Seed randomized.")
495 .await
496 {
497 warn!("Failed to answer randomize seed callback query: {}", e)
498 }
499 } else {
500 if let Err(e) = bot
501 .answer_callback_query(q.id)
502 .text(format!("Seed set to {seed}."))
503 .await
504 {
505 warn!("Failed to answer set seed callback query: {}", e)
506 }
507 bot.edit_message_reply_markup(chat_id, id)
508 .reply_markup(keyboard(-1))
509 .send()
510 .await?;
511 }
512
513 Ok(())
514}
515
516pub(crate) fn image_schema() -> UpdateHandler<anyhow::Error> {
517 let gen_command_handler = Update::filter_message()
518 .chain(filter_command::<GenCommands>())
519 .chain(dptree::filter_map(|g: GenCommands| match g {
520 GenCommands::Gen(s) | GenCommands::G(s) | GenCommands::Generate(s) => Some(s),
521 }))
522 .branch(Message::filter_photo().endpoint(handle_image))
523 .branch(dptree::endpoint(handle_prompt));
524
525 let message_handler = Update::filter_message()
526 .branch(
527 dptree::filter(|msg: Message| {
528 msg.text().map(|t| t.starts_with('/')).unwrap_or_default()
529 })
530 .endpoint(|msg: Message| async move {
531 info!(
532 "Ignoring unknown command: {}",
533 msg.text().unwrap_or_default()
534 );
535 Ok(())
536 }),
537 )
538 .branch(
539 Message::filter_photo()
540 .map(|msg: Message| msg.caption().map(str::to_string).unwrap_or_default())
541 .endpoint(handle_image),
542 )
543 .branch(Message::filter_text().endpoint(handle_prompt));
544
545 let callback_handler = Update::filter_callback_query()
546 .branch(
547 dptree::filter_map(|q: CallbackQuery| {
548 q.data
549 .filter(|d| d.starts_with("reuse"))
550 .and_then(|d| d.split('/').skip(1).flat_map(str::parse::<i64>).next())
551 })
552 .endpoint(handle_reuse),
553 )
554 .branch(
555 dptree::filter(|q: CallbackQuery| q.data.filter(|d| d.starts_with("rerun")).is_some())
556 .endpoint(handle_rerun),
557 );
558
559 dptree::entry()
560 .chain(filter_map_bot_state())
561 .chain(case![BotState::Generate])
562 .chain(filter_map_settings())
563 .branch(gen_command_handler)
564 .branch(message_handler)
565 .branch(callback_handler)
566}