oai.client
ModelClient
class ModelClient(Protocol)
A client class must implement the following methods:
- create must return a response object that implements the ModelClientResponseProtocol
- cost must return the cost of the response
- get_usage must return a dict with the following keys:
- prompt_tokens
- completion_tokens
- total_tokens
- cost
- model
This class is used to create a client that can be used by OpenAIWrapper. The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
message_retrieval
def message_retrieval(
response: ModelClientResponseProtocol
) -> Union[List[str],
List[ModelClient.ModelClientResponseProtocol.Choice.Message]]
Retrieve and return a list of strings or a list of Choice.Message from the response.
NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
get_usage
@staticmethod
def get_usage(response: ModelClientResponseProtocol) -> Dict
Return usage summary of the response using RESPONSE_USAGE_KEYS.
OpenAIClient
class OpenAIClient()
Follows the Client protocol and wraps the OpenAI client.
message_retrieval
def message_retrieval(
response: Union[ChatCompletion, Completion]
) -> Union[List[str], List[ChatCompletionMessage]]
Retrieve the messages from the response.
create
def create(params: Dict[str, Any]) -> ChatCompletion
Create a completion for a given config using openai's client.
Arguments:
client
- The openai client.params
- The params for the completion.
Returns:
The completion.
cost
def cost(response: Union[ChatCompletion, Completion]) -> float
Calculate the cost of the response.
OpenAIWrapper
class OpenAIWrapper()
A wrapper class for openai client.
__init__
def __init__(*,
config_list: Optional[List[Dict[str, Any]]] = None,
**base_config: Any)
Arguments:
config_list
- a list of config dicts to override the base_config. They can contain additional kwargs as allowed in the create method. E.g.,
config_list=[
{
"model": "gpt-4",
"api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"api_type": "azure",
"base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
"api_version": "2024-02-01",
},
{
"model": "gpt-3.5-turbo",
"api_key": os.environ.get("OPENAI_API_KEY"),
"api_type": "openai",
"base_url": "https://api.openai.com/v1",
},
{
"model": "llama-7B",
"base_url": "http://127.0.0.1:8080",
}
]
base_config
- base config. It can contain both keyword arguments for openai client and additional kwargs. When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either inbase_config
or in each config ofconfig_list
.
register_model_client
def register_model_client(model_client_cls: ModelClient, **kwargs)
Register a model client.
Arguments:
model_client_cls
- A custom client class that follows the ModelClient interface**kwargs
- The kwargs for the custom client class to be initialized with
create
def create(**config: Any) -> ModelClient.ModelClientResponseProtocol
Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. The config in each client will be overridden by the config.
Arguments:
- context (Dict | None): The context to instantiate the prompt or messages. Default to None.
It needs to contain keys that are used by the prompt template or the filter function.
E.g.,
prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}
. The actual prompt will be: "Complete the following sentence: Today I feel". More examples can be found at templating. - cache (AbstractCache | None): A Cache object to use for response cache. Default to None. Note that the cache argument overrides the legacy cache_seed argument: if this argument is provided, then the cache_seed argument is ignored. If this argument is not provided or None, then the cache_seed argument is used.
- agent (AbstractAgent | None): The object responsible for creating a completion if an agent.
- (Legacy) cache_seed (int | None) for using the DiskCache. Default to 41. An integer cache_seed is useful when implementing "controlled randomness" for the completion. None for no caching.
Note
- this is a legacy argument. It is only used when the cache argument is not provided.- filter_func (Callable | None): A function that takes in the context and the response and returns a boolean to indicate whether the response is valid. E.g.,
def yes_or_no_filter(context, response):
return context.get("yes_or_no_choice", False) is False or any(
text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
)
- allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false.
- api_version (str | None): The api version. Default to None. E.g., "2024-02-01".
Raises:
- RuntimeError: If all declared custom model clients are not registered
- APIError: If any model client create call raises an APIError
print_usage_summary
def print_usage_summary(
mode: Union[str, List[str]] = ["actual", "total"]) -> None
Print the usage summary.
clear_usage_summary
def clear_usage_summary() -> None
Clear the usage summary.
extract_text_or_completion_object
@classmethod
def extract_text_or_completion_object(
cls, response: ModelClient.ModelClientResponseProtocol
) -> Union[List[str],
List[ModelClient.ModelClientResponseProtocol.Choice.Message]]
Extract the text or ChatCompletion objects from a completion or chat response.
Arguments:
response
ChatCompletion | Completion - The response from openai.
Returns:
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.