Skip to content

Commit

Permalink
fix(azure): refresh auth token during retries (#1533)
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie authored and stainless-app[bot] committed Jul 9, 2024
1 parent 41f682b commit 6b01ba6
Showing 1 changed file with 86 additions and 2 deletions.
88 changes: 86 additions & 2 deletions tests/lib/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Union
from typing_extensions import Literal
from typing import Union, cast
from typing_extensions import Literal, Protocol

import httpx
import pytest
from respx import MockRouter

from openai._models import FinalRequestOptions
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
Expand All @@ -22,6 +24,10 @@
)


class MockRequestCall(Protocol):
request: httpx.Request


@pytest.mark.parametrize("client", [sync_client, async_client])
def test_implicit_deployment_path(client: Client) -> None:
req = client._build_request(
Expand Down Expand Up @@ -64,3 +70,81 @@ def test_client_copying_override_options(client: Client) -> None:
api_version="2022-05-01",
)
assert copied._custom_query == {"api-version": "2022-05-01"}


@pytest.mark.respx()
def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

counter = 0

def token_provider() -> str:
nonlocal counter

counter += 1

if counter == 1:
return "first"

return "second"

client = AzureOpenAI(
api_version="2024-02-01",
azure_ad_token_provider=token_provider,
azure_endpoint="https://example-resource.azure.openai.com",
)
client.chat.completions.create(messages=[], model="gpt-4")

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert calls[0].request.headers.get("Authorization") == "Bearer first"
assert calls[1].request.headers.get("Authorization") == "Bearer second"


@pytest.mark.asyncio
@pytest.mark.respx()
async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

counter = 0

def token_provider() -> str:
nonlocal counter

counter += 1

if counter == 1:
return "first"

return "second"

client = AsyncAzureOpenAI(
api_version="2024-02-01",
azure_ad_token_provider=token_provider,
azure_endpoint="https://example-resource.azure.openai.com",
)

await client.chat.completions.create(messages=[], model="gpt-4")

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert calls[0].request.headers.get("Authorization") == "Bearer first"
assert calls[1].request.headers.get("Authorization") == "Bearer second"

0 comments on commit 6b01ba6

Please sign in to comment.