diff --git a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
index bbd6222a8..6d5923fae 100644
--- a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
+++ b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
@@ -2,6 +2,7 @@
// Licensed under the MIT License. See License.txt in the project root for license information.
using System;
+using System.Diagnostics;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
@@ -18,6 +19,80 @@ namespace Microsoft.Azure.Functions.Worker;
///
public static class DurableTaskClientExtensions
{
+ ///
+ ///
+ ///
+ /// The .
+ /// The HTTP request that this response is for.
+ /// The ID of the orchestration instance to check.
+ /// The cancellation token.
+ /// Total allowed timeout for output from the durable function. The default value is 10 seconds.
+ /// The timeout between checks for output from the durable function. The default value is 1 second.
+ /// Optional parameter that configures the http response code returned. Defaults to false.
+ ///
+ public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync(this DurableTaskClient client,
+ HttpRequestData request,
+ string instanceId,
+ CancellationToken cancellation = default,
+ TimeSpan? timeout = null,
+ TimeSpan? retryInterval = null,
+ bool returnInternalServerErrorOnFailure = false
+ )
+ {
+ TimeSpan timeoutLocal = timeout ?? TimeSpan.FromSeconds(10);
+ TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1);
+
+ if (retryIntervalLocal > timeoutLocal)
+ {
+ throw new ArgumentException($"Total timeout {timeoutLocal.TotalSeconds} should be bigger than retry timeout {retryIntervalLocal.TotalSeconds}");
+ }
+
+ Stopwatch stopwatch = Stopwatch.StartNew();
+ while (true)
+ {
+ var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: true);
+ if (status != null)
+ {
+ if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
+#pragma warning disable CS0618 // Type or member is obsolete
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
+#pragma warning restore CS0618 // Type or member is obsolete
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
+ {
+ var response = request.CreateResponse(HttpStatusCode.OK);
+ await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId)
+ {
+ CreatedAt = status.CreatedAt,
+ LastUpdatedAt = status.LastUpdatedAt,
+ RuntimeStatus = status.RuntimeStatus,
+ SerializedInput = status.SerializedInput,
+ SerializedOutput = status.SerializedOutput,
+ SerializedCustomStatus = status.SerializedCustomStatus,
+ });
+
+ if (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure)
+ {
+ response.StatusCode = HttpStatusCode.InternalServerError;
+ }
+
+ return response;
+ }
+ }
+
+ TimeSpan elapsed = stopwatch.Elapsed;
+ if (elapsed < timeout)
+ {
+ TimeSpan remainingTime = timeoutLocal!.Subtract(elapsed);
+ await Task.Delay(remainingTime > retryIntervalLocal ? retryIntervalLocal : remainingTime);
+ }
+ else
+ {
+ return await CreateCheckStatusResponseAsync(client, request, instanceId, cancellation: cancellation);
+ }
+ }
+ }
+
///
/// Creates an HTTP response that is useful for checking the status of the specified instance.
///
@@ -176,7 +251,7 @@ static string BuildUrl(string url, params string?[] queryValues)
{
throw new InvalidOperationException("Failed to create HTTP management payload as base URL is null. Either use Functions bindings or provide an HTTP request to create the HttpPayload.");
}
-
+
bool isFromRequest = request != null;
string formattedInstanceId = Uri.EscapeDataString(instanceId);
diff --git a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
index 6f975d2c5..d17a635db 100644
--- a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
+++ b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
@@ -1,6 +1,10 @@
+using System.Net;
+using Azure.Core.Serialization;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.DurableTask.Client;
+using Microsoft.Extensions.Options;
using Moq;
+using Newtonsoft.Json;
namespace Microsoft.Azure.Functions.Worker.Tests
{
@@ -9,7 +13,7 @@ namespace Microsoft.Azure.Functions.Worker.Tests
///
public class FunctionsDurableTaskClientTests
{
- private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null)
+ private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null, OrchestrationMetadata? orchestrationMetadata = null)
{
// construct mock client
@@ -21,6 +25,12 @@ private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? bas
durableClientMock.Setup(x => x.TerminateInstanceAsync(
It.IsAny(), It.IsAny(), It.IsAny())).Returns(completedTask);
+ if (orchestrationMetadata != null)
+ {
+ durableClientMock.Setup(x => x.GetInstancesAsync(orchestrationMetadata.InstanceId, It.IsAny(), It.IsAny()))
+ .ReturnsAsync(orchestrationMetadata);
+ }
+
DurableTaskClient durableClient = durableClientMock.Object;
FunctionsDurableTaskClient client = new FunctionsDurableTaskClient(durableClient, queryString: null, httpBaseUrl: baseUrl);
return client;
@@ -89,6 +99,116 @@ public void CreateHttpManagementPayload_WithHttpRequestData()
AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId);
}
+ ///
+ /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns the expected response when the orchestration is completed.
+ /// The expected response should include OrchestrationMetadata in the body with an HttpStatusCode.OK.
+ ///
+ [Fact]
+ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenCompleted()
+ {
+ string instanceId = "test-instance-id-completed";
+ var expectedResult = new OrchestrationMetadata("TestCompleted", instanceId)
+ {
+ CreatedAt = DateTime.UtcNow,
+ LastUpdatedAt = DateTime.UtcNow,
+ RuntimeStatus = OrchestrationRuntimeStatus.Completed,
+ SerializedCustomStatus = "TestCustomStatus",
+ SerializedInput = "TestInput",
+ SerializedOutput = "TestOutput"
+ };
+
+ var client = this.GetTestFunctionsDurableTaskClient( orchestrationMetadata: expectedResult);
+
+ HttpRequestData request = this.MockHttpRequestAndResponseData();
+
+ HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId);
+
+ Assert.NotNull(response);
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+
+ // Reset stream position for reading
+ response.Body.Position = 0;
+ var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body);
+
+ // Assert the response content is not null and check the content is correct.
+ Assert.NotNull(orchestratorMetadata);
+ AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata);
+ }
+
+ ///
+ /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestration is still running.
+ /// The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted.
+ ///
+ [Fact]
+ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning()
+ {
+ string instanceId = "test-instance-id-running";
+ var expectedResult = new OrchestrationMetadata("TestRunning", instanceId)
+ {
+ CreatedAt = DateTime.UtcNow,
+ LastUpdatedAt = DateTime.UtcNow,
+ RuntimeStatus = OrchestrationRuntimeStatus.Running,
+ };
+
+ var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult);
+
+ HttpRequestData request = this.MockHttpRequestAndResponseData();
+
+ HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId);
+
+ Assert.NotNull(response);
+ Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);
+
+ // Reset stream position for reading
+ response.Body.Position = 0;
+ HttpManagementPayload? payload;
+ using (var reader = new StreamReader(response.Body))
+ {
+ payload = JsonConvert.DeserializeObject(await reader.ReadToEndAsync());
+ }
+
+ // Assert the response content is not null and check the content is correct.
+ Assert.NotNull(payload);
+ AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId);
+ }
+
+ ///
+ /// Tests the `WaitForCompletionOrCreateCheckStatusResponseAsync` method to ensure it returns the correct HTTP status code
+ /// based on the `returnInternalServerErrorOnFailure` parameter when the orchestration has failed.
+ ///
+ [Theory]
+ [InlineData(true, HttpStatusCode.InternalServerError)]
+ [InlineData(false, HttpStatusCode.OK)]
+ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFailed(bool returnInternalServerErrorOnFailure, HttpStatusCode expected)
+ {
+ string instanceId = "test-instance-id-failed";
+ var expectedResult = new OrchestrationMetadata("TestFailed", instanceId)
+ {
+ CreatedAt = DateTime.UtcNow,
+ LastUpdatedAt = DateTime.UtcNow,
+ RuntimeStatus = OrchestrationRuntimeStatus.Failed,
+ SerializedOutput = "Microsoft.DurableTask.TaskFailedException: Task 'SayHello' (#0) failed with an unhandled exception: Exception while executing function: Functions.SayHello",
+ SerializedInput = null
+ };
+
+ var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult);
+
+ HttpRequestData request = this.MockHttpRequestAndResponseData();
+
+ HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, returnInternalServerErrorOnFailure: returnInternalServerErrorOnFailure);
+
+ Assert.NotNull(response);
+ Assert.Equal(expected, response.StatusCode);
+
+ // Reset stream position for reading
+ response.Body.Position = 0;
+ var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body);
+
+ // Assert the response content is not null and check the content is correct.
+ Assert.NotNull(orchestratorMetadata);
+ AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata);
+ }
+
private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId)
{
Assert.Equal(instanceId, payload.Id);
@@ -99,5 +219,71 @@ private static void AssertHttpManagementPayload(HttpManagementPayload payload, s
Assert.Equal($"{BaseUrl}/instances/{instanceId}/suspend?reason={{{{text}}}}", payload.SuspendPostUri);
Assert.Equal($"{BaseUrl}/instances/{instanceId}/resume?reason={{{{text}}}}", payload.ResumePostUri);
}
+
+ private static void AssertOrhcestrationMetadata( OrchestrationMetadata expected, OrchestrationMetadata actual)
+ {
+ Assert.Equal(expected.InstanceId, actual.InstanceId);
+ Assert.Equal(expected.CreatedAt, actual.CreatedAt);
+ Assert.Equal(expected.LastUpdatedAt, actual.LastUpdatedAt);
+ Assert.Equal(expected.RuntimeStatus, actual.RuntimeStatus);
+ Assert.Equal(expected.SerializedInput, actual.SerializedInput);
+ Assert.Equal(expected.SerializedOutput, actual.SerializedOutput);
+ Assert.Equal(expected.SerializedCustomStatus, actual.SerializedCustomStatus);
+ }
+
+ // Mocks the required HttpRequestData and HttpResponseData for testing purposes.
+ // This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body.
+ private HttpRequestData MockHttpRequestAndResponseData()
+ {
+ var mockObjectSerializer = new Mock();
+
+ // Setup the SerializeAsync method
+ mockObjectSerializer.Setup(s => s.SerializeAsync(It.IsAny(), It.IsAny