From 61f48bb1f103f06f7094af4ccaa21dee6916fda7 Mon Sep 17 00:00:00 2001 From: Jason Tatton Date: Mon, 30 Oct 2023 21:02:02 -0700 Subject: [PATCH] Restart pyre server on connectivity failure Summary: Presently pyre codenav does not restart the pyre server when the socket connection is deleted from the filesystem or when the pyre server hangs when pyre kill is invoked. This is manifests as a `connections.ConnectionFailure` (there are also two other less common failures which we can restart on: `asyncio.IncompleteReadError` and `ConnectionError`). Here we add some code to boil this restartable failure as an `error_source` up to `dispatch_nonblocking_request` - where the client can choose to restart based on the presence of and type of exception raised. Event is implicitly logged to scuba (with more detailed error message including stack trace) by previous diff in stack. In later diffs: - Refactor dispatch_nonblocking_request, use match syntax etc. Reviewed By: grievejia Differential Revision: D50738586 fbshipit-source-id: 82754ae757291e6c2357e80ea678e8673a5781de --- client/commands/daemon_querier.py | 19 ++++-- client/commands/daemon_query.py | 7 +- client/commands/pyre_language_server.py | 68 +++++++++++++------ client/commands/tests/language_server_test.py | 5 +- .../code_navigation_request.py | 17 +++-- client/language_server/daemon_connection.py | 8 ++- 6 files changed, 86 insertions(+), 38 deletions(-) diff --git a/client/commands/daemon_querier.py b/client/commands/daemon_querier.py index 31ab52f3bec..d65af21ca56 100644 --- a/client/commands/daemon_querier.py +++ b/client/commands/daemon_querier.py @@ -733,7 +733,9 @@ async def get_type_errors( self.socket_path, type_errors_request ) if isinstance(response, code_navigation_request.ErrorResponse): - return daemon_query.DaemonQueryFailure(response.message) + return daemon_query.DaemonQueryFailure( + response.message, error_source=response.error_source + ) # TODO(T165048078): determine if this error should be kept for code navigation (are we committing to unsafe mode in codenav?) return [error for error in response.to_errors_response() if error.code != 0] @@ -779,7 +781,9 @@ async def get_definition_locations( definition_request, ) if isinstance(response, code_navigation_request.ErrorResponse): - return daemon_query.DaemonQueryFailure(response.message) + return daemon_query.DaemonQueryFailure( + response.message, error_source=response.error_source + ) return GetDefinitionLocationsResponse( source=DaemonQuerierSource.PYRE_DAEMON, data=[ @@ -802,7 +806,9 @@ async def get_completions( self.socket_path, completions_request ) if isinstance(response, code_navigation_request.ErrorResponse): - return daemon_query.DaemonQueryFailure(response.message) + return daemon_query.DaemonQueryFailure( + response.message, error_source=response.error_source + ) return [ completion_item.to_lsp_completion_item() for completion_item in response.completions @@ -959,8 +965,11 @@ async def get_definition_locations( return await self.get_definition_locations_from_glean(path, position) base_results = await self.base_querier.get_definition_locations(path, position) - # If pyre throws an exception or if the definition locations are empty, fall back to glean - if isinstance(base_results, daemon_query.DaemonQueryFailure): + # If pyre throws an exception and might not require restarting due to that exception - then fall back to glean + if ( + isinstance(base_results, daemon_query.DaemonQueryFailure) + and base_results.error_source is None + ): LOG.warn( f"Pyre threw exception: {base_results.error_message} - falling back to glean" ) diff --git a/client/commands/daemon_query.py b/client/commands/daemon_query.py index e8be8767deb..53fe41c7406 100644 --- a/client/commands/daemon_query.py +++ b/client/commands/daemon_query.py @@ -44,6 +44,7 @@ @dataclasses.dataclass(frozen=True) class DaemonQueryFailure(json_mixins.CamlCaseAndExcludeJsonMixin): error_message: str + error_source: Optional[Exception] = None def execute_query(socket_path: Path, query_text: str) -> Response: @@ -69,7 +70,8 @@ async def attempt_async_query( ) if isinstance(response_text, daemon_connection.DaemonConnectionFailure): return DaemonQueryFailure( - f"In attempt async query with response_text, got DaemonConnectionFailure exception: ({response_text.error_message})" + error_message=f"In attempt async query with response_text, got DaemonConnectionFailure exception: ({response_text.error_message})", + error_source=response_text.error_source, ) try: return Response.parse(response_text) @@ -136,7 +138,8 @@ async def attempt_async_overlay_type_errors( ) if isinstance(response_text, daemon_connection.DaemonConnectionFailure): return DaemonQueryFailure( - f"In attempt async query with response_text, got DaemonConnectionFailure exception: ({response_text.error_message})" + error_message=f"In attempt async query with response_text, got DaemonConnectionFailure exception: ({response_text.error_message})", + error_source=response_text.error_source, ) try: return incremental.parse_type_error_response(response_text).errors diff --git a/client/commands/pyre_language_server.py b/client/commands/pyre_language_server.py index 1169e441ccd..f7ac569749e 100644 --- a/client/commands/pyre_language_server.py +++ b/client/commands/pyre_language_server.py @@ -197,7 +197,7 @@ async def process_definition_request( parameters: lsp.DefinitionParameters, request_id: Union[int, str, None], activity_key: Optional[Dict[str, object]] = None, - ) -> None: + ) -> Optional[Exception]: raise NotImplementedError() @abc.abstractmethod @@ -206,7 +206,7 @@ async def process_completion_request( parameters: lsp.CompletionParameters, request_id: Union[int, str, None], activity_key: Optional[Dict[str, object]] = None, - ) -> None: + ) -> Optional[Exception]: raise NotImplementedError() @abc.abstractmethod @@ -740,7 +740,7 @@ async def process_definition_request( parameters: lsp.DefinitionParameters, request_id: Union[int, str, None], activity_key: Optional[Dict[str, object]] = None, - ) -> None: + ) -> Optional[Exception]: document_path: Optional[ Path ] = parameters.text_document.document_uri().to_file_path() @@ -750,7 +750,7 @@ async def process_definition_request( ) document_path = document_path.resolve() if document_path not in self.server_state.opened_documents: - return await lsp.write_json_rpc( + await lsp.write_json_rpc( self.output_channel, json_rpc.SuccessResponse( id=request_id, @@ -758,6 +758,7 @@ async def process_definition_request( result=lsp.LspLocation.cached_schema().dump([], many=True), ), ) + return None daemon_status_before = self.server_state.status_tracker.get_status() shadow_mode = self.get_language_server_features().definition.is_shadow() # In shadow mode, we need to return an empty response immediately @@ -779,9 +780,11 @@ async def process_definition_request( if isinstance(result, DaemonQueryFailure): error_message = result.error_message output_result = [] + error_source = result.error_source else: error_message = None output_result = result + error_source = None # Unless we are in shadow mode, we send the response as output if not shadow_mode: await lsp.write_json_rpc( @@ -831,12 +834,14 @@ async def process_definition_request( activity_key, ) + return error_source + async def process_completion_request( self, parameters: lsp.CompletionParameters, request_id: Union[int, str, None], activity_key: Optional[Dict[str, object]] = None, - ) -> None: + ) -> Optional[Exception]: document_path = parameters.text_document.document_uri().to_file_path() if document_path is None: raise json_rpc.InvalidRequestError( @@ -845,7 +850,7 @@ async def process_completion_request( document_path = document_path.resolve() if document_path not in self.server_state.opened_documents: - return await lsp.write_json_rpc( + await lsp.write_json_rpc( self.output_channel, json_rpc.SuccessResponse( id=request_id, @@ -853,6 +858,7 @@ async def process_completion_request( result=[], ), ) + return None daemon_status_before = self.server_state.status_tracker.get_status() completion_timer = timer.Timer() @@ -870,7 +876,10 @@ async def process_completion_request( ) ) error_message = result.error_message + error_source = result.error_source result = [] + else: + error_source = None raw_result = [completion_item.to_dict() for completion_item in result] @@ -902,6 +911,8 @@ async def process_completion_request( activity_key, ) + return error_source + async def process_document_symbols_request( self, parameters: lsp.DocumentSymbolsParameters, @@ -1409,42 +1420,58 @@ async def wait_for_exit(self) -> commands.ExitCode: await _wait_for_exit(self.input_channel, self.output_channel) return commands.ExitCode.SUCCESS - async def _try_restart_pyre_daemon(self) -> None: + async def _restart_if_needed( + self, error_source: Optional[Exception] = None + ) -> None: if ( self.server_state.consecutive_start_failure - < CONSECUTIVE_START_ATTEMPT_THRESHOLD + >= CONSECUTIVE_START_ATTEMPT_THRESHOLD ): - await self.daemon_manager.ensure_task_running() - else: LOG.info( "Not restarting Pyre since failed consecutive start attempt limit" " has been reached." ) + return + + if isinstance( + error_source, + ( + connections.ConnectionFailure, + asyncio.IncompleteReadError, + ConnectionError, + ), + ): # do we think the daemon is probably down at this point? + # Terminate any existing daemon processes before starting a new one + LOG.info("Forcing pyre daemon restart...") + await self.daemon_manager.ensure_task_stop() # make sure it's down + + # restart if needed + if not self.daemon_manager.is_task_running(): + # Just check to ensure that the daemon is running and restart if needed + await self.daemon_manager.ensure_task_running() async def dispatch_nonblocking_request(self, request: json_rpc.Request) -> None: if request.method == "exit" or request.method == "shutdown": raise Exception("Exit and shutdown requests should be blocking") elif request.method == "textDocument/definition": - await self.api.process_definition_request( + error_source = await self.api.process_definition_request( lsp.DefinitionParameters.from_json_rpc_parameters( request.extract_parameters() ), request.id, request.activity_key, ) - if not self.daemon_manager.is_task_running(): - await self._try_restart_pyre_daemon() + await self._restart_if_needed(error_source) elif request.method == "textDocument/completion": LOG.debug("Received 'textDocument/completion' request.") - await self.api.process_completion_request( + error_source = await self.api.process_completion_request( lsp.CompletionParameters.from_json_rpc_parameters( request.extract_parameters() ), request.id, request.activity_key, ) - if not self.daemon_manager.is_task_running(): - await self._try_restart_pyre_daemon() + await self._restart_if_needed(error_source) elif request.method == "textDocument/didOpen": await self.api.process_open_request( lsp.DidOpenTextDocumentParameters.from_json_rpc_parameters( @@ -1452,16 +1479,14 @@ async def dispatch_nonblocking_request(self, request: json_rpc.Request) -> None: ), request.activity_key, ) - if not self.daemon_manager.is_task_running(): - await self._try_restart_pyre_daemon() + await self._restart_if_needed() elif request.method == "textDocument/didChange": await self.api.process_did_change_request( lsp.DidChangeTextDocumentParameters.from_json_rpc_parameters( request.extract_parameters() ) ) - if not self.daemon_manager.is_task_running(): - await self._try_restart_pyre_daemon() + await self._restart_if_needed() elif request.method == "textDocument/didClose": await self.api.process_close_request( lsp.DidCloseTextDocumentParameters.from_json_rpc_parameters( @@ -1475,8 +1500,7 @@ async def dispatch_nonblocking_request(self, request: json_rpc.Request) -> None: ), request.activity_key, ) - if not self.daemon_manager.is_task_running(): - await self._try_restart_pyre_daemon() + await self._restart_if_needed() elif request.method == "textDocument/hover": await self.api.process_hover_request( lsp.HoverParameters.from_json_rpc_parameters( diff --git a/client/commands/tests/language_server_test.py b/client/commands/tests/language_server_test.py index a660a6821d2..65f1647f5f9 100644 --- a/client/commands/tests/language_server_test.py +++ b/client/commands/tests/language_server_test.py @@ -44,8 +44,9 @@ async def process_definition_request( parameters: lsp.DefinitionParameters, request_id: Union[int, str, None], activity_key: Optional[Dict[str, object]] = None, - ) -> None: + ) -> Optional[Exception]: await asyncio.Event().wait() + return None async def process_hover_request( self, @@ -119,7 +120,7 @@ async def process_completion_request( parameters: lsp.CompletionParameters, request_id: Union[int, str, None], activity_key: Optional[Dict[str, object]] = None, - ) -> None: + ) -> Optional[Exception]: raise NotImplementedError() async def process_document_symbols_request( diff --git a/client/language_server/code_navigation_request.py b/client/language_server/code_navigation_request.py index 4dc243b9963..3537a5aecac 100644 --- a/client/language_server/code_navigation_request.py +++ b/client/language_server/code_navigation_request.py @@ -78,6 +78,7 @@ def to_json(self) -> List[object]: @dataclasses.dataclass(frozen=True) class ErrorResponse: message: str + error_source: Optional[Exception] = None @dataclasses.dataclass(frozen=True) @@ -311,7 +312,9 @@ async def async_handle_hover_request( socket_path, raw_request ) if isinstance(response, daemon_connection.DaemonConnectionFailure): - return ErrorResponse(message=response.error_message) + return ErrorResponse( + message=response.error_message, error_source=response.error_source + ) response = parse_raw_response( response, expected_response_kind="Hover", response_type=HoverResponse ) @@ -329,7 +332,9 @@ async def async_handle_definition_request( socket_path, raw_request ) if isinstance(response, daemon_connection.DaemonConnectionFailure): - return ErrorResponse(message=response.error_message) + return ErrorResponse( + message=response.error_message, error_source=response.error_source + ) return parse_raw_response( response, expected_response_kind="LocationOfDefinition", @@ -346,7 +351,9 @@ async def async_handle_type_errors_request( socket_path, raw_request ) if isinstance(response, daemon_connection.DaemonConnectionFailure): - return ErrorResponse(message=response.error_message) + return ErrorResponse( + message=response.error_message, error_source=response.error_source + ) return parse_raw_response( response, expected_response_kind="TypeErrors", response_type=TypeErrorsResponse ) @@ -361,7 +368,9 @@ async def async_handle_completion_request( socket_path, raw_request ) if isinstance(response, daemon_connection.DaemonConnectionFailure): - return ErrorResponse(message=response.error_message) + return ErrorResponse( + message=response.error_message, error_source=response.error_source + ) return parse_raw_response( response, expected_response_kind="Completion", diff --git a/client/language_server/daemon_connection.py b/client/language_server/daemon_connection.py index 83094e7a9bd..222014bd575 100644 --- a/client/language_server/daemon_connection.py +++ b/client/language_server/daemon_connection.py @@ -22,7 +22,7 @@ import logging import traceback from pathlib import Path -from typing import AsyncIterator, Union +from typing import AsyncIterator, Optional, Union from .. import dataclasses_json_extensions as json_mixins, log from . import connections @@ -34,6 +34,7 @@ @dataclasses.dataclass(frozen=True) class DaemonConnectionFailure(json_mixins.CamlCaseAndExcludeJsonMixin): error_message: str + error_source: Optional[Exception] = None def send_raw_request(socket_path: Path, raw_request: str) -> str: @@ -101,6 +102,7 @@ async def attempt_send_async_raw_request( ConnectionError, ) as error: return DaemonConnectionFailure( - "Could not establish connection with an existing Pyre server " - f"at {socket_path}: {error}. Type: {type(error)}. Stacktrace: {traceback.format_exc( limit = None, chain = True)}" + error_message="Could not establish connection with an existing Pyre server " + f"at {socket_path}: {error}. Type: {type(error)}. Stacktrace: {traceback.format_exc( limit = None, chain = True)}", + error_source=error, )