diff --git a/src/Discord.Net.Core/DiscordConfig.cs b/src/Discord.Net.Core/DiscordConfig.cs index 3aacc30b6a..396c3069b8 100644 --- a/src/Discord.Net.Core/DiscordConfig.cs +++ b/src/Discord.Net.Core/DiscordConfig.cs @@ -232,5 +232,10 @@ public class DiscordConfig /// Returns the max length of an application description. /// public const int MaxApplicationDescriptionLength = 400; + + /// + /// Returns the max number of user IDs that can be requested in a Request Guild Members chunk. + /// + public const int MaxRequestedUserIdsPerRequestGuildMembersChunk = 100; } } diff --git a/src/Discord.Net.Core/Entities/Guilds/IGuild.cs b/src/Discord.Net.Core/Entities/Guilds/IGuild.cs index aec5bff1e6..cdcd196c0c 100644 --- a/src/Discord.Net.Core/Entities/Guilds/IGuild.cs +++ b/src/Discord.Net.Core/Entities/Guilds/IGuild.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Globalization; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace Discord @@ -958,6 +959,33 @@ public interface IGuild : IDeletable, ISnowflakeEntity /// A task that represents the asynchronous download operation. /// Task DownloadUsersAsync(); + + /// + /// Downloads specific users for this guild with the default request timeout. + /// + /// + /// This method downloads all users specified in through the Gateway and caches them. + /// Consider using when downloading a large number of users. + /// + /// The list of Discord user IDs to download. + /// + /// A task that represents the asynchronous download operation. + /// + /// The timeout has elapsed. + Task DownloadUsersAsync(IEnumerable userIds); + + /// + /// Downloads specific users for this guild. + /// + /// + /// This method downloads all users specified in through the Gateway and caches them. + /// + /// The list of Discord user IDs to download. + /// The cancellation token used to cancel the task. + /// + /// A task that represents the asynchronous download operation. + /// + Task DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken); /// /// Prunes inactive users. /// diff --git a/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs b/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs new file mode 100644 index 0000000000..0ce262c90f --- /dev/null +++ b/src/Discord.Net.Core/Extensions/EnumerableExtensions.cs @@ -0,0 +1,103 @@ +// Based on https://github.com/dotnet/runtime/blob/main/src/libraries/System.Linq/src/System/Linq/Chunk.cs (only available on .NET 6+) +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; + +namespace Discord +{ + internal static class EnumerableExtensions + { + /// + /// Split the elements of a sequence into chunks of size at most . + /// + /// + /// Every chunk except the last will be of size . + /// The last chunk will contain the remaining elements and may be of a smaller size. + /// + /// + /// An whose elements to chunk. + /// + /// + /// Maximum size of each chunk. + /// + /// + /// The type of the elements of source. + /// + /// + /// An that contains the elements the input sequence split into chunks of size . + /// + /// + /// is null. + /// + /// + /// is below 1. + /// + public static IEnumerable Chunk(this IEnumerable source, int size) + { + Preconditions.NotNull(source, nameof(source)); + Preconditions.GreaterThan(size, 0, nameof(size)); + + return ChunkIterator(source, size); + } + + private static IEnumerable ChunkIterator(IEnumerable source, int size) + { + using IEnumerator e = source.GetEnumerator(); + + // Before allocating anything, make sure there's at least one element. + if (e.MoveNext()) + { + // Now that we know we have at least one item, allocate an initial storage array. This is not + // the array we'll yield. It starts out small in order to avoid significantly overallocating + // when the source has many fewer elements than the chunk size. + int arraySize = Math.Min(size, 4); + int i; + do + { + var array = new TSource[arraySize]; + + // Store the first item. + array[0] = e.Current; + i = 1; + + if (size != array.Length) + { + // This is the first chunk. As we fill the array, grow it as needed. + for (; i < size && e.MoveNext(); i++) + { + if (i >= array.Length) + { + arraySize = (int)Math.Min((uint)size, 2 * (uint)array.Length); + Array.Resize(ref array, arraySize); + } + + array[i] = e.Current; + } + } + else + { + // For all but the first chunk, the array will already be correctly sized. + // We can just store into it until either it's full or MoveNext returns false. + TSource[] local = array; // avoid bounds checks by using cached local (`array` is lifted to iterator object as a field) + Debug.Assert(local.Length == size); + for (; (uint)i < (uint)local.Length && e.MoveNext(); i++) + { + local[i] = e.Current; + } + } + + if (i != array.Length) + { + Array.Resize(ref array, i); + } + + yield return array; + } + while (i >= size && e.MoveNext()); + } + } + } +} diff --git a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs index a6c2d2d998..fdf35b09ca 100644 --- a/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs +++ b/src/Discord.Net.Rest/Entities/Guilds/RestGuild.cs @@ -6,6 +6,7 @@ using System.Globalization; using System.IO; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Model = Discord.API.Guild; using WidgetModel = Discord.API.GuildWidget; @@ -1512,6 +1513,14 @@ async Task> IGuild.GetUsersAsync(CacheMode mode, Task IGuild.DownloadUsersAsync() => throw new NotSupportedException(); /// + /// Downloading users is not supported for a REST-based guild. + Task IGuild.DownloadUsersAsync(IEnumerable userIds) => + throw new NotSupportedException(); + /// + /// Downloading users is not supported for a REST-based guild. + Task IGuild.DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken) => + throw new NotSupportedException(); + /// async Task> IGuild.SearchUsersAsync(string query, int limit, CacheMode mode, RequestOptions options) { if (mode == CacheMode.AllowDownload) @@ -1604,7 +1613,7 @@ async Task IGuild.GetAutoModRulesAsync(RequestOptions options) /// async Task IGuild.CreateAutoModRuleAsync(Action props, RequestOptions options) => await CreateAutoModRuleAsync(props, options).ConfigureAwait(false); - + /// async Task IGuild.GetOnboardingAsync(RequestOptions options) => await GetOnboardingAsync(options); diff --git a/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs b/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs index 26114bf541..e62dae08bf 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/GuildMembersChunkEvent.cs @@ -1,4 +1,5 @@ using Newtonsoft.Json; +using System.Collections.Generic; namespace Discord.API.Gateway { @@ -8,5 +9,15 @@ internal class GuildMembersChunkEvent public ulong GuildId { get; set; } [JsonProperty("members")] public GuildMember[] Members { get; set; } + [JsonProperty("chunk_index")] + public int ChunkIndex { get; set; } + [JsonProperty("chunk_count")] + public int ChunkCount { get; set; } + [JsonProperty("not_found")] + public Optional> NotFound { get; set; } + [JsonProperty("presences")] + public Optional> Presences { get; set; } + [JsonProperty("nonce")] + public Optional Nonce { get; set; } } } diff --git a/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs b/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs index f7a63e330c..d88fce51c5 100644 --- a/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs +++ b/src/Discord.Net.WebSocket/API/Gateway/RequestMembersParams.cs @@ -6,12 +6,17 @@ namespace Discord.API.Gateway [JsonObject(MemberSerialization = MemberSerialization.OptIn)] internal class RequestMembersParams { + [JsonProperty("guild_id")] + public ulong GuildId { get; set; } [JsonProperty("query")] - public string Query { get; set; } + public Optional Query { get; set; } [JsonProperty("limit")] public int Limit { get; set; } - - [JsonProperty("guild_id")] - public IEnumerable GuildIds { get; set; } + [JsonProperty("presences")] + public Optional Presences { get; set; } + [JsonProperty("user_ids")] + public Optional> UserIds { get; set; } + [JsonProperty("nonce")] + public Optional Nonce { get; set; } } } diff --git a/src/Discord.Net.WebSocket/BaseSocketClient.cs b/src/Discord.Net.WebSocket/BaseSocketClient.cs index 482a08a0f9..897aa67460 100644 --- a/src/Discord.Net.WebSocket/BaseSocketClient.cs +++ b/src/Discord.Net.WebSocket/BaseSocketClient.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace Discord.WebSocket @@ -236,6 +237,17 @@ private static DiscordSocketApiClient CreateApiClient(DiscordSocketConfig config /// public abstract Task DownloadUsersAsync(IEnumerable guilds); + /// + /// Attempts to download specific users into the user cache for the selected guild. + /// + /// The guild to download the members from. + /// The list of Discord user IDs to download. + /// The cancellation token used to cancel the task. + /// + /// A task that represents the asynchronous download operation. + /// + public abstract Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default); + /// /// Creates a guild for the logged-in user who is in less than 10 active guilds. /// diff --git a/src/Discord.Net.WebSocket/DiscordShardedClient.cs b/src/Discord.Net.WebSocket/DiscordShardedClient.cs index c3809ba672..b4c59c4236 100644 --- a/src/Discord.Net.WebSocket/DiscordShardedClient.cs +++ b/src/Discord.Net.WebSocket/DiscordShardedClient.cs @@ -381,6 +381,23 @@ public override async Task DownloadUsersAsync(IEnumerable guilds) } } + /// + /// is + public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default) + { + Preconditions.NotNull(guild, nameof(guild)); + + for (int i = 0; i < _shards.Length; i++) + { + int id = _shardIds[i]; + if (GetShardIdFor(guild) == id) + { + await _shards[i].DownloadUsersAsync(guild, userIds, cancelToken).ConfigureAwait(false); + break; + } + } + } + private int GetLatency() { int total = 0; diff --git a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs index 75960b173e..64a0496502 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketApiClient.cs @@ -379,11 +379,24 @@ public async Task SendPresenceUpdateAsync(UserStatus status, bool isAFK, long? s options.BucketId = GatewayBucket.Get(GatewayBucketType.PresenceUpdate).Id; await SendGatewayAsync(GatewayOpCode.PresenceUpdate, args, options: options).ConfigureAwait(false); } - public async Task SendRequestMembersAsync(IEnumerable guildIds, RequestOptions options = null) + public async Task SendRequestMembersAsync(ulong guildId, RequestOptions options = null) { options = RequestOptions.CreateOrClone(options); - await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, new RequestMembersParams { GuildIds = guildIds, Query = "", Limit = 0 }, options: options).ConfigureAwait(false); + await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, new RequestMembersParams { GuildId = guildId, Query = "", Limit = 0 }, options: options).ConfigureAwait(false); } + public async Task SendRequestMembersAsync(ulong guildId, IEnumerable userIds, string nonce, RequestOptions options = null) + { + var payload = new RequestMembersParams + { + GuildId = guildId, + Limit = 0, + UserIds = new Optional>(userIds), + Nonce = nonce + }; + options = RequestOptions.CreateOrClone(options); + await SendGatewayAsync(GatewayOpCode.RequestGuildMembers, payload, options: options).ConfigureAwait(false); + } + public async Task SendVoiceStateUpdateAsync(ulong guildId, ulong? channelId, bool selfDeaf, bool selfMute, RequestOptions options = null) { var payload = new VoiceStateUpdateParams diff --git a/src/Discord.Net.WebSocket/DiscordSocketClient.cs b/src/Discord.Net.WebSocket/DiscordSocketClient.cs index 924f5f645c..b7ddc73139 100644 --- a/src/Discord.Net.WebSocket/DiscordSocketClient.cs +++ b/src/Discord.Net.WebSocket/DiscordSocketClient.cs @@ -37,6 +37,7 @@ public partial class DiscordSocketClient : BaseSocketClient, IDiscordClient private readonly ConnectionManager _connection; private readonly Logger _gatewayLogger; private readonly SemaphoreSlim _stateLock; + private readonly ConcurrentDictionary> _guildMembersRequestTasks; private string _sessionId; private int _lastSeq; @@ -51,6 +52,7 @@ public partial class DiscordSocketClient : BaseSocketClient, IDiscordClient private GatewayIntents _gatewayIntents; private ImmutableArray> _defaultStickers; private SocketSelfUser _previousSessionUser; + private long _guildMembersRequestCounter; /// /// Provides access to a REST-only client with a shared state from this client. @@ -183,6 +185,8 @@ private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClie e.ErrorContext.Handled = true; }; + _guildMembersRequestTasks = new ConcurrentDictionary>(); + ApiClient.SentGatewayMessage += async opCode => await _gatewayLogger.DebugAsync($"Sent {opCode}").ConfigureAwait(false); ApiClient.ReceivedGatewayEvent += ProcessMessageAsync; @@ -627,29 +631,51 @@ private async Task ProcessUserDownloadsAsync(IEnumerable guilds) { var cachedGuilds = guilds.ToImmutableArray(); - const short batchSize = 1; - ulong[] batchIds = new ulong[Math.Min(batchSize, cachedGuilds.Length)]; - Task[] batchTasks = new Task[batchIds.Length]; - int batchCount = (cachedGuilds.Length + (batchSize - 1)) / batchSize; + foreach (var guild in cachedGuilds) + { + await ApiClient.SendRequestMembersAsync(guild.Id).ConfigureAwait(false); + await guild.DownloaderPromise.ConfigureAwait(false); + } + } - for (int i = 0, k = 0; i < batchCount; i++) + /// + public override async Task DownloadUsersAsync(IGuild guild, IEnumerable userIds, CancellationToken cancelToken = default) + { + if (ConnectionState == ConnectionState.Connected) { - bool isLast = i == batchCount - 1; - int count = isLast ? (cachedGuilds.Length - (batchCount - 1) * batchSize) : batchSize; + EnsureGatewayIntent(GatewayIntents.GuildMembers); - for (int j = 0; j < count; j++, k++) + var socketGuild = GetGuild(guild.Id); + if (socketGuild != null) { - var guild = cachedGuilds[k]; - batchIds[j] = guild.Id; - batchTasks[j] = guild.DownloaderPromise; + foreach (var chunk in userIds.Chunk(DiscordConfig.MaxRequestedUserIdsPerRequestGuildMembersChunk)) + { + await ProcessUserDownloadsAsync(socketGuild, chunk, cancelToken).ConfigureAwait(false); + } } + } + else + { + throw new InvalidOperationException("Client not connected"); + } + } - await ApiClient.SendRequestMembersAsync(batchIds).ConfigureAwait(false); - if (isLast && batchCount > 1) - await Task.WhenAll(batchTasks.Take(count)).ConfigureAwait(false); - else - await Task.WhenAll(batchTasks).ConfigureAwait(false); + private async Task ProcessUserDownloadsAsync(SocketGuild guild, IEnumerable userIds, CancellationToken cancelToken = default) + { + var nonce = Interlocked.Increment(ref _guildMembersRequestCounter).ToString(); + var tcs = new TaskCompletionSource(); + using var registration = cancelToken.Register(() => tcs.TrySetCanceled()); + _guildMembersRequestTasks.TryAdd(nonce, tcs); + try + { + await ApiClient.SendRequestMembersAsync(guild.Id, userIds, nonce).ConfigureAwait(false); + await tcs.Task.ConfigureAwait(false); + cancelToken.ThrowIfCancellationRequested(); + } + finally + { + _guildMembersRequestTasks.TryRemove(nonce, out _); } } @@ -1410,6 +1436,13 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty guild.CompleteDownloadUsers(); await TimedInvokeAsync(_guildMembersDownloadedEvent, nameof(GuildMembersDownloaded), guild).ConfigureAwait(false); } + + if (data.Nonce.IsSpecified + && data.ChunkIndex + 1 >= data.ChunkCount + && _guildMembersRequestTasks.TryRemove(data.Nonce.Value, out var tcs)) + { + tcs.TrySetResult(true); + } } else { @@ -2904,7 +2937,7 @@ private async Task ProcessMessageAsync(GatewayOpCode opCode, int? seq, string ty } break; #endregion - + #region Auto Moderation case "AUTO_MODERATION_RULE_CREATE": diff --git a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs index 9180ad92f3..d8f5ea7d25 100644 --- a/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs +++ b/src/Discord.Net.WebSocket/Entities/Guilds/SocketGuild.cs @@ -383,7 +383,7 @@ public IReadOnlyCollection Stickers /// /// /// Otherwise, you may need to enable to fetch - /// the full user list upon startup, or use to manually download + /// the full user list upon startup, or use to manually download /// the users. /// /// @@ -1260,6 +1260,20 @@ public async Task DownloadUsersAsync() { await Discord.DownloadUsersAsync(new[] { this }).ConfigureAwait(false); } + + /// + public async Task DownloadUsersAsync(IEnumerable userIds) + { + using var cts = new CancellationTokenSource(DiscordConfig.DefaultRequestTimeout); + await DownloadUsersAsync(userIds, cts.Token).ConfigureAwait(false); + } + + /// + public async Task DownloadUsersAsync(IEnumerable userIds, CancellationToken cancelToken) + { + await Discord.DownloadUsersAsync(this, userIds, cancelToken).ConfigureAwait(false); + } + internal void CompleteDownloadUsers() { _downloaderPromise.TrySetResultAsync(true); @@ -1406,7 +1420,7 @@ public Task CreateEventAsync( /// /// A task that represents the asynchronous get operation. The task result contains a read-only collection /// of the requested audit log entries. - /// + /// public IAsyncEnumerable> GetAuditLogsAsync(int limit, RequestOptions options = null, ulong? beforeId = null, ulong? userId = null, ActionType? actionType = null, ulong? afterId = null) => GuildHelper.GetAuditLogsAsync(this, Discord, beforeId, limit, options, userId: userId, actionType: actionType, afterId: afterId);