Skip to content

Commit

Permalink
Add support to connect to the LoginFlow AI service
Browse files Browse the repository at this point in the history
  • Loading branch information
sahandilshan committed Sep 30, 2024
1 parent cdb8821 commit ec247a1
Show file tree
Hide file tree
Showing 11 changed files with 976 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@
<groupId>org.apache.axis2.wso2</groupId>
<artifactId>axis2</artifactId>
</dependency>
<dependency>
<groupId>org.wso2.orbit.org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents.wso2</groupId>
<artifactId>httpcore</artifactId>
</dependency>
<dependency>
<groupId>org.wso2.orbit.org.apache.httpcomponents</groupId>
<artifactId>httpasyncclient</artifactId>
</dependency>
<dependency>
<groupId>org.wso2.carbon</groupId>
<artifactId>org.wso2.carbon.utils</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.wso2.carbon.identity.application.mgt.ai;

/**
* Client Exception class for BrandingAI service.
*/
public class LoginFlowAIClientException extends Exception {

private String errorCode;
private LoginFlowAIManagerImpl.HttpResponseWrapper loginFlowAIResponse;

public LoginFlowAIClientException(String message, String errorCode) {

super(message);
this.errorCode = errorCode;
}

public LoginFlowAIClientException(LoginFlowAIManagerImpl.HttpResponseWrapper httpResponseWrapper,
String message, String errorCode) {

super(message);
this.errorCode = errorCode;
this.loginFlowAIResponse = httpResponseWrapper;
}

public LoginFlowAIClientException(String message, Throwable cause) {

super(cause);
}

public LoginFlowAIClientException(String message, String errorCode, Throwable cause) {

super(message, cause);
this.errorCode = errorCode;
}

public String getErrorCode() {

return errorCode;
}

public LoginFlowAIManagerImpl.HttpResponseWrapper getLoginFlowAIResponse() {

return loginFlowAIResponse;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.wso2.carbon.identity.application.mgt.ai;

import org.json.JSONArray;
import org.json.JSONObject;

/**
* AI Manager interface for the LoginFlowAI module.
*/
public interface LoginFlowAIManager {

String generateAuthenticationSequence(String userQuery, JSONArray userClaims, JSONObject availableAuthenticators)
throws LoginFlowAIServiceException, LoginFlowAIClientException;

JSONObject getAuthenticationSequenceGenerationStatus(String operationId) throws LoginFlowAIServiceException,
LoginFlowAIClientException;

JSONObject getAuthenticationSequenceGenerationResult(String operationId) throws LoginFlowAIServiceException,
LoginFlowAIClientException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
* Copyright (c) 2024, WSO2 LLC. (http://www.wso2.com).
*
* WSO2 LLC. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.wso2.carbon.identity.application.mgt.ai;

import com.google.gson.JsonSyntaxException;
import org.apache.commons.httpclient.HttpStatus;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.impl.nio.client.HttpAsyncClients;
import org.apache.http.util.EntityUtils;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import org.wso2.carbon.context.PrivilegedCarbonContext;
import org.wso2.carbon.identity.core.util.IdentityUtil;

import java.io.IOException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.ErrorMessages.ERROR_RETRIEVING_ACCESS_TOKEN;
import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.ErrorMessages.ERROR_WHILE_CONNECTING_TO_LOGINFLOW_AI_SERVICE;
import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.ErrorMessages.ERROR_WHILE_GENERATING_AUTHENTICATION_SEQUENCE;
import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.ErrorMessages.UNABLE_TO_ACCESS_AI_SERVICE_WITH_RENEW_ACCESS_TOKEN;

/**
* Implementation of the LoginFlowAIManager interface to communicate with the LoginFlowAI service.
*/
public class LoginFlowAIManagerImpl implements LoginFlowAIManager {

private static final String LOGINFLOW_AI_ENDPOINT = IdentityUtil.getProperty(
"AIServices.LoginFlowAI.LoginFlowAIEndpoint");
private static final String LOGINFLOW_AI_GENERATE_ENDPOINT = "/api/server/v1/applications/loginflow/generate";
private static final String LOGINFLOW_AI_STATUS_ENDPOINT = "/api/server/v1/applications/loginflow/status";
private static final String LOGINFLOW_AI_RESULT_ENDPOINT = "/api/server/v1/applications/loginflow/result";

private static final String ERROR_CODE_CLIENT = "AILF_10000";
private static final String ERROR_CODE_SERVER = "AILF_10001";

private static final Log LOG = LogFactory.getLog(LoginFlowAIManagerImpl.class);

@Override
public String generateAuthenticationSequence(String userQuery, JSONArray userClaims,
JSONObject availableAuthenticators) throws LoginFlowAIServiceException,
LoginFlowAIClientException {
JSONObject requestBody = new JSONObject();
requestBody.put("user_query", userQuery);
requestBody.put("user_claims", userClaims);
requestBody.put("available_authenticators", availableAuthenticators);

JSONObject response = executeRequest(LOGINFLOW_AI_GENERATE_ENDPOINT, HttpPost.class, requestBody);
return response.getString("operation_id");
}

@Override
public JSONObject getAuthenticationSequenceGenerationStatus(String operationId) throws LoginFlowAIServiceException,
LoginFlowAIClientException {
return executeRequest(LOGINFLOW_AI_STATUS_ENDPOINT + "/" + operationId, HttpGet.class, null);
}

@Override
public JSONObject getAuthenticationSequenceGenerationResult(String operationId) throws LoginFlowAIServiceException,
LoginFlowAIClientException {
return executeRequest(LOGINFLOW_AI_RESULT_ENDPOINT + "/" + operationId, HttpGet.class, null);
}

private JSONObject executeRequest(String endpoint, Class<? extends HttpUriRequest> requestType, JSONObject
requestBody) throws LoginFlowAIServiceException, LoginFlowAIClientException {
String tenantDomain = PrivilegedCarbonContext.getThreadLocalCarbonContext().getTenantDomain();

try (CloseableHttpAsyncClient client = HttpAsyncClients.createDefault()) {
client.start();
String accessToken = LoginFlowAITokenService.getInstance().getAccessToken(false);
String orgName = LoginFlowAITokenService.getInstance().getClientId();

HttpUriRequest request = createRequest(LOGINFLOW_AI_ENDPOINT + "/t/" + orgName + endpoint, requestType,
accessToken, requestBody);
HttpResponseWrapper loginFlowAIServiceResponse = executeRequestWithRetry(client, request);

int statusCode = loginFlowAIServiceResponse.getStatusCode();
String responseBody = loginFlowAIServiceResponse.getResponseBody();

if (statusCode >= 400) {
handleErrorResponse(statusCode, responseBody, tenantDomain);
}
return convertJsonStringToJsonObject(responseBody);
} catch (IOException | InterruptedException | ExecutionException e) {
throw new LoginFlowAIServiceException(ERROR_WHILE_CONNECTING_TO_LOGINFLOW_AI_SERVICE.getMessage(),
ERROR_WHILE_CONNECTING_TO_LOGINFLOW_AI_SERVICE.getCode(), e);
} catch (LoginFlowAITokenServiceException e) {
throw new LoginFlowAIServiceException("Error occurred while retrieving the access token: " + e.getMessage(),
e.getErrorCode(), e);
}
}

private HttpUriRequest createRequest(String url, Class<? extends HttpUriRequest> requestType, String accessToken,
JSONObject requestBody)
throws IOException {
HttpUriRequest request;
if (requestType == HttpPost.class) {
HttpPost post = new HttpPost(url);
if (requestBody != null) {
post.setEntity(new StringEntity(requestBody.toString()));
}
request = post;
} else if (requestType == HttpGet.class) {
request = new HttpGet(url);
} else {
throw new IllegalArgumentException("Unsupported request type: " + requestType.getName());
}

request.setHeader("Authorization", "Bearer " + accessToken);
request.setHeader("Content-Type", "application/json");
return request;
}

private HttpResponseWrapper executeRequestWithRetry(CloseableHttpAsyncClient client, HttpUriRequest request)
throws InterruptedException, ExecutionException, IOException, LoginFlowAITokenServiceException {
HttpResponseWrapper response = HttpClientHelper.executeRequest(client, request);

if (response.getStatusCode() == HttpStatus.SC_UNAUTHORIZED) {
String newAccessToken = LoginFlowAITokenService.getInstance().getAccessToken(true);
if (newAccessToken == null) {
throw new LoginFlowAITokenServiceException("Failed to renew access token.",
ERROR_RETRIEVING_ACCESS_TOKEN.getCode());
}
request.setHeader("Authorization", "Bearer " + newAccessToken);
response = HttpClientHelper.executeRequest(client, request);
}

return response;
}

private void handleErrorResponse(int statusCode, String responseBody, String tenantDomain)
throws LoginFlowAIServiceException, LoginFlowAIClientException {
if (statusCode == HttpStatus.SC_UNAUTHORIZED) {
throw new LoginFlowAIServiceException("Failed to access AI service with renewed access token for " +
"the tenant domain: " + tenantDomain,
UNABLE_TO_ACCESS_AI_SERVICE_WITH_RENEW_ACCESS_TOKEN.getCode());
} else if (statusCode >= 400 && statusCode < 500) {
throw new LoginFlowAIClientException(new HttpResponseWrapper(statusCode, responseBody),
"Client error occurred from tenant: " + tenantDomain + " with status code: '" + statusCode +
"' while generating authentication sequence.", ERROR_CODE_CLIENT);
} else if (statusCode >= 500) {
throw new LoginFlowAIServiceException(new HttpResponseWrapper(statusCode, responseBody),
"Server error occurred from tenant: " + tenantDomain + " with status code: '" + statusCode +
"' while generating authentication sequence.", ERROR_CODE_SERVER);
}
}

private JSONObject convertJsonStringToJsonObject(String jsonString) throws LoginFlowAIServiceException {
try {
return new JSONObject(jsonString);
} catch (JsonSyntaxException | JSONException e) {
throw new LoginFlowAIServiceException("Error occurred while parsing the JSON string: " + e.getMessage(),
ERROR_WHILE_GENERATING_AUTHENTICATION_SEQUENCE.getCode(), e);
}
}

/**
* Wrapper class to hold the HTTP response status code and the response body.
*/
public static class HttpResponseWrapper {
private final int statusCode;
private final String responseBody;

public HttpResponseWrapper(int statusCode, String responseBody) {
this.statusCode = statusCode;
this.responseBody = responseBody;
}

public int getStatusCode() {
return statusCode;
}

public String getResponseBody() {
return responseBody;
}
}

/**
* Helper class to execute HTTP requests.
*/
public static class HttpClientHelper {

public static HttpResponseWrapper executeRequest(CloseableHttpAsyncClient client, HttpUriRequest httpRequest)
throws InterruptedException, ExecutionException, IOException {

Future<HttpResponse> apiResponse = client.execute(httpRequest, new FutureCallback<HttpResponse>() {
@Override
public void completed(HttpResponse response) {

LOG.info("API request completed with status code: " + response.getStatusLine().getStatusCode());
}

@Override
public void failed(Exception e) {

LOG.error("API request failed: " + e.getMessage(), e);
}

@Override
public void cancelled() {

LOG.warn("API request was cancelled");
}
});

HttpResponse httpResponse = apiResponse.get(); // Wait for the response to be available.
int status = httpResponse.getStatusLine().getStatusCode();
String response = EntityUtils.toString(httpResponse.getEntity());
return new HttpResponseWrapper(status, response);
}

}
}
Loading

0 comments on commit ec247a1

Please sign in to comment.