Skip to content

Commit

Permalink
Raise ConnectionAbortedError in Client if server closes connection
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbaynham committed May 26, 2020
1 parent 1239b5d commit a0de059
Showing 1 changed file with 24 additions and 4 deletions.
28 changes: 24 additions & 4 deletions sipyco/pc_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from sipyco.asyncio_tools import AsyncioServer as _AsyncioServer
from sipyco.packed_exceptions import *

CONNECTION_CLOSED_ERR = ConnectionAbortedError("Connection closed by the server")

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -74,6 +76,10 @@ class Client:
automatically attempted. The user must call :meth:`~sipyco.pc_rpc.Client.close_rpc` to
free resources properly after initialization completes successfully.
If the remote server shuts down during operation, ConnectionAbortedError is
raised by Client methods. The user should call
:meth:`~sipyco.pc_rpc.Client.close_rpc` and then discard this object.
:param host: Identifier of the server. The string can represent a
hostname or a IPv4 or IPv6 address (see
``socket.create_connection`` in the Python standard library).
Expand All @@ -93,9 +99,12 @@ class Client:
in the middle of a RPC can break subsequent RPCs (from the same
client).
"""

def __init__(self, host, port, target_name=AutoTarget, timeout=None):
self.__socket = socket.create_connection((host, port), timeout)

self.__closed = False

try:
self.__socket.sendall(_init_string)

Expand Down Expand Up @@ -144,12 +153,17 @@ def __send(self, obj):
self.__socket.sendall(line.encode())

def __recv(self):
if self.__closed:
raise CONNECTION_CLOSED_ERR
buf = self.__socket.recv(4096).decode()
while "\n" not in buf:
more = self.__socket.recv(4096)
if not more:
break
buf += more.decode()
if not buf:
self.__closed = True
raise CONNECTION_CLOSED_ERR
return pyon.decode(buf)

def __do_action(self, action):
Expand All @@ -174,6 +188,7 @@ def get_rpc_method_list(self):
def __getattr__(self, name):
if name not in self.__valid_methods:
raise AttributeError

def proxy(*args, **kwargs):
return self.__do_rpc(name, args, kwargs)
return proxy
Expand All @@ -188,6 +203,7 @@ class AsyncioClient:
Concurrent access from different asyncio tasks is supported; all calls
use a single lock.
"""

def __init__(self):
self.__lock = asyncio.Lock()
self.__reader = None
Expand Down Expand Up @@ -277,6 +293,7 @@ async def __do_rpc(self, name, args, kwargs):
def __getattr__(self, name):
if name not in self.__valid_methods:
raise AttributeError

async def proxy(*args, **kwargs):
res = await self.__do_rpc(name, args, kwargs)
return res
Expand All @@ -296,6 +313,7 @@ class BestEffortClient:
:param retry: Amount of time to wait between retries when reconnecting
in the background.
"""

def __init__(self, host, port, target_name,
firstcon_timeout=1.0, retry=5.0):
self.__host = host
Expand Down Expand Up @@ -407,6 +425,7 @@ def __do_rpc(self, name, args, kwargs):
def __getattr__(self, name):
if name not in self.__valid_methods:
raise AttributeError

def proxy(*args, **kwargs):
return self.__do_rpc(name, args, kwargs)
return proxy
Expand Down Expand Up @@ -473,6 +492,7 @@ class Server(_AsyncioServer):
:param allow_parallel: Allow concurrent asyncio calls to the target's
methods.
"""

def __init__(self, targets, description=None, builtin_terminate=False,
allow_parallel=False):
_AsyncioServer.__init__(self)
Expand All @@ -490,12 +510,12 @@ def __init__(self, targets, description=None, builtin_terminate=False,
def _document_function(function):
"""
Turn a function into a tuple of its arguments and documentation.
Allows remote inspection of what methods are available on a local device.
Args:
function (Callable): a Python function to be documented.
Returns:
Tuple[dict, str]: tuple of (argument specifications,
function documentation).
Expand Down Expand Up @@ -603,7 +623,7 @@ async def _handle_connection_cr(self, reader, writer):
if not line:
break
reply = await self._process_and_pyonize(target,
pyon.decode(line.decode()))
pyon.decode(line.decode()))
writer.write((reply + "\n").encode())
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError):
# May happens on Windows when client disconnects
Expand Down

0 comments on commit a0de059

Please sign in to comment.