-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for using 'langchain serve' endpoints (#588)
* Added support for using langchain serve endpoints * Added mitigation example that is common for Gemini models * Changed to correct output handling * Values not used removed * Returning None instead of empty string * Validate URI + Exception handling + not using model name * Clearer error handling * Added exception * Suggestion for tests * sync pyproject toml & requirements.txt * no need to instantiate class * set `name` from endpoint uri Co-authored-by: Jeffrey Martin <[email protected]> Signed-off-by: Leon Derczynski <[email protected]> * remove default name setting code Co-authored-by: Jeffrey Martin <[email protected]> Signed-off-by: Leon Derczynski <[email protected]> * refer to name set by class logic rather than value passed to constructor Co-authored-by: Jeffrey Martin <[email protected]> Signed-off-by: Leon Derczynski <[email protected]> * mv langchain serv testing req to optional segment * align tests with langchain serve's treatment of name in constructor --------- Signed-off-by: Leon Derczynski <[email protected]> Co-authored-by: Leon Derczynski <[email protected]> Co-authored-by: Jeffrey Martin <[email protected]>
- Loading branch information
1 parent
79c7649
commit 6455c86
Showing
5 changed files
with
158 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import logging | ||
import json | ||
import requests | ||
import os | ||
from urllib.parse import urlparse | ||
|
||
from garak import _config | ||
from garak.generators.base import Generator | ||
|
||
|
||
class LangChainServeLLMGenerator(Generator): | ||
"""Class supporting LangChain Serve LLM interfaces via HTTP POST requests. | ||
This class facilitates communication with LangChain Serve's LLMs through a web API, making it possible | ||
to utilize external LLMs not directly integrated into the LangChain library. It requires setting up | ||
an API endpoint using LangChain Serve. | ||
Utilizes the HTTP POST method to send prompts to the specified LLM and retrieves the generated text | ||
response. It is necessary to ensure that the API endpoint is correctly set up and accessible. | ||
Inherits from Garak's base Generator class, extending its capabilities to support web-based LLM services. | ||
The API endpoint is set through the 'LANGCHAIN_SERVE_URI' environment variable, which should be the base URI | ||
of the LangChain Serve deployment. The 'invoke' endpoint is then appended to this URI to form the full API endpoint URL. | ||
Example of setting up the environment variable: | ||
export LANGCHAIN_SERVE_URI=http://127.0.0.1:8000/rag-chroma-private | ||
""" | ||
|
||
generator_family_name = "LangChainServe" | ||
config_hash = "default" | ||
|
||
def __init__( | ||
self, name=None, generations=10 | ||
): # name not required, will be extracted from uri | ||
self.generations = generations | ||
api_uri = os.getenv("LANGCHAIN_SERVE_URI") | ||
if not self._validate_uri(api_uri): | ||
raise ValueError("Invalid API endpoint URI") | ||
self.name = api_uri.split("/")[-1] | ||
self.fullname = f"LangChain Serve LLM {self.name}" | ||
self.api_endpoint = f"{api_uri}/invoke" | ||
|
||
super().__init__(self.name, generations=generations) | ||
|
||
@staticmethod | ||
def _validate_uri(uri): | ||
"""Validates the given URI for correctness.""" | ||
try: | ||
result = urlparse(uri) | ||
return all([result.scheme, result.netloc]) | ||
except Exception as e: | ||
logging.error(f"URL parsing error: {e}") | ||
return False | ||
|
||
def _call_model(self, prompt: str) -> str: | ||
"""Makes an HTTP POST request to the LangChain Serve API endpoint to invoke the LLM with a given prompt.""" | ||
headers = {"Content-Type": "application/json", "Accept": "application/json"} | ||
payload = {"input": prompt, "config": {}, "kwargs": {}} | ||
|
||
try: | ||
response = requests.post( | ||
f"{self.api_endpoint}?config_hash={self.config_hash}", | ||
headers=headers, | ||
data=json.dumps(payload), | ||
) | ||
response.raise_for_status() | ||
except requests.exceptions.HTTPError as e: | ||
if 400 <= response.status_code < 500: | ||
logging.error(f"Client error for prompt {prompt}: {e}") | ||
return None | ||
elif 500 <= response.status_code < 600: | ||
logging.error(f"Server error for prompt {prompt}: {e}") | ||
raise | ||
except requests.exceptions.RequestException as e: | ||
logging.error(f"Request failed: {e}") | ||
return None | ||
|
||
try: | ||
response_data = response.json() | ||
if "output" not in response_data: | ||
logging.error(f"No output found in response: {response_data}") | ||
return None | ||
return response_data.get("output") | ||
except json.JSONDecodeError as e: | ||
logging.error( | ||
f"Failed to decode JSON from response: {response.text}, error: {e}" | ||
) | ||
return None | ||
except Exception as e: | ||
logging.error(f"Unexpected error processing response: {e}") | ||
return None | ||
|
||
|
||
default_class = "LangChainServeLLMGenerator" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import pytest | ||
import requests_mock | ||
|
||
from garak.generators.langchain_serve import LangChainServeLLMGenerator | ||
|
||
DEFAULT_GENERATIONS_QTY = 10 | ||
|
||
|
||
@pytest.fixture | ||
def set_env_vars(): | ||
os.environ["LANGCHAIN_SERVE_URI"] = "http://127.0.0.1:8000" | ||
yield | ||
os.environ.pop("LANGCHAIN_SERVE_URI", None) | ||
|
||
|
||
def test_validate_uri(): | ||
assert LangChainServeLLMGenerator._validate_uri("http://127.0.0.1:8000") == True | ||
assert LangChainServeLLMGenerator._validate_uri("bad_uri") == False | ||
|
||
|
||
@pytest.mark.usefixtures("set_env_vars") | ||
def test_langchain_serve_generator_initialization(): | ||
generator = LangChainServeLLMGenerator() | ||
assert generator.name == "127.0.0.1:8000" | ||
assert generator.generations == DEFAULT_GENERATIONS_QTY | ||
assert generator.api_endpoint == "http://127.0.0.1:8000/invoke" | ||
|
||
|
||
@pytest.mark.usefixtures("set_env_vars") | ||
def test_langchain_serve_generation(requests_mock): | ||
requests_mock.post( | ||
"http://127.0.0.1:8000/invoke?config_hash=default", | ||
json={"output": ["Generated text"]}, | ||
) | ||
generator = LangChainServeLLMGenerator() | ||
output = generator._call_model("Hello LangChain!") | ||
assert len(output) == 1 | ||
assert output[0] == "Generated text" | ||
|
||
|
||
@pytest.mark.usefixtures("set_env_vars") | ||
def test_error_handling(requests_mock): | ||
requests_mock.post( | ||
"http://127.0.0.1:8000/invoke?config_hash=default", status_code=500 | ||
) | ||
generator = LangChainServeLLMGenerator() | ||
with pytest.raises(Exception): | ||
generator._call_model("This should raise an error") | ||
|
||
|
||
@pytest.mark.usefixtures("set_env_vars") | ||
def test_bad_response_handling(requests_mock): | ||
requests_mock.post( | ||
"http://127.0.0.1:8000/invoke?config_hash=default", json={}, status_code=200 | ||
) | ||
generator = LangChainServeLLMGenerator() | ||
output = generator._call_model("This should not find output") | ||
assert output is None |