diff --git a/source/analysis/test/integration/typeVariableTest.ml b/source/analysis/test/integration/typeVariableTest.ml index db87a82229..a167ef55c4 100644 --- a/source/analysis/test/integration/typeVariableTest.ml +++ b/source/analysis/test/integration/typeVariableTest.ml @@ -57,8 +57,6 @@ let test_type_variable_scoping = |} []; - (* TODO migeedz: Why do we need to express the domain of Callable as a list for this to - typecheck? *) labeled_test_case __FUNCTION__ __LINE__ @@ assert_type_errors {| @@ -68,6 +66,39 @@ let test_type_variable_scoping = ... |} []; + labeled_test_case __FUNCTION__ __LINE__ + @@ assert_type_errors + {| + from typing import Callable, Awaitable + + def outer[**TParams, TReturn]( + inner: Callable[TParams, Awaitable[TReturn]], + ) -> Callable[TParams, Awaitable[TReturn]]: + async def _func( + *args: TParams.args, **kwargs: TParams.kwargs + ) -> TReturn: + return await inner(*args, **kwargs) + return _func + |} + []; + labeled_test_case __FUNCTION__ __LINE__ + @@ assert_type_errors + {| + from typing import Callable, Awaitable, ParamSpec, TypeVar + + TParams = ParamSpec("TParams") + TReturn = TypeVar("TReturn") + + def outer( + inner: Callable[TParams, Awaitable[TReturn]], + ) -> Callable[TParams, Awaitable[TReturn]]: + async def _func( + *args: TParams.args, **kwargs: TParams.kwargs + ) -> TReturn: + return await inner(*args, **kwargs) + return _func + |} + []; (* PEP695 generic methods from non-generic classes *) labeled_test_case __FUNCTION__ __LINE__ @@ assert_type_errors diff --git a/source/analysis/typeCheck.ml b/source/analysis/typeCheck.ml index 9c43ccd707..0bccfc9a63 100644 --- a/source/analysis/typeCheck.ml +++ b/source/analysis/typeCheck.ml @@ -481,7 +481,6 @@ module State (Context : Context) = struct let get_type_params_as_variables type_params global_resolution = let create_type = GlobalResolution.parse_annotation global_resolution in - (* TODO migeedz: why does parse_annotation return Top for types of the form A[int]? . *) let validate_bound bound = match bound.Node.value with | Expression.Tuple elements -> @@ -6141,6 +6140,15 @@ module State (Context : Context) = struct in Value resolution, validate_return expression ~resolution ~errors ~actual ~is_implicit | Define { signature = { Define.Signature.name; parent; type_params; _ } as signature; _ } -> + let type_params, type_params_errors = + get_type_params_as_variables type_params global_resolution + in + let resolution = + type_params + |> List.fold ~init:resolution ~f:(fun resolution variable -> + Resolution.add_type_variable resolution ~variable) + in + let resolution = match parent with | NestingContext.Function _ -> @@ -6152,7 +6160,6 @@ module State (Context : Context) = struct Resolution.new_local resolution ~reference:name ~type_info:annotation | _ -> resolution in - let _, type_params_errors = get_type_params_as_variables type_params global_resolution in Value resolution, type_params_errors | Import { Import.from; imports } -> let get_export_kind = function @@ -6990,7 +6997,15 @@ module State (Context : Context) = struct |> add_typeguard_error return_annotation return_type in let add_capture_annotations ~outer_scope_type_variables resolution errors = - let process_signature ({ Define.Signature.parent; _ } as signature) = + let process_signature ({ Define.Signature.parent; type_params; _ } as signature) = + let type_params, _ = get_type_params_as_variables type_params global_resolution in + + let resolution = + type_params + |> List.fold ~init:resolution ~f:(fun resolution variable -> + Resolution.add_type_variable resolution ~variable) + in + match parent with | NestingContext.Function _ -> type_of_signature ~module_name:Context.qualifier ~resolution signature @@ -8060,7 +8075,12 @@ module State (Context : Context) = struct >>| List.map ~f:extract |> Option.value ~default:[] in - let type_variables_of_define signature_of_nesting_function = + let type_variables_of_define + ({ Define.Signature.type_params; _ } as signature_of_nesting_function) + = + let local_scope_function_type_params, _ = + get_type_params_as_variables type_params global_resolution + in let parser = GlobalResolution.nonvalidating_annotation_parser global_resolution in let generic_parameters_as_variables = GlobalResolution.generic_parameters_as_variables global_resolution @@ -8074,6 +8094,7 @@ module State (Context : Context) = struct |> Type.Variable.all_free_variables |> List.dedup_and_sort ~compare:Type.Variable.compare in + let define_variables = define_variables @ local_scope_function_type_params in let containing_class_variables = (* PEP484 specifies that scope of the type variables of the outer class doesn't cover the inner one. We are able to inspect only 1 level of nesting class as a result. *)