Skip to content

Commit

Permalink
Support for using 'langchain serve' endpoints (#588)
Browse files Browse the repository at this point in the history
* Added support for using langchain serve endpoints

* Added mitigation example that is common for Gemini models

* Changed to correct output handling

* Values not used removed

* Returning None instead of empty string

* Validate URI + Exception handling + not using model name

* Clearer error handling

* Added exception

* Suggestion for tests

* sync pyproject toml & requirements.txt

* no need to instantiate class

* set `name` from endpoint uri

Co-authored-by: Jeffrey Martin <[email protected]>
Signed-off-by: Leon Derczynski <[email protected]>

* remove default name setting code

Co-authored-by: Jeffrey Martin <[email protected]>
Signed-off-by: Leon Derczynski <[email protected]>

* refer to name set by class logic rather than value passed to constructor

Co-authored-by: Jeffrey Martin <[email protected]>
Signed-off-by: Leon Derczynski <[email protected]>

* mv langchain serv testing req to optional segment

* align tests with langchain serve's treatment of name in constructor

---------

Signed-off-by: Leon Derczynski <[email protected]>
Co-authored-by: Leon Derczynski <[email protected]>
Co-authored-by: Jeffrey Martin <[email protected]>
  • Loading branch information
3 people authored Apr 26, 2024
1 parent 79c7649 commit 6455c86
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 2 deletions.
1 change: 1 addition & 0 deletions garak/detectors/mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(self):
"I cannot perform the task",
"As a computer program",
"I'm sorry, but I'm an AI model",
"I'm sorry, I can't respond to that.",
]
super().__init__(substrings)

Expand Down
94 changes: 94 additions & 0 deletions garak/generators/langchain_serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import logging
import json
import requests
import os
from urllib.parse import urlparse

from garak import _config
from garak.generators.base import Generator


class LangChainServeLLMGenerator(Generator):
"""Class supporting LangChain Serve LLM interfaces via HTTP POST requests.
This class facilitates communication with LangChain Serve's LLMs through a web API, making it possible
to utilize external LLMs not directly integrated into the LangChain library. It requires setting up
an API endpoint using LangChain Serve.
Utilizes the HTTP POST method to send prompts to the specified LLM and retrieves the generated text
response. It is necessary to ensure that the API endpoint is correctly set up and accessible.
Inherits from Garak's base Generator class, extending its capabilities to support web-based LLM services.
The API endpoint is set through the 'LANGCHAIN_SERVE_URI' environment variable, which should be the base URI
of the LangChain Serve deployment. The 'invoke' endpoint is then appended to this URI to form the full API endpoint URL.
Example of setting up the environment variable:
export LANGCHAIN_SERVE_URI=http://127.0.0.1:8000/rag-chroma-private
"""

generator_family_name = "LangChainServe"
config_hash = "default"

def __init__(
self, name=None, generations=10
): # name not required, will be extracted from uri
self.generations = generations
api_uri = os.getenv("LANGCHAIN_SERVE_URI")
if not self._validate_uri(api_uri):
raise ValueError("Invalid API endpoint URI")
self.name = api_uri.split("/")[-1]
self.fullname = f"LangChain Serve LLM {self.name}"
self.api_endpoint = f"{api_uri}/invoke"

super().__init__(self.name, generations=generations)

@staticmethod
def _validate_uri(uri):
"""Validates the given URI for correctness."""
try:
result = urlparse(uri)
return all([result.scheme, result.netloc])
except Exception as e:
logging.error(f"URL parsing error: {e}")
return False

def _call_model(self, prompt: str) -> str:
"""Makes an HTTP POST request to the LangChain Serve API endpoint to invoke the LLM with a given prompt."""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
payload = {"input": prompt, "config": {}, "kwargs": {}}

try:
response = requests.post(
f"{self.api_endpoint}?config_hash={self.config_hash}",
headers=headers,
data=json.dumps(payload),
)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if 400 <= response.status_code < 500:
logging.error(f"Client error for prompt {prompt}: {e}")
return None
elif 500 <= response.status_code < 600:
logging.error(f"Server error for prompt {prompt}: {e}")
raise
except requests.exceptions.RequestException as e:
logging.error(f"Request failed: {e}")
return None

try:
response_data = response.json()
if "output" not in response_data:
logging.error(f"No output found in response: {response_data}")
return None
return response_data.get("output")
except json.JSONDecodeError as e:
logging.error(
f"Failed to decode JSON from response: {response.text}, error: {e}"
)
return None
except Exception as e:
logging.error(f"Unexpected error processing response: {e}")
return None


default_class = "LangChainServeLLMGenerator"
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ dependencies = [
"deepl==1.17.0",
"fschat>=0.2.36",
"litellm>=1.33.8",
"typing>=3.7,<3.8; python_version<'3.5'"
"typing>=3.7,<3.8; python_version<'3.5'",
]

[project.optional-dependencies]
tests = [
"pytest>=8.0",
"requests-mock==1.12.1",
]
lint = [
"black>=22.3",
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ litellm>=1.33.8
typing>=3.7,<3.8; python_version<'3.5'
# tests
pytest>=8.0
requests-mock==1.12.1
# lint
black>=22.3
black>=22.3
59 changes: 59 additions & 0 deletions tests/generators/test_langchain_serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import pytest
import requests_mock

from garak.generators.langchain_serve import LangChainServeLLMGenerator

DEFAULT_GENERATIONS_QTY = 10


@pytest.fixture
def set_env_vars():
os.environ["LANGCHAIN_SERVE_URI"] = "http://127.0.0.1:8000"
yield
os.environ.pop("LANGCHAIN_SERVE_URI", None)


def test_validate_uri():
assert LangChainServeLLMGenerator._validate_uri("http://127.0.0.1:8000") == True
assert LangChainServeLLMGenerator._validate_uri("bad_uri") == False


@pytest.mark.usefixtures("set_env_vars")
def test_langchain_serve_generator_initialization():
generator = LangChainServeLLMGenerator()
assert generator.name == "127.0.0.1:8000"
assert generator.generations == DEFAULT_GENERATIONS_QTY
assert generator.api_endpoint == "http://127.0.0.1:8000/invoke"


@pytest.mark.usefixtures("set_env_vars")
def test_langchain_serve_generation(requests_mock):
requests_mock.post(
"http://127.0.0.1:8000/invoke?config_hash=default",
json={"output": ["Generated text"]},
)
generator = LangChainServeLLMGenerator()
output = generator._call_model("Hello LangChain!")
assert len(output) == 1
assert output[0] == "Generated text"


@pytest.mark.usefixtures("set_env_vars")
def test_error_handling(requests_mock):
requests_mock.post(
"http://127.0.0.1:8000/invoke?config_hash=default", status_code=500
)
generator = LangChainServeLLMGenerator()
with pytest.raises(Exception):
generator._call_model("This should raise an error")


@pytest.mark.usefixtures("set_env_vars")
def test_bad_response_handling(requests_mock):
requests_mock.post(
"http://127.0.0.1:8000/invoke?config_hash=default", json={}, status_code=200
)
generator = LangChainServeLLMGenerator()
output = generator._call_model("This should not find output")
assert output is None

0 comments on commit 6455c86

Please sign in to comment.