Skip to content

Commit

Permalink
Set env vars for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jlowin committed Sep 12, 2024
1 parent 2764e15 commit 8a507bc
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions tests/llm/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,40 @@
from controlflow.llm.models import get_model


def test_get_model_from_openai():
def test_get_model_from_openai(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "fake_openai_api_key")
model = get_model("openai/gpt-4o-mini")
assert isinstance(model, ChatOpenAI)
assert model.model_name == "gpt-4o-mini"


def test_get_model_from_anthropic():
def test_get_model_from_anthropic(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "fake_anthropic_api_key")
model = get_model("anthropic/claude-3-haiku-20240307")
assert isinstance(model, ChatAnthropic)
assert model.model == "claude-3-haiku-20240307"


def test_get_azure_openai_model():
def test_get_azure_openai_model(monkeypatch):
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "fake_azure_openai_api_key")
monkeypatch.setenv(
"AZURE_OPENAI_ENDPOINT", "https://fake-endpoint.openai.azure.com"
)
monkeypatch.setenv("OPENAI_API_VERSION", "2024-05-01-preview")
model = get_model("azure-openai/gpt-4")
assert isinstance(model, AzureChatOpenAI)
assert model.deployment_name == "gpt-4"
assert model.model_name == "gpt-4"


def test_get_google_model():
def test_get_google_model(monkeypatch):
monkeypatch.setenv("GOOGLE_API_KEY", "fake_google_api_key")
model = get_model("google/gemini-1.5-pro")
assert isinstance(model, ChatGoogleGenerativeAI)
assert model.model == "models/gemini-1.5-pro"


def test_get_groq_model():
def test_get_groq_model(monkeypatch):
monkeypatch.setenv("GROQ_API_KEY", "fake_groq_api_key")
model = get_model("groq/mixtral-8x7b-32768")
assert isinstance(model, ChatGroq)
assert model.model_name == "mixtral-8x7b-32768"
Expand All @@ -49,7 +58,8 @@ def test_get_model_with_unsupported_provider():
get_model("unsupported/model-name")


def test_get_model_with_temperature():
def test_get_model_with_temperature(monkeypatch):
monkeypatch.setenv("ANTHROPIC_API_KEY", "fake_anthropic_api_key")
model = get_model("anthropic/claude-3-haiku-20240307", temperature=0.7)
assert isinstance(model, ChatAnthropic)
assert model.temperature == 0.7

0 comments on commit 8a507bc

Please sign in to comment.