Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WaitForCompletionOrCreateCheckStatusResponseAsync to Microsoft.Azure.Functions.Worker.DurableTaskClientExtensions #2875

Open
wants to merge 9 commits into
base: dev
Choose a base branch
from
77 changes: 76 additions & 1 deletion src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Diagnostics;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -18,6 +19,80 @@ namespace Microsoft.Azure.Functions.Worker;
/// </summary>
public static class DurableTaskClientExtensions
{
/// <summary>
///
/// </summary>
/// <param name="client">The <see cref="DurableTaskClient"/>.</param>
/// <param name="request">The HTTP request that this response is for.</param>
/// <param name="instanceId">The ID of the orchestration instance to check.</param>
/// <param name="cancellation">The cancellation token.</param>
/// <param name="timeout">Total allowed timeout for output from the durable function. The default value is 10 seconds.</param>
/// <param name="retryInterval">The timeout between checks for output from the durable function. The default value is 1 second.</param>
/// <param name="returnInternalServerErrorOnFailure">Optional parameter that configures the http response code returned. Defaults to <c>false</c>.</param>
/// <returns></returns>
public static async Task<HttpResponseData> WaitForCompletionOrCreateCheckStatusResponseAsync(this DurableTaskClient client,
HttpRequestData request,
string instanceId,
CancellationToken cancellation = default,
TimeSpan? timeout = null,
TimeSpan? retryInterval = null,
bool returnInternalServerErrorOnFailure = false
)
{
TimeSpan timeoutLocal = timeout ?? TimeSpan.FromSeconds(10);
TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1);

if (retryIntervalLocal > timeoutLocal)
{
throw new ArgumentException($"Total timeout {timeoutLocal.TotalSeconds} should be bigger than retry timeout {retryIntervalLocal.TotalSeconds}");
}

Stopwatch stopwatch = Stopwatch.StartNew();
while (true)
{
var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: true);
if (status != null)
{
if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
#pragma warning disable CS0618 // Type or member is obsolete
status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
#pragma warning restore CS0618 // Type or member is obsolete
status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
{
var response = request.CreateResponse(HttpStatusCode.OK);
await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId)
{
CreatedAt = status.CreatedAt,
LastUpdatedAt = status.LastUpdatedAt,
RuntimeStatus = status.RuntimeStatus,
SerializedInput = status.SerializedInput,
SerializedOutput = status.SerializedOutput,
SerializedCustomStatus = status.SerializedCustomStatus,
});

if (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure)
{
response.StatusCode = HttpStatusCode.InternalServerError;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this logic here, as WriteAsJsonAsync will set response.StatusCode to HttpStatusCode.OK once it completes successfully.


return response;
}
}

TimeSpan elapsed = stopwatch.Elapsed;
if (elapsed < timeout)
{
TimeSpan remainingTime = timeoutLocal!.Subtract(elapsed);
await Task.Delay(remainingTime > retryIntervalLocal ? retryIntervalLocal : remainingTime);
}
else
{
return await CreateCheckStatusResponseAsync(client, request, instanceId, cancellation: cancellation);
}
}
}

/// <summary>
/// Creates an HTTP response that is useful for checking the status of the specified instance.
/// </summary>
Expand Down Expand Up @@ -176,7 +251,7 @@ static string BuildUrl(string url, params string?[] queryValues)
{
throw new InvalidOperationException("Failed to create HTTP management payload as base URL is null. Either use Functions bindings or provide an HTTP request to create the HttpPayload.");
}

bool isFromRequest = request != null;

string formattedInstanceId = Uri.EscapeDataString(instanceId);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using System.Net;
using Azure.Core.Serialization;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.DurableTask.Client;
using Microsoft.Extensions.Options;
using Moq;
using Newtonsoft.Json;

namespace Microsoft.Azure.Functions.Worker.Tests
{
Expand All @@ -9,7 +13,7 @@ namespace Microsoft.Azure.Functions.Worker.Tests
/// </summary>
public class FunctionsDurableTaskClientTests
{
private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null)
private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null, OrchestrationMetadata? orchestrationMetadata = null)
{
// construct mock client

Expand All @@ -21,6 +25,12 @@ private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? bas
durableClientMock.Setup(x => x.TerminateInstanceAsync(
It.IsAny<string>(), It.IsAny<TerminateInstanceOptions>(), It.IsAny<CancellationToken>())).Returns(completedTask);

if (orchestrationMetadata != null)
{
durableClientMock.Setup(x => x.GetInstancesAsync(orchestrationMetadata.InstanceId, It.IsAny<bool>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(orchestrationMetadata);
}

DurableTaskClient durableClient = durableClientMock.Object;
FunctionsDurableTaskClient client = new FunctionsDurableTaskClient(durableClient, queryString: null, httpBaseUrl: baseUrl);
return client;
Expand Down Expand Up @@ -89,6 +99,116 @@ public void CreateHttpManagementPayload_WithHttpRequestData()
AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId);
}

/// <summary>
/// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns the expected response when the orchestration is completed.
/// The expected response should include OrchestrationMetadata in the body with an HttpStatusCode.OK.
/// </summary>
[Fact]
public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenCompleted()
{
string instanceId = "test-instance-id-completed";
var expectedResult = new OrchestrationMetadata("TestCompleted", instanceId)
{
CreatedAt = DateTime.UtcNow,
LastUpdatedAt = DateTime.UtcNow,
RuntimeStatus = OrchestrationRuntimeStatus.Completed,
SerializedCustomStatus = "TestCustomStatus",
SerializedInput = "TestInput",
SerializedOutput = "TestOutput"
};

var client = this.GetTestFunctionsDurableTaskClient( orchestrationMetadata: expectedResult);

HttpRequestData request = this.MockHttpRequestAndResponseData();

HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId);

Assert.NotNull(response);
Assert.Equal(HttpStatusCode.OK, response.StatusCode);

// Reset stream position for reading
response.Body.Position = 0;
var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync<OrchestrationMetadata>(response.Body);

// Assert the response content is not null and check the content is correct.
Assert.NotNull(orchestratorMetadata);
AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata);
}

/// <summary>
/// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestration is still running.
/// The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted.
/// </summary>
[Fact]
public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning()
{
string instanceId = "test-instance-id-running";
var expectedResult = new OrchestrationMetadata("TestRunning", instanceId)
{
CreatedAt = DateTime.UtcNow,
LastUpdatedAt = DateTime.UtcNow,
RuntimeStatus = OrchestrationRuntimeStatus.Running,
};

var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult);

HttpRequestData request = this.MockHttpRequestAndResponseData();

HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId);

Assert.NotNull(response);
Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);

// Reset stream position for reading
response.Body.Position = 0;
HttpManagementPayload? payload;
using (var reader = new StreamReader(response.Body))
{
payload = JsonConvert.DeserializeObject<HttpManagementPayload>(await reader.ReadToEndAsync());
}

// Assert the response content is not null and check the content is correct.
Assert.NotNull(payload);
AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId);
}

/// <summary>
/// Tests the `WaitForCompletionOrCreateCheckStatusResponseAsync` method to ensure it returns the correct HTTP status code
/// based on the `returnInternalServerErrorOnFailure` parameter when the orchestration has failed.
/// </summary>
[Theory]
[InlineData(true, HttpStatusCode.InternalServerError)]
[InlineData(false, HttpStatusCode.OK)]
public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFailed(bool returnInternalServerErrorOnFailure, HttpStatusCode expected)
{
string instanceId = "test-instance-id-failed";
var expectedResult = new OrchestrationMetadata("TestFailed", instanceId)
{
CreatedAt = DateTime.UtcNow,
LastUpdatedAt = DateTime.UtcNow,
RuntimeStatus = OrchestrationRuntimeStatus.Failed,
SerializedOutput = "Microsoft.DurableTask.TaskFailedException: Task 'SayHello' (#0) failed with an unhandled exception: Exception while executing function: Functions.SayHello",
SerializedInput = null
};

var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult);

HttpRequestData request = this.MockHttpRequestAndResponseData();

HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, returnInternalServerErrorOnFailure: returnInternalServerErrorOnFailure);

Assert.NotNull(response);
Assert.Equal(expected, response.StatusCode);

// Reset stream position for reading
response.Body.Position = 0;
var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync<OrchestrationMetadata>(response.Body);

// Assert the response content is not null and check the content is correct.
Assert.NotNull(orchestratorMetadata);
AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata);
}

private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId)
{
Assert.Equal(instanceId, payload.Id);
Expand All @@ -99,5 +219,71 @@ private static void AssertHttpManagementPayload(HttpManagementPayload payload, s
Assert.Equal($"{BaseUrl}/instances/{instanceId}/suspend?reason={{{{text}}}}", payload.SuspendPostUri);
Assert.Equal($"{BaseUrl}/instances/{instanceId}/resume?reason={{{{text}}}}", payload.ResumePostUri);
}

private static void AssertOrhcestrationMetadata( OrchestrationMetadata expected, OrchestrationMetadata actual)
{
Assert.Equal(expected.InstanceId, actual.InstanceId);
Assert.Equal(expected.CreatedAt, actual.CreatedAt);
Assert.Equal(expected.LastUpdatedAt, actual.LastUpdatedAt);
Assert.Equal(expected.RuntimeStatus, actual.RuntimeStatus);
Assert.Equal(expected.SerializedInput, actual.SerializedInput);
Assert.Equal(expected.SerializedOutput, actual.SerializedOutput);
Assert.Equal(expected.SerializedCustomStatus, actual.SerializedCustomStatus);
}

// Mocks the required HttpRequestData and HttpResponseData for testing purposes.
// This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body.
private HttpRequestData MockHttpRequestAndResponseData()
{
var mockObjectSerializer = new Mock<ObjectSerializer>();

// Setup the SerializeAsync method
mockObjectSerializer.Setup(s => s.SerializeAsync(It.IsAny<Stream>(), It.IsAny<object?>(), It.IsAny<Type>(), It.IsAny<CancellationToken>()))
.Returns<Stream, object?, Type, CancellationToken>(async (stream, value, type, token) =>
{
await System.Text.Json.JsonSerializer.SerializeAsync(stream, value, type, cancellationToken: token);
});

var workerOptions = new WorkerOptions
{
Serializer = mockObjectSerializer.Object
};
var mockOptions = new Mock<IOptions<WorkerOptions>>();
mockOptions.Setup(o => o.Value).Returns(workerOptions);

// Mock the service provider
var mockServiceProvider = new Mock<IServiceProvider>();

// Set up the service provider to return the mock IOptions<WorkerOptions>
mockServiceProvider.Setup(sp => sp.GetService(typeof(IOptions<WorkerOptions>)))
.Returns(mockOptions.Object);

// Set up the service provider to return the mock ObjectSerializer
mockServiceProvider.Setup(sp => sp.GetService(typeof(ObjectSerializer)))
.Returns(mockObjectSerializer.Object);

// Create a mock FunctionContext and assign the service provider
var mockFunctionContext = new Mock<FunctionContext>();
mockFunctionContext.SetupGet(c => c.InstanceServices).Returns(mockServiceProvider.Object);
var mockHttpRequestData = new Mock<HttpRequestData>(mockFunctionContext.Object);

// Set up the URL property.
mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("http://localhost:7075/orchestrators/E1_HelloSequence"));

var mockHttpResponseData = new Mock<HttpResponseData>(mockFunctionContext.Object)
{
DefaultValue = DefaultValue.Mock
};

// Enable setting StatusCode and Body as mutable properties
mockHttpResponseData.SetupProperty(r => r.StatusCode, HttpStatusCode.OK);
mockHttpResponseData.SetupProperty(r => r.Body, new MemoryStream());

// Setup CreateResponse to return the configured HttpResponseData mock
mockHttpRequestData.Setup(r => r.CreateResponse())
.Returns(mockHttpResponseData.Object);

return mockHttpRequestData.Object;
}
}
}
Loading