Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Periodically close open request.Sessions to avoid buggy interaction with Docker Desktop #478

Merged
merged 2 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import platform
import sys
import threading
import time
import warnings
from contextlib import asynccontextmanager
from json import JSONDecodeError
Expand Down Expand Up @@ -32,6 +33,7 @@
from openai.util import ApiType

TIMEOUT_SECS = 600
MAX_SESSION_LIFETIME_SECS = 180
MAX_CONNECTION_RETRIES = 2

# Has one attribute per thread, 'session'.
Expand Down Expand Up @@ -516,6 +518,14 @@ def request_raw(

if not hasattr(_thread_context, "session"):
_thread_context.session = _make_session()
_thread_context.session_create_time = time.time()
elif (
time.time() - getattr(_thread_context, "session_create_time", 0)
>= MAX_SESSION_LIFETIME_SECS
):
_thread_context.session.close()
_thread_context.session = _make_session()
_thread_context.session_create_time = time.time()
try:
result = _thread_context.session.request(
method,
Expand Down
2 changes: 1 addition & 1 deletion openai/openai_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __repr__(self):

def __str__(self):
obj = self.to_dict_recursive()
return json.dumps(obj, sort_keys=True, indent=2)
return json.dumps(obj, indent=2)

def to_dict(self):
return dict(self)
Expand Down
32 changes: 32 additions & 0 deletions openai/tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,35 @@ def test_requestor_azure_ad_headers() -> None:
assert headers["Test_Header"] == "Unit_Test_Header"
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer test_key"


@pytest.mark.requestor
def test_requestor_cycle_sessions(mocker: MockerFixture) -> None:
# HACK: we need to purge the _thread_context to not interfere
# with other tests
from openai.api_requestor import _thread_context

delattr(_thread_context, "session")

api_requestor = APIRequestor(key="test_key", api_type="azure_ad")

mock_session = mocker.MagicMock()
mocker.patch("openai.api_requestor._make_session", lambda: mock_session)

# We don't call `session.close()` if not enough time has elapsed
api_requestor.request_raw("get", "http://example.com")
mock_session.request.assert_called()
api_requestor.request_raw("get", "http://example.com")
mock_session.close.assert_not_called()

mocker.patch("openai.api_requestor.MAX_SESSION_LIFETIME_SECS", 0)

# Due to 0 lifetime, the original session will be closed before the next call
# and a new session will be created
mock_session_2 = mocker.MagicMock()
mocker.patch("openai.api_requestor._make_session", lambda: mock_session_2)
api_requestor.request_raw("get", "http://example.com")
mock_session.close.assert_called()
mock_session_2.request.assert_called()

delattr(_thread_context, "session")
25 changes: 25 additions & 0 deletions openai/tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from tempfile import NamedTemporaryFile

import pytest
Expand Down Expand Up @@ -28,3 +29,27 @@ def test_openai_api_key_path_with_malformed_key(api_key_file) -> None:
api_key_file.flush()
with pytest.raises(ValueError, match="Malformed API key"):
util.default_api_key()


def test_key_order_openai_object_rendering() -> None:
sample_response = {
"id": "chatcmpl-7NaPEA6sgX7LnNPyKPbRlsyqLbr5V",
"object": "chat.completion",
"created": 1685855844,
"model": "gpt-3.5-turbo-0301",
"usage": {"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
"choices": [
{
"message": {
"role": "assistant",
"content": "The 2020 World Series was played at Globe Life Field in Arlington, Texas. It was the first time that the World Series was played at a neutral site because of the COVID-19 pandemic.",
},
"finish_reason": "stop",
"index": 0,
}
],
}

oai_object = util.convert_to_openai_object(sample_response)
# The `__str__` method was sorting while dumping to json
assert list(json.loads(str(oai_object)).keys()) == list(sample_response.keys())