Skip to content

Commit

Permalink
fix(azure): refresh auth token during retries
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Jul 9, 2024
1 parent 50371bf commit e5341d7
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 8 deletions.
22 changes: 16 additions & 6 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,11 @@ def _request(
stream: bool,
stream_cls: type[_StreamT] | None,
) -> ResponseT | _StreamT:
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)

cast_to = self._maybe_override_cast_to(cast_to, options)
self._prepare_options(options)

Expand All @@ -980,7 +985,7 @@ def _request(

if retries > 0:
return self._retry_request(
options,
input_options,
cast_to,
retries,
stream=stream,
Expand All @@ -995,7 +1000,7 @@ def _request(

if retries > 0:
return self._retry_request(
options,
input_options,
cast_to,
retries,
stream=stream,
Expand Down Expand Up @@ -1024,7 +1029,7 @@ def _request(
if retries > 0 and self._should_retry(err.response):
err.response.close()
return self._retry_request(
options,
input_options,
cast_to,
retries,
err.response.headers,
Expand Down Expand Up @@ -1533,6 +1538,11 @@ async def _request(
# execute it earlier while we are in an async context
self._platform = await asyncify(get_platform)()

# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)

cast_to = self._maybe_override_cast_to(cast_to, options)
await self._prepare_options(options)

Expand All @@ -1555,7 +1565,7 @@ async def _request(

if retries > 0:
return await self._retry_request(
options,
input_options,
cast_to,
retries,
stream=stream,
Expand All @@ -1570,7 +1580,7 @@ async def _request(

if retries > 0:
return await self._retry_request(
options,
input_options,
cast_to,
retries,
stream=stream,
Expand All @@ -1593,7 +1603,7 @@ async def _request(
if retries > 0 and self._should_retry(err.response):
await err.response.aclose()
return await self._retry_request(
options,
input_options,
cast_to,
retries,
err.response.headers,
Expand Down
88 changes: 86 additions & 2 deletions tests/lib/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Union
from typing_extensions import Literal
from typing import Union, cast
from typing_extensions import Literal, Protocol

import httpx
import pytest
from respx import MockRouter

from openai._models import FinalRequestOptions
from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
Expand All @@ -22,6 +24,10 @@
)


class MockRequestCall(Protocol):
request: httpx.Request


@pytest.mark.parametrize("client", [sync_client, async_client])
def test_implicit_deployment_path(client: Client) -> None:
req = client._build_request(
Expand Down Expand Up @@ -64,3 +70,81 @@ def test_client_copying_override_options(client: Client) -> None:
api_version="2022-05-01",
)
assert copied._custom_query == {"api-version": "2022-05-01"}


@pytest.mark.respx()
def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

counter = 0

def token_provider() -> str:
nonlocal counter

counter += 1

if counter == 1:
return "first"

return "second"

client = AzureOpenAI(
api_version="2024-02-01",
azure_ad_token_provider=token_provider,
azure_endpoint="https://example-resource.azure.openai.com",
)
client.chat.completions.create(messages=[], model="gpt-4")

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert calls[0].request.headers.get("Authorization") == "Bearer first"
assert calls[1].request.headers.get("Authorization") == "Bearer second"


@pytest.mark.asyncio
@pytest.mark.respx()
async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None:
respx_mock.post(
"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
).mock(
side_effect=[
httpx.Response(500, json={"error": "server error"}),
httpx.Response(200, json={"foo": "bar"}),
]
)

counter = 0

def token_provider() -> str:
nonlocal counter

counter += 1

if counter == 1:
return "first"

return "second"

client = AsyncAzureOpenAI(
api_version="2024-02-01",
azure_ad_token_provider=token_provider,
azure_endpoint="https://example-resource.azure.openai.com",
)

await client.chat.completions.create(messages=[], model="gpt-4")

calls = cast("list[MockRequestCall]", respx_mock.calls)

assert len(calls) == 2

assert calls[0].request.headers.get("Authorization") == "Bearer first"
assert calls[1].request.headers.get("Authorization") == "Bearer second"

0 comments on commit e5341d7

Please sign in to comment.