forked from ebraminio/aiosseclient
-
Notifications
You must be signed in to change notification settings - Fork 0
/
aiosseclient.py
154 lines (133 loc) · 4.8 KB
/
aiosseclient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
'''Main module'''
import re
import logging
from typing import (
List,
Optional,
AsyncGenerator,
Final,
)
import aiohttp
# pylint: disable=too-many-arguments, dangerous-default-value, redefined-builtin
_SSE_LINE_PATTERN: Final[re.Pattern] = re.compile('(?P<name>[^:]*):?( ?(?P<value>.*))?')
_LOGGER = logging.getLogger(__name__)
# Good parts of the below class is adopted from:
# https://github.com/btubbs/sseclient/blob/db38dc6/sseclient.py
class Event:
'''The object created as the result of received events'''
data: str
event: str
id: Optional[str]
retry: Optional[bool]
def __init__(
self,
data: str = '',
event: str = 'message',
id: Optional[str] = None,
retry: Optional[bool] = None
):
self.data = data
self.event = event
self.id = id
self.retry = retry
def dump(self) -> str:
'''Serialize the event object to a string'''
lines = []
if self.id:
lines.append(f'id: {self.id}')
# Only include an event line if it's not the default already.
if self.event != 'message':
lines.append(f'event: {self.event}')
if self.retry:
lines.append(f'retry: {self.retry}')
lines.extend(f'data: {d}' for d in self.data.split('\n'))
return '\n'.join(lines) + '\n\n'
def encode(self) -> bytes:
'''Serialize the event object to a bytes object'''
return self.dump().encode('utf-8')
@classmethod
def parse(cls, raw):
'''
Given a possibly-multiline string representing an SSE message, parse it
and return a Event object.
'''
msg = cls()
for line in raw.splitlines():
m = _SSE_LINE_PATTERN.match(line)
if m is None:
# Malformed line. Discard but warn.
_LOGGER.warning('Invalid SSE line: %s', line)
continue
name = m.group('name')
if name == '':
# line began with a ':', so is a comment. Ignore
continue
value = m.group('value')
if name == 'data':
# If we already have some data, then join to it with a newline.
# Else this is it.
if msg.data:
msg.data = f'{msg.data}\n{value}'
else:
msg.data = value
elif name == 'event':
msg.event = value
elif name == 'id':
msg.id = value
elif name == 'retry':
msg.retry = bool(value)
return msg
def __str__(self) -> str:
return self.data
# pylint: disable=too-many-arguments, dangerous-default-value
async def aiosseclient(
url: str,
last_id: Optional[str] = None,
valid_http_codes: List[int] = [200, 301, 307],
exit_events: List[str] = [],
timeout_total: Optional[float] = None,
headers: Optional[dict[str, str]] = None,
) -> AsyncGenerator[Event, None]:
'''Canonical API of the library'''
if headers is None:
headers = {}
# The SSE spec requires making requests with Cache-Control: nocache
headers['Cache-Control'] = 'no-cache'
# The 'Accept' header is not required, but explicit > implicit
headers['Accept'] = 'text/event-stream'
if last_id:
headers['Last-Event-ID'] = last_id
# Override default timeout of 5 minutes
timeout = aiohttp.ClientTimeout(total=timeout_total, connect=None,
sock_connect=None, sock_read=None)
async with aiohttp.ClientSession(timeout=timeout) as session:
response = None
try:
_LOGGER.info('Session created: %s', session)
response = await session.get(url, headers=headers)
if response.status not in valid_http_codes:
_LOGGER.error('Invalid HTTP response.status: %s', response.status)
await session.close()
lines = []
async for line in response.content:
line = line.decode('utf8')
if line in {'\n', '\r', '\r\n'}:
if not lines:
continue
if lines[0] == ':ok\n':
lines = []
continue
current_event = Event.parse(''.join(lines))
yield current_event
if current_event.event in exit_events:
await session.close()
lines = []
else:
lines.append(line)
except TimeoutError as sseerr:
_LOGGER.error('TimeoutError: %s', sseerr)
finally:
if response:
response.close()
if not session.closed:
await session.close()