Skip to content

Commit

Permalink
[evaluation] ci: Enable mypy (#37615)
Browse files Browse the repository at this point in the history
* fix(typing): Resolve mypy violations in azure/ai/evaluation/_http_utils.py

* fix(typing): Resolve uses of implicit Optional in type annotations

* fix(typing): Resolve type reassignment in http_utils.py

* style: Run isort

* fix(typing): Fix attempted type reassignment in _f1_score.py

* fix(typing): Use a TypeGuard to allow mypy to narrow types in _common/utils.py

* fix(typing): Correct return type of get_harm_severity_level

* fix(typing): Correct return type of _compute_f1_score

* fix(typing): Ensure mypy knows that AsyncHttpPipeline.__enter__ returns Self

* fix(typing): Allow mypy to infer the types of the convenience request methods

   _http_utils.py extensively uses decorators to implement the
   "convenience" request methods (get, post, put, etc...) for
    {Async}HttpPipeline, since they all share a common underlying
    implementation.

    However neither decorator annotated its return type (the type of the
    decorated function). Initially this was because the accurate type
    couldn't be spelled using a `Callable`, and pylance still did
    a fine job providing intellisense.

    It turns out that it currently isn't possible to spell it
    with a callable typing.Protocol. Our decorator applies to a method,
    and mypy struggles with the removal of the `self` attribute that
    occurs when a method binds to an object (see python/mypy issue
    #16200).

    This commit resolves this by making the implementation of the http
    pipelines more verbose, removing the decorators and unrolling the
    implementation to each convenience method.

    Using `Unpack[TypeDict]` to annotate kwargs makes this substantially
    more readable, but this causes mypy to complain if unknown keys
    are passed as kwargs (useful for request-specific pipeline
    configuration).

* ci: Enable mypy in CI

* fix(typing): Fix extranous `total=False` on  TypedDict

* fix(typing): Propagate model config type hint upwards

* fix(typing): Ensure that `len` is only called on objects that implement Sized in _tracing.py

* fix(typing): Resolve implicit optional type for Turn

* fix(typing): Resolve missing/inaccurate return types in simulator_data_classes

* fix(typing): Refine the TExperimental type for experimental()

* fix(typing): Ignore the method assign in experimental decorator

* fix(typing): Remove unnecessary optional for _add_method_docstring

* fix(typing): Mark get_token as sync

         The abstract method `get_token` is marked as async, but
         both concrete implementations are sync and every use of it in
         the codebase is in a sync context.

* fix(typing): Add type hints for APITokenManager attributes

* fix(typing): Prevent type-reassignment in APITokenManager

* refactor: Remove unnecessary pass

* fix(typing): Explicitly list accepted kwargs for derived APITokenManager classes

* fix(typing): Mark PlainTokenManager.token as non-optional str

* fix(typing): Mark *_prompty args as Optional in _simulator.py

* fix: Don't raise bare strings

* fix(typing): Fix return type for _apply_target_to_data

* fix(typing): Use TypedDict as argument to _trace_destination_from_project_scope

* fix(typing): Fix return type of Simulator._complete_conversation

* fix(typing): Correct the param type of _process_column_mappings

* fix(typing): evaluators param Dict[str, Any] -> Dict[str, Callable]

* fix(typing): Add type annotation for processed_config

* fix(typing): Remove unnecessary variable declaration from _evaluate

* fix(typing),refactor: Clarify to mypy that fetch_or_reuse_token always returns str

* fix(typing): Add type annotations for EvalRun attributes

* fix(typing): Use TypedDict for get_rai_svc_url project_scope parameter

* fix(typing): Specify that EvalRun.__enter__ returns Self

* fix(typing): Add type annotation in evaluate_with_rai_service

* fix(typing),refactor: Make EvalRun.info a non-Optional property

* fix(typing): Add a type annotation in log_artifact

* fix(typing): Add missing MLClient import

* fix(typing): Add missing return to EvalRun.__exit__

* fix(typing),refactor: Clarify that _get_evaluator_type always returns str

* fix(typing): Add type annotations  in log_evaluate_activity

* fix(typing): QAEvaluator accepts typed dict and returns Dict[str, float]

* fix(typing): Set USER_AGENT to a str when import fails

* fix: Avoid using a dangerous default value

         Using a mutable value as a parameter default is dangerous,
         since mutations will persist across function calls.

         See pylint error code `W0102(dangerous-default-value)`

* fix(typing): Remove unused *args from OpenAIChatCompletionsModel.__init__

* fix(typing): Avoid name-redefinition due to repeat import

* fix(typing): Make EvaluationMetrics an enum

* fix(typing): Use TypedDict for AzureAIProject params

* fix(typing): Type credential as azure.core.credentials.TokenCredential

* fix(typing): Clarify that _log_metrics_and_insant_results returns optional str

* fix(typing), refactor: Add a utility function to validate AzureAIProject dict

* fix(typing): Resolve mismatch with namedtuple type name and variable name

* refactor: Remove unused attribute AdversarialTemplateHandler.cached_templates_source

* fix(typing): Resolve type reassignment in proxy_model_completion

* fix(typing): Add type annotation for ProxyChatCompletionModel.result_url

* fix(typing): Add types annotations to BatchRunContext methods

* fix(typing): Add type annotation for ConversationBot.conversation_starter

* fix(typing): Fix return type of ConversationBot.generate_responses

* fix(typing): Clarify return type of simulate_conversation

* fix(typing): Add type ignore for OpenAICompletionsModel.format_request_data

* fix(typing): Remove unnecessary type annotation in OpenAICompletionsModel.format_request_data

* fix(typing): Clarify that content safety evaluators return Dict[str, Union[str, float]]

* fix(typing): Clarify return type of ContentSafetyChatEvaluator._get_harm_severity_level

* fix(typing): Add type annotations to ContentSafetyChatEvaluator methods

* fix(typing): Add type annotations for ContentSafetyEvaluator

* fix(typing): Use a callable object in AdversarialSimulator

* refactor: Use a set literal for CONTENT_HARM_TEMPLATES_COLLECTION_KEY

* fix(typing): Specify evaluate return type to narrow log_evaluate_activity type

* fix(typing): Add type annotations to adversarial simulator

* fix(typing),refactor: Clarify that _setup_bot's fallthrough branch is unreachable

    _setup_bot does exhaustive matching against all ConversationRole's
    enum values

* fix(typing): Make SimulationRequestDTO.to_dict non-destructive

* fix(typing): Add type annotations to code_client.py

* fix(typing): Correct Simulator__call__ task parameter to be List[str]

* fix(typing): evaluators Dict[str, Type] -> Dict[str, Callable]

* fix(typing): Make CodeClient.get_metrics always return a dict

* fix(typing): Add type annotations to evaluate/utils.py

* fix(typing): Clarify that CodeRun.get_aggregated_metrics returns Dict[str, Any]

* fix(typing): data is a required parameter for _evaluate

* fix(typing): Add variable annotations in _evaluate

* fix(typing),refactor: Prevent batch_run_client from being Union[ProxyClient,CodeClient]

    Despite having similar interfaces with compatible calling conventions,
    the fact that ProxyClient and CodeClient have different "run" types
    (ProxyRun and CodeRun) causes type errors when dealing with a
    client of type Union[ProxyClient,CodeRun]. Mypy must consider the case
    when the wrong run type is used for a given client, despite that
    not being possible in this function.

    Refactoring the relevant code into a function allows us to clarify
    to mypy that client and run types are used consistently.

* fix: Remove unused imports

* fix(pylint): Resolve R1711(useless-return)

* fix(pylint): Resolve W0707(raise-missing-from)

* fix(pylint): Add parameters/returns to http_utils docstrings

* fix(pylint): Make EvaluationMetrics implement CaseInsentitiveEnumMeta

* fix: Remove return type annotations for Evaluators

         Promptflow does reflection on type annotations, and only
         accepts a dataclass, typeddict, or string as return type
         annotation.

* fix(typing): Add runtime validation of model_config

* fix: Remove type annotations from evaluator/simulators credential param

    Promptflow does reflection on type annotations and only allows
    dict

* fix: Remove type annotations from azure_ai_project param

    Promptflow does reflection on param types and disallows TypedDicts

* fix(typing): {Azure,}OpenAIModelConfiguration.type is NotRequired

* fix(typing): List[Dict] -> list for conversation param

* tests: Fix tests

* fix(typing): Make RaiServiceEvaluatorBase also accept _InternalEvaluationMetrics

* fix(typing): Use typing.final to enforce "never be overriden by children"

* fix(typing): Use abstractmethod to enforce "children must override method"

* fix(typing): Add type annotations to EvaluatorBase

* ci: Add "stringized" to cspell

* fix: Explicitly pass in data to get_evaluators_info

    Resolves a bug where the function was capturing data from the other
    scope, but data wasn't changed to the approriate value until
    after the function call.
  • Loading branch information
kdestin authored Oct 8, 2024
1 parent 5043dbe commit a852079
Show file tree
Hide file tree
Showing 53 changed files with 1,223 additions and 705 deletions.
6 changes: 6 additions & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,12 @@
"deidentify",
"deidentified"
]
},
{
"filename": "sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/utils.py",
"words": [
"stringized"
]
}
],
"allowCompoundWords": true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from ._model_configurations import (
AzureAIProject,
AzureOpenAIModelConfiguration,
OpenAIModelConfiguration,
EvaluatorConfig,
OpenAIModelConfiguration,
)

__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# ---------------------------------------------------------
from enum import Enum

from azure.core import CaseInsensitiveEnumMeta


class CommonConstants:
"""Define common constants."""
Expand Down Expand Up @@ -43,7 +45,7 @@ class _InternalAnnotationTasks:
ECI = "eci"


class EvaluationMetrics:
class EvaluationMetrics(str, Enum, metaclass=CaseInsensitiveEnumMeta):
"""Evaluation metrics to aid the RAI service in determining what
metrics to request, and how to present them back to the user."""

Expand All @@ -56,7 +58,7 @@ class EvaluationMetrics:
XPIA = "xpia"


class _InternalEvaluationMetrics:
class _InternalEvaluationMetrics(str, Enum, metaclass=CaseInsensitiveEnumMeta):
"""Evaluation metrics that are not publicly supported.
These metrics are experimental and subject to potential change or migration to the main
enum over time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
# ---------------------------------------------------------
import asyncio
import importlib.metadata
import math
import re
import time
import math
from ast import literal_eval
from typing import Dict, List
from typing import Dict, List, Optional, Union, cast
from urllib.parse import urlparse

import jwt

from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
from azure.ai.evaluation._http_utils import get_async_http_client
from azure.ai.evaluation._http_utils import AsyncHttpPipeline, get_async_http_client
from azure.ai.evaluation._model_configurations import AzureAIProject
from azure.core.credentials import TokenCredential
from azure.core.pipeline.policies import AsyncRetryPolicy

from .constants import (
CommonConstants,
Expand Down Expand Up @@ -52,7 +53,13 @@ def get_common_headers(token: str) -> Dict:
}


async def ensure_service_availability(rai_svc_url: str, token: str, capability: str = None) -> None:
def get_async_http_client_with_timeout() -> AsyncHttpPipeline:
return get_async_http_client().with_policies(
retry_policy=AsyncRetryPolicy(timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT)
)


async def ensure_service_availability(rai_svc_url: str, token: str, capability: Optional[str] = None) -> None:
"""Check if the Responsible AI service is available in the region and has the required capability, if relevant.
:param rai_svc_url: The Responsible AI service URL.
Expand All @@ -67,9 +74,7 @@ async def ensure_service_availability(rai_svc_url: str, token: str, capability:
svc_liveness_url = rai_svc_url + "/checkannotation"

async with get_async_http_client() as client:
response = await client.get( # pylint: disable=too-many-function-args,unexpected-keyword-arg
svc_liveness_url, headers=headers, timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT
)
response = await client.get(svc_liveness_url, headers=headers)

if response.status_code != 200:
msg = f"RAI service is not available in this region. Status Code: {response.status_code}"
Expand Down Expand Up @@ -153,16 +158,14 @@ async def submit_request(query: str, response: str, metric: str, rai_svc_url: st
url = rai_svc_url + "/submitannotation"
headers = get_common_headers(token)

async with get_async_http_client() as client:
response = await client.post( # pylint: disable=too-many-function-args,unexpected-keyword-arg
url, json=payload, headers=headers, timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT
)
async with get_async_http_client_with_timeout() as client:
http_response = await client.post(url, json=payload, headers=headers)

if response.status_code != 202:
print("Fail evaluating '%s' with error message: %s" % (payload["UserTextList"], response.text))
response.raise_for_status()
if http_response.status_code != 202:
print("Fail evaluating '%s' with error message: %s" % (payload["UserTextList"], http_response.text()))
http_response.raise_for_status()

result = response.json()
result = http_response.json()
operation_id = result["location"].split("/")[-1]
return operation_id

Expand All @@ -189,10 +192,8 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre
token = await fetch_or_reuse_token(credential, token)
headers = get_common_headers(token)

async with get_async_http_client() as client:
response = await client.get( # pylint: disable=too-many-function-args,unexpected-keyword-arg
url, headers=headers, timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT
)
async with get_async_http_client_with_timeout() as client:
response = await client.get(url, headers=headers)

if response.status_code == 200:
return response.json()
Expand All @@ -208,15 +209,15 @@ async def fetch_result(operation_id: str, rai_svc_url: str, credential: TokenCre

def parse_response( # pylint: disable=too-many-branches,too-many-statements
batch_response: List[Dict], metric_name: str
) -> Dict:
) -> Dict[str, Union[str, float]]:
"""Parse the annotation response from Responsible AI service for a content harm evaluation.
:param batch_response: The annotation response from Responsible AI service.
:type batch_response: List[Dict]
:param metric_name: The evaluation metric to use.
:type metric_name: str
:return: The parsed annotation result.
:rtype: List[List[Dict]]
:rtype: Dict[str, Union[str, float]]
"""
# non-numeric metrics
if metric_name in {EvaluationMetrics.PROTECTED_MATERIAL, _InternalEvaluationMetrics.ECI, EvaluationMetrics.XPIA}:
Expand Down Expand Up @@ -248,23 +249,23 @@ def parse_response( # pylint: disable=too-many-branches,too-many-statements
return _parse_content_harm_response(batch_response, metric_name)


def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -> Dict:
def _parse_content_harm_response(batch_response: List[Dict], metric_name: str) -> Dict[str, Union[str, float]]:
"""Parse the annotation response from Responsible AI service for a content harm evaluation.
:param batch_response: The annotation response from Responsible AI service.
:type batch_response: List[Dict]
:param metric_name: The evaluation metric to use.
:type metric_name: str
:return: The parsed annotation result.
:rtype: List[List[Dict]]
:rtype: Dict[str, Union[str, float]]
"""
# Fix the metric name if it's "hate_fairness"
# Eventually we will remove this fix once the RAI service is updated
key = metric_name
if key == EvaluationMetrics.HATE_FAIRNESS:
key = EvaluationMetrics.HATE_UNFAIRNESS

result = {key: math.nan, key + "_score": math.nan, key + "_reason": ""}
result: Dict[str, Union[str, float]] = {key: math.nan, key + "_score": math.nan, key + "_reason": ""}

response = batch_response[0]
if metric_name not in response:
Expand Down Expand Up @@ -336,14 +337,13 @@ async def _get_service_discovery_url(azure_ai_project: AzureAIProject, token: st
"""
headers = get_common_headers(token)

async with get_async_http_client() as client:
response = await client.get( # pylint: disable=too-many-function-args,unexpected-keyword-arg
async with get_async_http_client_with_timeout() as client:
response = await client.get(
f"https://management.azure.com/subscriptions/{azure_ai_project['subscription_id']}/"
f"resourceGroups/{azure_ai_project['resource_group_name']}/"
f"providers/Microsoft.MachineLearningServices/workspaces/{azure_ai_project['project_name']}?"
f"api-version=2023-08-01-preview",
headers=headers,
timeout=CommonConstants.DEFAULT_HTTP_TIMEOUT,
)

if response.status_code != 200:
Expand All @@ -360,7 +360,7 @@ async def _get_service_discovery_url(azure_ai_project: AzureAIProject, token: st
return f"{base_url.scheme}://{base_url.netloc}"


async def get_rai_svc_url(project_scope: dict, token: str) -> str:
async def get_rai_svc_url(project_scope: AzureAIProject, token: str) -> str:
"""Get the Responsible AI service URL
:param project_scope: The Azure AI project scope details.
Expand All @@ -384,7 +384,7 @@ async def get_rai_svc_url(project_scope: dict, token: str) -> str:
return rai_url


async def fetch_or_reuse_token(credential: TokenCredential, token: str = None) -> str:
async def fetch_or_reuse_token(credential: TokenCredential, token: Optional[str] = None) -> str:
"""Get token. Fetch a new token if the current token is near expiry
:param credential: The Azure authentication credential.
Expand All @@ -394,29 +394,26 @@ async def fetch_or_reuse_token(credential: TokenCredential, token: str = None) -
:type token: str
:return: The Azure authentication token.
"""
acquire_new_token = True
try:
if token:
# Decode the token to get its expiration time
if token:
# Decode the token to get its expiration time
try:
decoded_token = jwt.decode(token, options={"verify_signature": False})
except jwt.PyJWTError:
pass
else:
exp_time = decoded_token["exp"]
current_time = time.time()

# Check if the token is near expiry
# Return current token if not near expiry
if (exp_time - current_time) >= 300:
acquire_new_token = False
except Exception: # pylint: disable=broad-exception-caught
pass

if acquire_new_token:
token = credential.get_token("https://management.azure.com/.default").token
return token

return token
return credential.get_token("https://management.azure.com/.default").token


async def evaluate_with_rai_service(
query: str, response: str, metric_name: str, project_scope: AzureAIProject, credential: TokenCredential
):
) -> Dict[str, Union[str, float]]:
""" "Evaluate the content safety of the response using Responsible AI service
:param query: The query to evaluate.
Expand All @@ -431,7 +428,7 @@ async def evaluate_with_rai_service(
:type credential:
~azure.core.credentials.TokenCredential
:return: The parsed annotation result.
:rtype: List[List[Dict]]
:rtype: Dict[str, Union[str, float]]
"""

# Get RAI service URL from discovery service and check service availability
Expand All @@ -441,7 +438,7 @@ async def evaluate_with_rai_service(

# Submit annotation request and fetch result
operation_id = await submit_request(query, response, metric_name, rai_svc_url, token)
annotation_response = await fetch_result(operation_id, rai_svc_url, credential, token)
annotation_response = cast(List[Dict], await fetch_result(operation_id, rai_svc_url, credential, token))
result = parse_response(annotation_response, metric_name)

return result
Loading

0 comments on commit a852079

Please sign in to comment.