diff --git a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs index bbd6222a8..6d5923fae 100644 --- a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs +++ b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using System; +using System.Diagnostics; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -18,6 +19,80 @@ namespace Microsoft.Azure.Functions.Worker; /// public static class DurableTaskClientExtensions { + /// + /// + /// + /// The . + /// The HTTP request that this response is for. + /// The ID of the orchestration instance to check. + /// The cancellation token. + /// Total allowed timeout for output from the durable function. The default value is 10 seconds. + /// The timeout between checks for output from the durable function. The default value is 1 second. + /// Optional parameter that configures the http response code returned. Defaults to false. + /// + public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync(this DurableTaskClient client, + HttpRequestData request, + string instanceId, + CancellationToken cancellation = default, + TimeSpan? timeout = null, + TimeSpan? retryInterval = null, + bool returnInternalServerErrorOnFailure = false + ) + { + TimeSpan timeoutLocal = timeout ?? TimeSpan.FromSeconds(10); + TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1); + + if (retryIntervalLocal > timeoutLocal) + { + throw new ArgumentException($"Total timeout {timeoutLocal.TotalSeconds} should be bigger than retry timeout {retryIntervalLocal.TotalSeconds}"); + } + + Stopwatch stopwatch = Stopwatch.StartNew(); + while (true) + { + var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: true); + if (status != null) + { + if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed || +#pragma warning disable CS0618 // Type or member is obsolete + status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled || +#pragma warning restore CS0618 // Type or member is obsolete + status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated || + status.RuntimeStatus == OrchestrationRuntimeStatus.Failed) + { + var response = request.CreateResponse(HttpStatusCode.OK); + await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId) + { + CreatedAt = status.CreatedAt, + LastUpdatedAt = status.LastUpdatedAt, + RuntimeStatus = status.RuntimeStatus, + SerializedInput = status.SerializedInput, + SerializedOutput = status.SerializedOutput, + SerializedCustomStatus = status.SerializedCustomStatus, + }); + + if (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) + { + response.StatusCode = HttpStatusCode.InternalServerError; + } + + return response; + } + } + + TimeSpan elapsed = stopwatch.Elapsed; + if (elapsed < timeout) + { + TimeSpan remainingTime = timeoutLocal!.Subtract(elapsed); + await Task.Delay(remainingTime > retryIntervalLocal ? retryIntervalLocal : remainingTime); + } + else + { + return await CreateCheckStatusResponseAsync(client, request, instanceId, cancellation: cancellation); + } + } + } + /// /// Creates an HTTP response that is useful for checking the status of the specified instance. /// @@ -176,7 +251,7 @@ static string BuildUrl(string url, params string?[] queryValues) { throw new InvalidOperationException("Failed to create HTTP management payload as base URL is null. Either use Functions bindings or provide an HTTP request to create the HttpPayload."); } - + bool isFromRequest = request != null; string formattedInstanceId = Uri.EscapeDataString(instanceId); diff --git a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs index 6f975d2c5..d17a635db 100644 --- a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs +++ b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs @@ -1,6 +1,10 @@ +using System.Net; +using Azure.Core.Serialization; using Microsoft.Azure.Functions.Worker.Http; using Microsoft.DurableTask.Client; +using Microsoft.Extensions.Options; using Moq; +using Newtonsoft.Json; namespace Microsoft.Azure.Functions.Worker.Tests { @@ -9,7 +13,7 @@ namespace Microsoft.Azure.Functions.Worker.Tests /// public class FunctionsDurableTaskClientTests { - private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null) + private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null, OrchestrationMetadata? orchestrationMetadata = null) { // construct mock client @@ -21,6 +25,12 @@ private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? bas durableClientMock.Setup(x => x.TerminateInstanceAsync( It.IsAny(), It.IsAny(), It.IsAny())).Returns(completedTask); + if (orchestrationMetadata != null) + { + durableClientMock.Setup(x => x.GetInstancesAsync(orchestrationMetadata.InstanceId, It.IsAny(), It.IsAny())) + .ReturnsAsync(orchestrationMetadata); + } + DurableTaskClient durableClient = durableClientMock.Object; FunctionsDurableTaskClient client = new FunctionsDurableTaskClient(durableClient, queryString: null, httpBaseUrl: baseUrl); return client; @@ -89,6 +99,116 @@ public void CreateHttpManagementPayload_WithHttpRequestData() AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId); } + /// + /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns the expected response when the orchestration is completed. + /// The expected response should include OrchestrationMetadata in the body with an HttpStatusCode.OK. + /// + [Fact] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenCompleted() + { + string instanceId = "test-instance-id-completed"; + var expectedResult = new OrchestrationMetadata("TestCompleted", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Completed, + SerializedCustomStatus = "TestCustomStatus", + SerializedInput = "TestInput", + SerializedOutput = "TestOutput" + }; + + var client = this.GetTestFunctionsDurableTaskClient( orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + + HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId); + + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body); + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(orchestratorMetadata); + AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata); + } + + /// + /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestration is still running. + /// The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted. + /// + [Fact] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning() + { + string instanceId = "test-instance-id-running"; + var expectedResult = new OrchestrationMetadata("TestRunning", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Running, + }; + + var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + + HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId); + + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + HttpManagementPayload? payload; + using (var reader = new StreamReader(response.Body)) + { + payload = JsonConvert.DeserializeObject(await reader.ReadToEndAsync()); + } + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(payload); + AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId); + } + + /// + /// Tests the `WaitForCompletionOrCreateCheckStatusResponseAsync` method to ensure it returns the correct HTTP status code + /// based on the `returnInternalServerErrorOnFailure` parameter when the orchestration has failed. + /// + [Theory] + [InlineData(true, HttpStatusCode.InternalServerError)] + [InlineData(false, HttpStatusCode.OK)] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFailed(bool returnInternalServerErrorOnFailure, HttpStatusCode expected) + { + string instanceId = "test-instance-id-failed"; + var expectedResult = new OrchestrationMetadata("TestFailed", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Failed, + SerializedOutput = "Microsoft.DurableTask.TaskFailedException: Task 'SayHello' (#0) failed with an unhandled exception: Exception while executing function: Functions.SayHello", + SerializedInput = null + }; + + var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + + HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, returnInternalServerErrorOnFailure: returnInternalServerErrorOnFailure); + + Assert.NotNull(response); + Assert.Equal(expected, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body); + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(orchestratorMetadata); + AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata); + } + private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId) { Assert.Equal(instanceId, payload.Id); @@ -99,5 +219,71 @@ private static void AssertHttpManagementPayload(HttpManagementPayload payload, s Assert.Equal($"{BaseUrl}/instances/{instanceId}/suspend?reason={{{{text}}}}", payload.SuspendPostUri); Assert.Equal($"{BaseUrl}/instances/{instanceId}/resume?reason={{{{text}}}}", payload.ResumePostUri); } + + private static void AssertOrhcestrationMetadata( OrchestrationMetadata expected, OrchestrationMetadata actual) + { + Assert.Equal(expected.InstanceId, actual.InstanceId); + Assert.Equal(expected.CreatedAt, actual.CreatedAt); + Assert.Equal(expected.LastUpdatedAt, actual.LastUpdatedAt); + Assert.Equal(expected.RuntimeStatus, actual.RuntimeStatus); + Assert.Equal(expected.SerializedInput, actual.SerializedInput); + Assert.Equal(expected.SerializedOutput, actual.SerializedOutput); + Assert.Equal(expected.SerializedCustomStatus, actual.SerializedCustomStatus); + } + + // Mocks the required HttpRequestData and HttpResponseData for testing purposes. + // This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body. + private HttpRequestData MockHttpRequestAndResponseData() + { + var mockObjectSerializer = new Mock(); + + // Setup the SerializeAsync method + mockObjectSerializer.Setup(s => s.SerializeAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(async (stream, value, type, token) => + { + await System.Text.Json.JsonSerializer.SerializeAsync(stream, value, type, cancellationToken: token); + }); + + var workerOptions = new WorkerOptions + { + Serializer = mockObjectSerializer.Object + }; + var mockOptions = new Mock>(); + mockOptions.Setup(o => o.Value).Returns(workerOptions); + + // Mock the service provider + var mockServiceProvider = new Mock(); + + // Set up the service provider to return the mock IOptions + mockServiceProvider.Setup(sp => sp.GetService(typeof(IOptions))) + .Returns(mockOptions.Object); + + // Set up the service provider to return the mock ObjectSerializer + mockServiceProvider.Setup(sp => sp.GetService(typeof(ObjectSerializer))) + .Returns(mockObjectSerializer.Object); + + // Create a mock FunctionContext and assign the service provider + var mockFunctionContext = new Mock(); + mockFunctionContext.SetupGet(c => c.InstanceServices).Returns(mockServiceProvider.Object); + var mockHttpRequestData = new Mock(mockFunctionContext.Object); + + // Set up the URL property. + mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("http://localhost:7075/orchestrators/E1_HelloSequence")); + + var mockHttpResponseData = new Mock(mockFunctionContext.Object) + { + DefaultValue = DefaultValue.Mock + }; + + // Enable setting StatusCode and Body as mutable properties + mockHttpResponseData.SetupProperty(r => r.StatusCode, HttpStatusCode.OK); + mockHttpResponseData.SetupProperty(r => r.Body, new MemoryStream()); + + // Setup CreateResponse to return the configured HttpResponseData mock + mockHttpRequestData.Setup(r => r.CreateResponse()) + .Returns(mockHttpResponseData.Object); + + return mockHttpRequestData.Object; + } } }