diff --git a/src/Discord.Net.Core/DiscordConfig.cs b/src/Discord.Net.Core/DiscordConfig.cs
index 3aacc30b6a..396c3069b8 100644
--- a/src/Discord.Net.Core/DiscordConfig.cs
+++ b/src/Discord.Net.Core/DiscordConfig.cs
@@ -232,5 +232,10 @@ public class DiscordConfig
/// Returns the max length of an application description.
///
public const int MaxApplicationDescriptionLength = 400;
+
+ ///
+ /// Returns the max number of user IDs that can be requested in a Request Guild Members chunk.
+ ///
+ public const int MaxRequestedUserIdsPerRequestGuildMembersChunk = 100;
}
}
diff --git a/src/Discord.Net.Core/Entities/Guilds/IGuild.cs b/src/Discord.Net.Core/Entities/Guilds/IGuild.cs
index aec5bff1e6..cdcd196c0c 100644
--- a/src/Discord.Net.Core/Entities/Guilds/IGuild.cs
+++ b/src/Discord.Net.Core/Entities/Guilds/IGuild.cs
@@ -3,6 +3,7 @@
using System.Collections.Generic;
using System.Globalization;
using System.IO;
+using System.Threading;
using System.Threading.Tasks;
namespace Discord
@@ -958,6 +959,33 @@ public interface IGuild : IDeletable, ISnowflakeEntity
/// A task that represents the asynchronous download operation.
///
Task DownloadUsersAsync();
+
+ ///
+ /// Downloads specific users for this guild with the default request timeout.
+ ///
+ ///
+ /// This method downloads all users specified in through the Gateway and caches them.
+ /// Consider using when downloading a large number of users.
+ ///
+ /// The list of Discord user IDs to download.
+ ///
+ /// A task that represents the asynchronous download operation.
+ ///
+ /// The timeout has elapsed.
+ Task DownloadUsersAsync(IEnumerable userIds);
+
+ ///
+ /// Downloads specific users for this guild.
+ ///
+ ///
+ /// This method downloads all users specified in through the Gateway and caches them.
+ ///
+ /// The list of Discord user IDs to download.
+ /// The cancellation token used to cancel the task.
+ ///
+ /// A task that represents the asynchronous download operation.
+ ///
+ Task DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken);
///
/// Prunes inactive users.
///
diff --git a/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs b/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs
new file mode 100644
index 0000000000..0ce262c90f
--- /dev/null
+++ b/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs
@@ -0,0 +1,103 @@
+// Based on https://github.com/dotnet/runtime/blob/main/src/libraries/System.Linq/src/System/Linq/Chunk.cs (only available on .NET 6+)
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+
+namespace Discord
+{
+ internal static class EnumerableExtensions
+ {
+ ///
+ /// Split the elements of a sequence into chunks of size at most .
+ ///
+ ///
+ /// Every chunk except the last will be of size .
+ /// The last chunk will contain the remaining elements and may be of a smaller size.
+ ///
+ ///
+ /// An whose elements to chunk.
+ ///
+ ///
+ /// Maximum size of each chunk.
+ ///
+ ///
+ /// The type of the elements of source.
+ ///
+ ///
+ /// An that contains the elements the input sequence split into chunks of size .
+ ///
+ ///
+ /// is null.
+ ///
+ ///
+ /// is below 1.
+ ///
+ public static IEnumerable Chunk(this IEnumerable source, int size)
+ {
+ Preconditions.NotNull(source, nameof(source));
+ Preconditions.GreaterThan(size, 0, nameof(size));
+
+ return ChunkIterator(source, size);
+ }
+
+ private static IEnumerable ChunkIterator(IEnumerable source, int size)
+ {
+ using IEnumerator e = source.GetEnumerator();
+
+ // Before allocating anything, make sure there's at least one element.
+ if (e.MoveNext())
+ {
+ // Now that we know we have at least one item, allocate an initial storage array. This is not
+ // the array we'll yield. It starts out small in order to avoid significantly overallocating
+ // when the source has many fewer elements than the chunk size.
+ int arraySize = Math.Min(size, 4);
+ int i;
+ do
+ {
+ var array = new TSource[arraySize];
+
+ // Store the first item.
+ array[0] = e.Current;
+ i = 1;
+
+ if (size != array.Length)
+ {
+ // This is the first chunk. As we fill the array, grow it as needed.
+ for (; i < size && e.MoveNext(); i++)
+ {
+ if (i >= array.Length)
+ {
+ arraySize = (int)Math.Min((uint)size, 2 * (uint)array.Length);
+ Array.Resize(ref array, arraySize);
+ }
+
+ array[i] = e.Current;
+ }
+ }
+ else
+ {
+ // For all but the first chunk, the array will already be correctly sized.
+ // We can just store into it until either it's full or MoveNext returns false.
+ TSource[] local = array; // avoid bounds checks by using cached local (`array` is lifted to iterator object as a field)
+ Debug.Assert(local.Length == size);
+ for (; (uint)i < (uint)local.Length && e.MoveNext(); i++)
+ {
+ local[i] = e.Current;
+ }
+ }
+
+ if (i != array.Length)
+ {
+ Array.Resize(ref array, i);
+ }
+
+ yield return array;
+ }
+ while (i >= size && e.MoveNext());
+ }
+ }
+ }
+}
diff --git a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs
index a6c2d2d998..fdf35b09ca 100644
--- a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs
+++ b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs
@@ -6,6 +6,7 @@
using System.Globalization;
using System.IO;
using System.Linq;
+using System.Threading;
using System.Threading.Tasks;
using Model = Discord.API.Guild;
using WidgetModel = Discord.API.GuildWidget;
@@ -1512,6 +1513,14 @@ async Task> IGuild.GetUsersAsync(CacheMode mode,
Task IGuild.DownloadUsersAsync() =>
throw new NotSupportedException();
///
+ /// Downloading users is not supported for a REST-based guild.
+ Task IGuild.DownloadUsersAsync(IEnumerable userIds) =>
+ throw new NotSupportedException();
+ ///
+ /// Downloading users is not supported for a REST-based guild.
+ Task IGuild.DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken) =>
+ throw new NotSupportedException();
+ ///
async Task> IGuild.SearchUsersAsync(string query, int limit, CacheMode mode, RequestOptions options)
{
if (mode == CacheMode.AllowDownload)
@@ -1604,7 +1613,7 @@ async Task IGuild.GetAutoModRulesAsync(RequestOptions options)
///
async Task IGuild.CreateAutoModRuleAsync(Action props, RequestOptions options)
=> await CreateAutoModRuleAsync(props, options).ConfigureAwait(false);
-
+
///
async Task IGuild.GetOnboardingAsync(RequestOptions options)
=> await GetOnboardingAsync(options);
diff --git a/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs b/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs
index 26114bf541..e62dae08bf 100644
--- a/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs
+++ b/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs
@@ -1,4 +1,5 @@
using Newtonsoft.Json;
+using System.Collections.Generic;
namespace Discord.API.Gateway
{
@@ -8,5 +9,15 @@ internal class GuildMembersChunkEvent
public ulong GuildId { get; set; }
[JsonProperty("members")]
public GuildMember[] Members { get; set; }
+ [JsonProperty("chunk_index")]
+ public int ChunkIndex { get; set; }
+ [JsonProperty("chunk_count")]
+ public int ChunkCount { get; set; }
+ [JsonProperty("not_found")]
+ public Optional> NotFound { get; set; }
+ [JsonProperty("presences")]
+ public Optional> Presences { get; set; }
+ [JsonProperty("nonce")]
+ public Optional Nonce { get; set; }
}
}
diff --git a/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs b/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs
index f7a63e330c..d88fce51c5 100644
--- a/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs
+++ b/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs
@@ -6,12 +6,17 @@ namespace Discord.API.Gateway
[JsonObject(MemberSerialization = MemberSerialization.OptIn)]
internal class RequestMembersParams
{
+ [JsonProperty("guild_id")]
+ public ulong GuildId { get; set; }
[JsonProperty("query")]
- public string Query { get; set; }
+ public Optional Query { get; set; }
[JsonProperty("limit")]
public int Limit { get; set; }
-
- [JsonProperty("guild_id")]
- public IEnumerable GuildIds { get; set; }
+ [JsonProperty("presences")]
+ public Optional Presences { get; set; }
+ [JsonProperty("user_ids")]
+ public Optional> UserIds { get; set; }
+ [JsonProperty("nonce")]
+ public Optional Nonce { get; set; }
}
}
diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs
index 482a08a0f9..897aa67460 100644
--- a/src/Discord.Net.WebSocket/BaseSocketClient.cs
+++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.Threading;
using System.Threading.Tasks;
namespace Discord.WebSocket
@@ -236,6 +237,17 @@ private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config
///
public abstract Task DownloadUsersAsync(IEnumerable guilds);
+ ///
+ /// Attempts to download specific users into the user cache for the selected guild.
+ ///
+ /// The guild to download the members from.
+ /// The list of Discord user IDs to download.
+ /// The cancellation token used to cancel the task.
+ ///
+ /// A task that represents the asynchronous download operation.
+ ///
+ public abstract Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default);
+
///
/// Creates a guild for the logged-in user who is in less than 10 active guilds.
///
diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
index c3809ba672..b4c59c4236 100644
--- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs
@@ -381,6 +381,23 @@ public override async Task DownloadUsersAsync(IEnumerable guilds)
}
}
+ ///
+ /// is
+ public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default)
+ {
+ Preconditions.NotNull(guild, nameof(guild));
+
+ for (int i = 0; i < _shards.Length; i++)
+ {
+ int id = _shardIds[i];
+ if (GetShardIdFor(guild) == id)
+ {
+ await _shards[i].DownloadUsersAsync(guild, userIds, cancelToken).ConfigureAwait(false);
+ break;
+ }
+ }
+ }
+
private int GetLatency()
{
int total = 0;
diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
index 75960b173e..64a0496502 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs
@@ -379,11 +379,24 @@ public async Task SendPresenceUpdateAsync(UserStatus status, bool isAFK, long? s
options.BucketId = GatewayBucket.Get(GatewayBucketType.PresenceUpdate).Id;
await SendGatewayAsync(GatewayOpCode.PresenceUpdate, args, options: options).ConfigureAwait(false);
}
- public async Task SendRequestMembersAsync(IEnumerable guildIds, RequestOptions options = null)
+ public async Task SendRequestMembersAsync(ulong guildId, RequestOptions options = null)
{
options = RequestOptions.CreateOrClone(options);
- await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, new RequestMembersParams { GuildIds = guildIds, Query = "", Limit = 0 }, options: options).ConfigureAwait(false);
+ await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, new RequestMembersParams { GuildId = guildId, Query = "", Limit = 0 }, options: options).ConfigureAwait(false);
}
+ public async Task SendRequestMembersAsync(ulong guildId, IEnumerable userIds, string nonce, RequestOptions options = null)
+ {
+ var payload = new RequestMembersParams
+ {
+ GuildId = guildId,
+ Limit = 0,
+ UserIds = new Optional>(userIds),
+ Nonce = nonce
+ };
+ options = RequestOptions.CreateOrClone(options);
+ await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, payload, options: options).ConfigureAwait(false);
+ }
+
public async Task SendVoiceStateUpdateAsync(ulong guildId, ulong? channelId, bool selfDeaf, bool selfMute, RequestOptions options = null)
{
var payload = new VoiceStateUpdateParams
diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
index 924f5f645c..b7ddc73139 100644
--- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs
+++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs
@@ -37,6 +37,7 @@ public partial class DiscordSocketClient : BaseSocketClient, IDiscordClient
private readonly ConnectionManager _connection;
private readonly Logger _gatewayLogger;
private readonly SemaphoreSlim _stateLock;
+ private readonly ConcurrentDictionary> _guildMembersRequestTasks;
private string _sessionId;
private int _lastSeq;
@@ -51,6 +52,7 @@ public partial class DiscordSocketClient : BaseSocketClient, IDiscordClient
private GatewayIntents _gatewayIntents;
private ImmutableArray> _defaultStickers;
private SocketSelfUser _previousSessionUser;
+ private long _guildMembersRequestCounter;
///
/// Provides access to a REST-only client with a shared state from this client.
@@ -183,6 +185,8 @@ private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClie
e.ErrorContext.Handled = true;
};
+ _guildMembersRequestTasks = new ConcurrentDictionary>();
+
ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false);
ApiClient.ReceivedGatewayEvent += ProcessMessageAsync;
@@ -627,29 +631,51 @@ private async Task ProcessUserDownloadsAsync(IEnumerable guilds)
{
var cachedGuilds = guilds.ToImmutableArray();
- const short batchSize = 1;
- ulong[] batchIds = new ulong[Math.Min(batchSize, cachedGuilds.Length)];
- Task[] batchTasks = new Task[batchIds.Length];
- int batchCount = (cachedGuilds.Length + (batchSize - 1)) / batchSize;
+ foreach (var guild in cachedGuilds)
+ {
+ await ApiClient.SendRequestMembersAsync(guild.Id).ConfigureAwait(false);
+ await guild.DownloaderPromise.ConfigureAwait(false);
+ }
+ }
- for (int i = 0, k = 0; i < batchCount; i++)
+ ///
+ public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default)
+ {
+ if (ConnectionState == ConnectionState.Connected)
{
- bool isLast = i == batchCount - 1;
- int count = isLast ? (cachedGuilds.Length - (batchCount - 1) * batchSize) : batchSize;
+ EnsureGatewayIntent(GatewayIntents.GuildMembers);
- for (int j = 0; j < count; j++, k++)
+ var socketGuild = GetGuild(guild.Id);
+ if (socketGuild != null)
{
- var guild = cachedGuilds[k];
- batchIds[j] = guild.Id;
- batchTasks[j] = guild.DownloaderPromise;
+ foreach (var chunk in userIds.Chunk(DiscordConfig.MaxRequestedUserIdsPerRequestGuildMembersChunk))
+ {
+ await ProcessUserDownloadsAsync(socketGuild, chunk, cancelToken).ConfigureAwait(false);
+ }
}
+ }
+ else
+ {
+ throw new InvalidOperationException("Client not connected");
+ }
+ }
- await ApiClient.SendRequestMembersAsync(batchIds).ConfigureAwait(false);
- if (isLast && batchCount > 1)
- await Task.WhenAll(batchTasks.Take(count)).ConfigureAwait(false);
- else
- await Task.WhenAll(batchTasks).ConfigureAwait(false);
+ private async Task ProcessUserDownloadsAsync(SocketGuild guild, IEnumerable userIds, CancellationToken cancelToken = default)
+ {
+ var nonce = Interlocked.Increment(ref _guildMembersRequestCounter).ToString();
+ var tcs = new TaskCompletionSource();
+ using var registration = cancelToken.Register(() => tcs.TrySetCanceled());
+ _guildMembersRequestTasks.TryAdd(nonce, tcs);
+ try
+ {
+ await ApiClient.SendRequestMembersAsync(guild.Id, userIds, nonce).ConfigureAwait(false);
+ await tcs.Task.ConfigureAwait(false);
+ cancelToken.ThrowIfCancellationRequested();
+ }
+ finally
+ {
+ _guildMembersRequestTasks.TryRemove(nonce, out _);
}
}
@@ -1410,6 +1436,13 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty
guild.CompleteDownloadUsers();
await TimedInvokeAsync(_guildMembersDownloadedEvent, nameof(GuildMembersDownloaded), guild).ConfigureAwait(false);
}
+
+ if (data.Nonce.IsSpecified
+ && data.ChunkIndex + 1 >= data.ChunkCount
+ && _guildMembersRequestTasks.TryRemove(data.Nonce.Value, out var tcs))
+ {
+ tcs.TrySetResult(true);
+ }
}
else
{
@@ -2904,7 +2937,7 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty
}
break;
#endregion
-
+
#region Auto Moderation
case "AUTO_MODERATION_RULE_CREATE":
diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
index 9180ad92f3..d8f5ea7d25 100644
--- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
+++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs
@@ -383,7 +383,7 @@ public IReadOnlyCollection Stickers
///
///
/// Otherwise, you may need to enable to fetch
- /// the full user list upon startup, or use to manually download
+ /// the full user list upon startup, or use to manually download
/// the users.
///
///
@@ -1260,6 +1260,20 @@ public async Task DownloadUsersAsync()
{
await Discord.DownloadUsersAsync(new[] { this }).ConfigureAwait(false);
}
+
+ ///
+ public async Task DownloadUsersAsync(IEnumerable userIds)
+ {
+ using var cts = new CancellationTokenSource(DiscordConfig.DefaultRequestTimeout);
+ await DownloadUsersAsync(userIds, cts.Token).ConfigureAwait(false);
+ }
+
+ ///
+ public async Task DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken)
+ {
+ await Discord.DownloadUsersAsync(this, userIds, cancelToken).ConfigureAwait(false);
+ }
+
internal void CompleteDownloadUsers()
{
_downloaderPromise.TrySetResultAsync(true);
@@ -1406,7 +1420,7 @@ public Task CreateEventAsync(
///
/// A task that represents the asynchronous get operation. The task result contains a read-only collection
/// of the requested audit log entries.
- ///
+ ///
public IAsyncEnumerable> GetAuditLogsAsync(int limit, RequestOptions options = null, ulong? beforeId = null, ulong? userId = null, ActionType? actionType = null, ulong? afterId = null)
=> GuildHelper.GetAuditLogsAsync(this, Discord, beforeId, limit, options, userId: userId, actionType: actionType, afterId: afterId);