Skip to main content

babyrite/expand/
discord.rs

1//! Discord message link expansion.
2//!
3//! This module provides functionality for parsing Discord message links
4//! and generating embed previews of the linked messages.
5//!
6//! Migrated from `preview.rs` with support for multiple link expansion.
7
8use regex::Regex;
9use serenity::all::{
10    ChannelId, ChannelType, Context, GuildChannel, GuildId, Message, MessageId,
11    PermissionOverwriteType, Permissions, RoleId,
12};
13use serenity_builder::model::embed::SerenityEmbed;
14use std::collections::HashSet;
15use std::sync::LazyLock;
16use url::Url;
17
18use super::{ExpandError, ExpandedContent};
19use crate::cache::CacheArgs;
20
21/// Regex pattern for matching Discord message links.
22///
23/// Supports production, PTB, and Canary Discord URLs.
24pub static MESSAGE_LINK_REGEX: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(r"https://(?:ptb\.|canary\.)?discord\.com/channels/(\d+)/(\d+)/(\d+)").unwrap()
26});
27
28/// Parsed IDs from a Discord message link.
29#[derive(Debug)]
30pub struct MessageLinkIDs {
31    /// The guild ID from the message link.
32    pub guild_id: GuildId,
33    /// The channel ID from the message link.
34    pub channel_id: ChannelId,
35    /// The message ID from the message link.
36    pub message_id: MessageId,
37}
38
39/// A preview containing the message and its channel.
40#[derive(serde::Deserialize, Debug)]
41pub struct Preview {
42    /// The message to preview.
43    pub message: Message,
44    /// The channel containing the message.
45    pub channel: GuildChannel,
46}
47
48/// Errors that can occur when generating a Discord message preview.
49#[derive(thiserror::Error, Debug)]
50pub enum PreviewError {
51    /// Failed to retrieve channel information from cache.
52    #[error("Failed to retrieve from cache.")]
53    Cache,
54    /// The target channel is marked as NSFW.
55    #[error("NSFW content previews are not permitted, but the channel is marked as NSFW.")]
56    Nsfw,
57    /// The target channel is private or a private thread.
58    #[error("The channel is a private channel or private thread.")]
59    Permission,
60    /// An error occurred while communicating with Discord.
61    #[allow(clippy::enum_variant_names)]
62    #[error(transparent)]
63    SerenityError(#[from] serenity::Error),
64}
65
66impl MessageLinkIDs {
67    /// Parses all Discord message links from the given text.
68    ///
69    /// Returns a `Vec<MessageLinkIDs>` containing all valid message links found.
70    ///
71    /// Note: Duplicate URLs are ignored, and a maximum of 3 links are returned.
72    pub fn parse_all(text: &str) -> Vec<MessageLinkIDs> {
73        let mut seen_urls = HashSet::new();
74        MESSAGE_LINK_REGEX
75            .captures_iter(text)
76            .filter_map(|captures| {
77                let m = captures.get(0)?;
78                let full_url = m.as_str();
79                // Skip URLs wrapped in angle brackets (e.g., <https://...>)
80                if m.start() > 0 && text.as_bytes()[m.start() - 1] == b'<' {
81                    return None;
82                }
83                if !seen_urls.insert(full_url.to_string()) {
84                    return None;
85                }
86
87                let url = Url::parse(full_url).ok()?;
88
89                if !matches!(
90                    url.domain(),
91                    Some("discord.com") | Some("canary.discord.com") | Some("ptb.discord.com")
92                ) {
93                    return None;
94                }
95
96                let guild_id = GuildId::new(captures.get(1)?.as_str().parse().ok()?);
97                let channel_id = ChannelId::new(captures.get(2)?.as_str().parse().ok()?);
98                let message_id = MessageId::new(captures.get(3)?.as_str().parse().ok()?);
99
100                Some(MessageLinkIDs {
101                    guild_id,
102                    channel_id,
103                    message_id,
104                })
105            })
106            .take(3) // Limit to maximum 3 links
107            .collect()
108    }
109
110    /// Fetches the linked message and returns an embed preview.
111    pub async fn fetch(&self, ctx: &Context) -> Result<ExpandedContent, ExpandError> {
112        let preview = Preview::get(self, ctx).await?;
113        let (message, channel) = (preview.message, preview.channel);
114
115        let embed = SerenityEmbed::builder()
116            .description(message.content)
117            .author_name(message.author.name.clone())
118            .author_icon_url(message.author.avatar_url().unwrap_or_default())
119            .footer_text(channel.name)
120            .timestamp(message.timestamp)
121            .color(0x7A4AFFu32)
122            .image_url(
123                message
124                    .attachments
125                    .first()
126                    .map(|a| a.url.clone())
127                    .unwrap_or_default(),
128            )
129            .build();
130
131        Ok(ExpandedContent::Embed(Box::new(embed)))
132    }
133}
134
135impl Preview {
136    /// Retrieves a preview for the given message link.
137    ///
138    /// Fetches the message and channel information, validating that
139    /// the channel is not NSFW and is publicly accessible.
140    async fn get(args: &MessageLinkIDs, ctx: &Context) -> Result<Preview, PreviewError> {
141        let caches = CacheArgs {
142            guild_id: args.guild_id,
143            channel_id: args.channel_id,
144        };
145
146        let channel = caches.get(ctx).await.map_err(|_| PreviewError::Cache)?;
147
148        if channel.nsfw {
149            return Err(PreviewError::Nsfw);
150        }
151
152        if matches!(
153            channel.kind,
154            ChannelType::Private | ChannelType::PrivateThread
155        ) {
156            return Err(PreviewError::Permission);
157        }
158
159        let everyone_role_id = RoleId::new(args.guild_id.get());
160        if channel
161            .permission_overwrites
162            .iter()
163            .any(|overwrite| {
164                matches!(overwrite.kind, PermissionOverwriteType::Role(role_id) if role_id == everyone_role_id)
165                    && overwrite.deny.contains(Permissions::VIEW_CHANNEL)
166            })
167        {
168            return Err(PreviewError::Permission);
169        }
170
171        let message = channel.message(&ctx.http, args.message_id).await?;
172        Ok(Preview { message, channel })
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn parse_standard_link() {
182        let text = "https://discord.com/channels/123456789/987654321/111111111";
183        let results = MessageLinkIDs::parse_all(text);
184        assert_eq!(results.len(), 1);
185        assert_eq!(results[0].guild_id, GuildId::new(123456789));
186        assert_eq!(results[0].channel_id, ChannelId::new(987654321));
187        assert_eq!(results[0].message_id, MessageId::new(111111111));
188    }
189
190    #[test]
191    fn parse_ptb_link() {
192        let text = "https://ptb.discord.com/channels/123/456/789";
193        let results = MessageLinkIDs::parse_all(text);
194        assert_eq!(results.len(), 1);
195        assert_eq!(results[0].guild_id, GuildId::new(123));
196    }
197
198    #[test]
199    fn parse_canary_link() {
200        let text = "https://canary.discord.com/channels/123/456/789";
201        let results = MessageLinkIDs::parse_all(text);
202        assert_eq!(results.len(), 1);
203        assert_eq!(results[0].guild_id, GuildId::new(123));
204    }
205
206    #[test]
207    fn parse_multiple_links() {
208        let text = "https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6";
209        let results = MessageLinkIDs::parse_all(text);
210        assert_eq!(results.len(), 2);
211        assert_eq!(results[0].guild_id, GuildId::new(1));
212        assert_eq!(results[1].guild_id, GuildId::new(4));
213    }
214
215    #[test]
216    fn parse_deduplicates() {
217        let text = "https://discord.com/channels/1/2/3 https://discord.com/channels/1/2/3";
218        let results = MessageLinkIDs::parse_all(text);
219        assert_eq!(results.len(), 1);
220    }
221
222    #[test]
223    fn parse_limits_to_three() {
224        let text = "\
225            https://discord.com/channels/1/2/3 \
226            https://discord.com/channels/4/5/6 \
227            https://discord.com/channels/7/8/9 \
228            https://discord.com/channels/10/11/12";
229        let results = MessageLinkIDs::parse_all(text);
230        assert_eq!(results.len(), 3);
231    }
232
233    #[test]
234    fn parse_no_match() {
235        let text = "Just some regular text";
236        let results = MessageLinkIDs::parse_all(text);
237        assert!(results.is_empty());
238    }
239
240    #[test]
241    fn parse_ignores_invalid_url() {
242        // Non-discord domain should not match (regex anchors to discord.com)
243        let text = "https://notdiscord.com/channels/1/2/3";
244        let results = MessageLinkIDs::parse_all(text);
245        assert!(results.is_empty());
246    }
247
248    #[test]
249    fn parse_ignores_angle_bracket_link() {
250        let text = "<https://discord.com/channels/123/456/789>";
251        let results = MessageLinkIDs::parse_all(text);
252        assert!(results.is_empty());
253    }
254
255    #[test]
256    fn parse_mixed_with_text() {
257        let text = "Hey check this out https://discord.com/channels/1/2/3 pretty cool right?";
258        let results = MessageLinkIDs::parse_all(text);
259        assert_eq!(results.len(), 1);
260        assert_eq!(results[0].message_id, MessageId::new(3));
261    }
262}