babyrite/expand/
discord.rs1use regex::Regex;
9use serenity::all::{
10 ChannelId, ChannelType, Context, GuildChannel, GuildId, Message, MessageId,
11 PermissionOverwriteType, Permissions, RoleId,
12};
13use serenity_builder::model::embed::SerenityEmbed;
14use std::collections::HashSet;
15use std::sync::LazyLock;
16use url::Url;
17
18use super::{ExpandError, ExpandedContent};
19use crate::cache::CacheArgs;
20
21pub static MESSAGE_LINK_REGEX: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(r"https://(?:ptb\.|canary\.)?discord\.com/channels/(\d+)/(\d+)/(\d+)").unwrap()
26});
27
28#[derive(Debug)]
30pub struct MessageLinkIDs {
31 pub guild_id: GuildId,
33 pub channel_id: ChannelId,
35 pub message_id: MessageId,
37}
38
39#[derive(serde::Deserialize, Debug)]
41pub struct Preview {
42 pub message: Message,
44 pub channel: GuildChannel,
46}
47
48#[derive(thiserror::Error, Debug)]
50pub enum PreviewError {
51 #[error("Failed to retrieve from cache.")]
53 Cache,
54 #[error("NSFW content previews are not permitted, but the channel is marked as NSFW.")]
56 Nsfw,
57 #[error("The channel is a private channel or private thread.")]
59 Permission,
60 #[allow(clippy::enum_variant_names)]
62 #[error(transparent)]
63 SerenityError(#[from] serenity::Error),
64}
65
66impl MessageLinkIDs {
67 pub fn parse_all(text: &str) -> Vec<MessageLinkIDs> {
73 let mut seen_urls = HashSet::new();
74 MESSAGE_LINK_REGEX
75 .captures_iter(text)
76 .filter_map(|captures| {
77 let m = captures.get(0)?;
78 let full_url = m.as_str();
79 if m.start() > 0 && text.as_bytes()[m.start() - 1] == b'<' {
81 return None;
82 }
83 if !seen_urls.insert(full_url.to_string()) {
84 return None;
85 }
86
87 let url = Url::parse(full_url).ok()?;
88
89 if !matches!(
90 url.domain(),
91 Some("discord.com") | Some("canary.discord.com") | Some("ptb.discord.com")
92 ) {
93 return None;
94 }
95
96 let guild_id = GuildId::new(captures.get(1)?.as_str().parse().ok()?);
97 let channel_id = ChannelId::new(captures.get(2)?.as_str().parse().ok()?);
98 let message_id = MessageId::new(captures.get(3)?.as_str().parse().ok()?);
99
100 Some(MessageLinkIDs {
101 guild_id,
102 channel_id,
103 message_id,
104 })
105 })
106 .take(3) .collect()
108 }
109
110 pub async fn fetch(&self, ctx: &Context) -> Result<ExpandedContent, ExpandError> {
112 let preview = Preview::get(self, ctx).await?;
113 let (message, channel) = (preview.message, preview.channel);
114
115 let embed = SerenityEmbed::builder()
116 .description(message.content)
117 .author_name(message.author.name.clone())
118 .author_icon_url(message.author.avatar_url().unwrap_or_default())
119 .footer_text(channel.name)
120 .timestamp(message.timestamp)
121 .color(0x7A4AFFu32)
122 .image_url(
123 message
124 .attachments
125 .first()
126 .map(|a| a.url.clone())
127 .unwrap_or_default(),
128 )
129 .build();
130
131 Ok(ExpandedContent::Embed(Box::new(embed)))
132 }
133}
134
135impl Preview {
136 async fn get(args: &MessageLinkIDs, ctx: &Context) -> Result<Preview, PreviewError> {
141 let caches = CacheArgs {
142 guild_id: args.guild_id,
143 channel_id: args.channel_id,
144 };
145
146 let channel = caches.get(ctx).await.map_err(|_| PreviewError::Cache)?;
147
148 if channel.nsfw {
149 return Err(PreviewError::Nsfw);
150 }
151
152 if matches!(
153 channel.kind,
154 ChannelType::Private | ChannelType::PrivateThread
155 ) {
156 return Err(PreviewError::Permission);
157 }
158
159 let everyone_role_id = RoleId::new(args.guild_id.get());
160 if channel
161 .permission_overwrites
162 .iter()
163 .any(|overwrite| {
164 matches!(overwrite.kind, PermissionOverwriteType::Role(role_id) if role_id == everyone_role_id)
165 && overwrite.deny.contains(Permissions::VIEW_CHANNEL)
166 })
167 {
168 return Err(PreviewError::Permission);
169 }
170
171 let message = channel.message(&ctx.http, args.message_id).await?;
172 Ok(Preview { message, channel })
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 #[test]
181 fn parse_standard_link() {
182 let text = "https://discord.com/channels/123456789/987654321/111111111";
183 let results = MessageLinkIDs::parse_all(text);
184 assert_eq!(results.len(), 1);
185 assert_eq!(results[0].guild_id, GuildId::new(123456789));
186 assert_eq!(results[0].channel_id, ChannelId::new(987654321));
187 assert_eq!(results[0].message_id, MessageId::new(111111111));
188 }
189
190 #[test]
191 fn parse_ptb_link() {
192 let text = "https://ptb.discord.com/channels/123/456/789";
193 let results = MessageLinkIDs::parse_all(text);
194 assert_eq!(results.len(), 1);
195 assert_eq!(results[0].guild_id, GuildId::new(123));
196 }
197
198 #[test]
199 fn parse_canary_link() {
200 let text = "https://canary.discord.com/channels/123/456/789";
201 let results = MessageLinkIDs::parse_all(text);
202 assert_eq!(results.len(), 1);
203 assert_eq!(results[0].guild_id, GuildId::new(123));
204 }
205
206 #[test]
207 fn parse_multiple_links() {
208 let text = "https://discord.com/channels/1/2/3 and https://discord.com/channels/4/5/6";
209 let results = MessageLinkIDs::parse_all(text);
210 assert_eq!(results.len(), 2);
211 assert_eq!(results[0].guild_id, GuildId::new(1));
212 assert_eq!(results[1].guild_id, GuildId::new(4));
213 }
214
215 #[test]
216 fn parse_deduplicates() {
217 let text = "https://discord.com/channels/1/2/3 https://discord.com/channels/1/2/3";
218 let results = MessageLinkIDs::parse_all(text);
219 assert_eq!(results.len(), 1);
220 }
221
222 #[test]
223 fn parse_limits_to_three() {
224 let text = "\
225 https://discord.com/channels/1/2/3 \
226 https://discord.com/channels/4/5/6 \
227 https://discord.com/channels/7/8/9 \
228 https://discord.com/channels/10/11/12";
229 let results = MessageLinkIDs::parse_all(text);
230 assert_eq!(results.len(), 3);
231 }
232
233 #[test]
234 fn parse_no_match() {
235 let text = "Just some regular text";
236 let results = MessageLinkIDs::parse_all(text);
237 assert!(results.is_empty());
238 }
239
240 #[test]
241 fn parse_ignores_invalid_url() {
242 let text = "https://notdiscord.com/channels/1/2/3";
244 let results = MessageLinkIDs::parse_all(text);
245 assert!(results.is_empty());
246 }
247
248 #[test]
249 fn parse_ignores_angle_bracket_link() {
250 let text = "<https://discord.com/channels/123/456/789>";
251 let results = MessageLinkIDs::parse_all(text);
252 assert!(results.is_empty());
253 }
254
255 #[test]
256 fn parse_mixed_with_text() {
257 let text = "Hey check this out https://discord.com/channels/1/2/3 pretty cool right?";
258 let results = MessageLinkIDs::parse_all(text);
259 assert_eq!(results.len(), 1);
260 assert_eq!(results[0].message_id, MessageId::new(3));
261 }
262}