Skip to content

Commit

Permalink
Restart pyre server on connectivity failure
Browse files Browse the repository at this point in the history
Summary:
Presently pyre codenav does not restart the pyre server when the socket connection is deleted from the filesystem or when the pyre server hangs when pyre kill is invoked.

This is manifests as a `connections.ConnectionFailure` (there are also two other less common failures which we can restart on: `asyncio.IncompleteReadError` and `ConnectionError`). Here we add some code to boil this restartable failure as an `error_source` up to `dispatch_nonblocking_request` - where the client can choose to restart based on the presence of and type of exception raised.

Event is implicitly logged to scuba (with more detailed error message including stack trace) by previous diff in stack.

In later diffs:
 - Refactor dispatch_nonblocking_request, use match syntax etc.

Reviewed By: grievejia

Differential Revision: D50738586

fbshipit-source-id: 82754ae757291e6c2357e80ea678e8673a5781de
  • Loading branch information
jasontatton authored and facebook-github-bot committed Oct 31, 2023
1 parent 48da973 commit 61f48bb
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 38 deletions.
19 changes: 14 additions & 5 deletions client/commands/daemon_querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,9 @@ async def get_type_errors(
self.socket_path, type_errors_request
)
if isinstance(response, code_navigation_request.ErrorResponse):
return daemon_query.DaemonQueryFailure(response.message)
return daemon_query.DaemonQueryFailure(
response.message, error_source=response.error_source
)
# TODO(T165048078): determine if this error should be kept for code navigation (are we committing to unsafe mode in codenav?)
return [error for error in response.to_errors_response() if error.code != 0]

Expand Down Expand Up @@ -779,7 +781,9 @@ async def get_definition_locations(
definition_request,
)
if isinstance(response, code_navigation_request.ErrorResponse):
return daemon_query.DaemonQueryFailure(response.message)
return daemon_query.DaemonQueryFailure(
response.message, error_source=response.error_source
)
return GetDefinitionLocationsResponse(
source=DaemonQuerierSource.PYRE_DAEMON,
data=[
Expand All @@ -802,7 +806,9 @@ async def get_completions(
self.socket_path, completions_request
)
if isinstance(response, code_navigation_request.ErrorResponse):
return daemon_query.DaemonQueryFailure(response.message)
return daemon_query.DaemonQueryFailure(
response.message, error_source=response.error_source
)
return [
completion_item.to_lsp_completion_item()
for completion_item in response.completions
Expand Down Expand Up @@ -959,8 +965,11 @@ async def get_definition_locations(
return await self.get_definition_locations_from_glean(path, position)
base_results = await self.base_querier.get_definition_locations(path, position)

# If pyre throws an exception or if the definition locations are empty, fall back to glean
if isinstance(base_results, daemon_query.DaemonQueryFailure):
# If pyre throws an exception and might not require restarting due to that exception - then fall back to glean
if (
isinstance(base_results, daemon_query.DaemonQueryFailure)
and base_results.error_source is None
):
LOG.warn(
f"Pyre threw exception: {base_results.error_message} - falling back to glean"
)
Expand Down
7 changes: 5 additions & 2 deletions client/commands/daemon_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
@dataclasses.dataclass(frozen=True)
class DaemonQueryFailure(json_mixins.CamlCaseAndExcludeJsonMixin):
error_message: str
error_source: Optional[Exception] = None


def execute_query(socket_path: Path, query_text: str) -> Response:
Expand All @@ -69,7 +70,8 @@ async def attempt_async_query(
)
if isinstance(response_text, daemon_connection.DaemonConnectionFailure):
return DaemonQueryFailure(
f"In attempt async query with response_text, got DaemonConnectionFailure exception: ({response_text.error_message})"
error_message=f"In attempt async query with response_text, got DaemonConnectionFailure exception: ({response_text.error_message})",
error_source=response_text.error_source,
)
try:
return Response.parse(response_text)
Expand Down Expand Up @@ -136,7 +138,8 @@ async def attempt_async_overlay_type_errors(
)
if isinstance(response_text, daemon_connection.DaemonConnectionFailure):
return DaemonQueryFailure(
f"In attempt async query with response_text, got DaemonConnectionFailure exception: ({response_text.error_message})"
error_message=f"In attempt async query with response_text, got DaemonConnectionFailure exception: ({response_text.error_message})",
error_source=response_text.error_source,
)
try:
return incremental.parse_type_error_response(response_text).errors
Expand Down
68 changes: 46 additions & 22 deletions client/commands/pyre_language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ async def process_definition_request(
parameters: lsp.DefinitionParameters,
request_id: Union[int, str, None],
activity_key: Optional[Dict[str, object]] = None,
) -> None:
) -> Optional[Exception]:
raise NotImplementedError()

@abc.abstractmethod
Expand All @@ -206,7 +206,7 @@ async def process_completion_request(
parameters: lsp.CompletionParameters,
request_id: Union[int, str, None],
activity_key: Optional[Dict[str, object]] = None,
) -> None:
) -> Optional[Exception]:
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -740,7 +740,7 @@ async def process_definition_request(
parameters: lsp.DefinitionParameters,
request_id: Union[int, str, None],
activity_key: Optional[Dict[str, object]] = None,
) -> None:
) -> Optional[Exception]:
document_path: Optional[
Path
] = parameters.text_document.document_uri().to_file_path()
Expand All @@ -750,14 +750,15 @@ async def process_definition_request(
)
document_path = document_path.resolve()
if document_path not in self.server_state.opened_documents:
return await lsp.write_json_rpc(
await lsp.write_json_rpc(
self.output_channel,
json_rpc.SuccessResponse(
id=request_id,
activity_key=activity_key,
result=lsp.LspLocation.cached_schema().dump([], many=True),
),
)
return None
daemon_status_before = self.server_state.status_tracker.get_status()
shadow_mode = self.get_language_server_features().definition.is_shadow()
# In shadow mode, we need to return an empty response immediately
Expand All @@ -779,9 +780,11 @@ async def process_definition_request(
if isinstance(result, DaemonQueryFailure):
error_message = result.error_message
output_result = []
error_source = result.error_source
else:
error_message = None
output_result = result
error_source = None
# Unless we are in shadow mode, we send the response as output
if not shadow_mode:
await lsp.write_json_rpc(
Expand Down Expand Up @@ -831,12 +834,14 @@ async def process_definition_request(
activity_key,
)

return error_source

async def process_completion_request(
self,
parameters: lsp.CompletionParameters,
request_id: Union[int, str, None],
activity_key: Optional[Dict[str, object]] = None,
) -> None:
) -> Optional[Exception]:
document_path = parameters.text_document.document_uri().to_file_path()
if document_path is None:
raise json_rpc.InvalidRequestError(
Expand All @@ -845,14 +850,15 @@ async def process_completion_request(
document_path = document_path.resolve()

if document_path not in self.server_state.opened_documents:
return await lsp.write_json_rpc(
await lsp.write_json_rpc(
self.output_channel,
json_rpc.SuccessResponse(
id=request_id,
activity_key=activity_key,
result=[],
),
)
return None
daemon_status_before = self.server_state.status_tracker.get_status()
completion_timer = timer.Timer()

Expand All @@ -870,7 +876,10 @@ async def process_completion_request(
)
)
error_message = result.error_message
error_source = result.error_source
result = []
else:
error_source = None

raw_result = [completion_item.to_dict() for completion_item in result]

Expand Down Expand Up @@ -902,6 +911,8 @@ async def process_completion_request(
activity_key,
)

return error_source

async def process_document_symbols_request(
self,
parameters: lsp.DocumentSymbolsParameters,
Expand Down Expand Up @@ -1409,59 +1420,73 @@ async def wait_for_exit(self) -> commands.ExitCode:
await _wait_for_exit(self.input_channel, self.output_channel)
return commands.ExitCode.SUCCESS

async def _try_restart_pyre_daemon(self) -> None:
async def _restart_if_needed(
self, error_source: Optional[Exception] = None
) -> None:
if (
self.server_state.consecutive_start_failure
< CONSECUTIVE_START_ATTEMPT_THRESHOLD
>= CONSECUTIVE_START_ATTEMPT_THRESHOLD
):
await self.daemon_manager.ensure_task_running()
else:
LOG.info(
"Not restarting Pyre since failed consecutive start attempt limit"
" has been reached."
)
return

if isinstance(
error_source,
(
connections.ConnectionFailure,
asyncio.IncompleteReadError,
ConnectionError,
),
): # do we think the daemon is probably down at this point?
# Terminate any existing daemon processes before starting a new one
LOG.info("Forcing pyre daemon restart...")
await self.daemon_manager.ensure_task_stop() # make sure it's down

# restart if needed
if not self.daemon_manager.is_task_running():
# Just check to ensure that the daemon is running and restart if needed
await self.daemon_manager.ensure_task_running()

async def dispatch_nonblocking_request(self, request: json_rpc.Request) -> None:
if request.method == "exit" or request.method == "shutdown":
raise Exception("Exit and shutdown requests should be blocking")
elif request.method == "textDocument/definition":
await self.api.process_definition_request(
error_source = await self.api.process_definition_request(
lsp.DefinitionParameters.from_json_rpc_parameters(
request.extract_parameters()
),
request.id,
request.activity_key,
)
if not self.daemon_manager.is_task_running():
await self._try_restart_pyre_daemon()
await self._restart_if_needed(error_source)
elif request.method == "textDocument/completion":
LOG.debug("Received 'textDocument/completion' request.")
await self.api.process_completion_request(
error_source = await self.api.process_completion_request(
lsp.CompletionParameters.from_json_rpc_parameters(
request.extract_parameters()
),
request.id,
request.activity_key,
)
if not self.daemon_manager.is_task_running():
await self._try_restart_pyre_daemon()
await self._restart_if_needed(error_source)
elif request.method == "textDocument/didOpen":
await self.api.process_open_request(
lsp.DidOpenTextDocumentParameters.from_json_rpc_parameters(
request.extract_parameters()
),
request.activity_key,
)
if not self.daemon_manager.is_task_running():
await self._try_restart_pyre_daemon()
await self._restart_if_needed()
elif request.method == "textDocument/didChange":
await self.api.process_did_change_request(
lsp.DidChangeTextDocumentParameters.from_json_rpc_parameters(
request.extract_parameters()
)
)
if not self.daemon_manager.is_task_running():
await self._try_restart_pyre_daemon()
await self._restart_if_needed()
elif request.method == "textDocument/didClose":
await self.api.process_close_request(
lsp.DidCloseTextDocumentParameters.from_json_rpc_parameters(
Expand All @@ -1475,8 +1500,7 @@ async def dispatch_nonblocking_request(self, request: json_rpc.Request) -> None:
),
request.activity_key,
)
if not self.daemon_manager.is_task_running():
await self._try_restart_pyre_daemon()
await self._restart_if_needed()
elif request.method == "textDocument/hover":
await self.api.process_hover_request(
lsp.HoverParameters.from_json_rpc_parameters(
Expand Down
5 changes: 3 additions & 2 deletions client/commands/tests/language_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ async def process_definition_request(
parameters: lsp.DefinitionParameters,
request_id: Union[int, str, None],
activity_key: Optional[Dict[str, object]] = None,
) -> None:
) -> Optional[Exception]:
await asyncio.Event().wait()
return None

async def process_hover_request(
self,
Expand Down Expand Up @@ -119,7 +120,7 @@ async def process_completion_request(
parameters: lsp.CompletionParameters,
request_id: Union[int, str, None],
activity_key: Optional[Dict[str, object]] = None,
) -> None:
) -> Optional[Exception]:
raise NotImplementedError()

async def process_document_symbols_request(
Expand Down
17 changes: 13 additions & 4 deletions client/language_server/code_navigation_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def to_json(self) -> List[object]:
@dataclasses.dataclass(frozen=True)
class ErrorResponse:
message: str
error_source: Optional[Exception] = None


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -311,7 +312,9 @@ async def async_handle_hover_request(
socket_path, raw_request
)
if isinstance(response, daemon_connection.DaemonConnectionFailure):
return ErrorResponse(message=response.error_message)
return ErrorResponse(
message=response.error_message, error_source=response.error_source
)
response = parse_raw_response(
response, expected_response_kind="Hover", response_type=HoverResponse
)
Expand All @@ -329,7 +332,9 @@ async def async_handle_definition_request(
socket_path, raw_request
)
if isinstance(response, daemon_connection.DaemonConnectionFailure):
return ErrorResponse(message=response.error_message)
return ErrorResponse(
message=response.error_message, error_source=response.error_source
)
return parse_raw_response(
response,
expected_response_kind="LocationOfDefinition",
Expand All @@ -346,7 +351,9 @@ async def async_handle_type_errors_request(
socket_path, raw_request
)
if isinstance(response, daemon_connection.DaemonConnectionFailure):
return ErrorResponse(message=response.error_message)
return ErrorResponse(
message=response.error_message, error_source=response.error_source
)
return parse_raw_response(
response, expected_response_kind="TypeErrors", response_type=TypeErrorsResponse
)
Expand All @@ -361,7 +368,9 @@ async def async_handle_completion_request(
socket_path, raw_request
)
if isinstance(response, daemon_connection.DaemonConnectionFailure):
return ErrorResponse(message=response.error_message)
return ErrorResponse(
message=response.error_message, error_source=response.error_source
)
return parse_raw_response(
response,
expected_response_kind="Completion",
Expand Down
8 changes: 5 additions & 3 deletions client/language_server/daemon_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
import traceback
from pathlib import Path
from typing import AsyncIterator, Union
from typing import AsyncIterator, Optional, Union

from .. import dataclasses_json_extensions as json_mixins, log
from . import connections
Expand All @@ -34,6 +34,7 @@
@dataclasses.dataclass(frozen=True)
class DaemonConnectionFailure(json_mixins.CamlCaseAndExcludeJsonMixin):
error_message: str
error_source: Optional[Exception] = None


def send_raw_request(socket_path: Path, raw_request: str) -> str:
Expand Down Expand Up @@ -101,6 +102,7 @@ async def attempt_send_async_raw_request(
ConnectionError,
) as error:
return DaemonConnectionFailure(
"Could not establish connection with an existing Pyre server "
f"at {socket_path}: {error}. Type: {type(error)}. Stacktrace: {traceback.format_exc( limit = None, chain = True)}"
error_message="Could not establish connection with an existing Pyre server "
f"at {socket_path}: {error}. Type: {type(error)}. Stacktrace: {traceback.format_exc( limit = None, chain = True)}",
error_source=error,
)

0 comments on commit 61f48bb

Please sign in to comment.