Skip to content

OpenAIWrapper

autogen.OpenAIWrapper #

OpenAIWrapper(*, config_list=None, **base_config)

A wrapper class for openai client.

Initialize the OpenAIWrapper.

PARAMETER DESCRIPTION
config_list

a list of config dicts to override the base_config. They can contain additional kwargs as allowed in the create method. E.g.,

    config_list = [
        {
            "model": "gpt-4",
            "api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
            "api_type": "azure",
            "base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
            "api_version": "2024-02-01",
        },
        {
            "model": "gpt-3.5-turbo",
            "api_key": os.environ.get("OPENAI_API_KEY"),
            "base_url": "https://api.openai.com/v1",
        },
        {
            "model": "llama-7B",
            "base_url": "http://127.0.0.1:8080",
        },
    ]

TYPE: Optional[list[dict[str, Any]]] DEFAULT: None

base_config

base config. It can contain both keyword arguments for openai client and additional kwargs. When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in base_config or in each config of config_list.

TYPE: Any DEFAULT: {}

Source code in autogen/oai/client.py
def __init__(
    self,
    *,
    config_list: Optional[list[dict[str, Any]]] = None,
    **base_config: Any,
):
    """Initialize the OpenAIWrapper.

    Args:
        config_list: a list of config dicts to override the base_config.
            They can contain additional kwargs as allowed in the [create](/docs/api-reference/autogen/OpenAIWrapper#autogen.OpenAIWrapper.create) method. E.g.,

            ```python
                config_list = [
                    {
                        "model": "gpt-4",
                        "api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
                        "api_type": "azure",
                        "base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
                        "api_version": "2024-02-01",
                    },
                    {
                        "model": "gpt-3.5-turbo",
                        "api_key": os.environ.get("OPENAI_API_KEY"),
                        "base_url": "https://api.openai.com/v1",
                    },
                    {
                        "model": "llama-7B",
                        "base_url": "http://127.0.0.1:8080",
                    },
                ]
            ```

        base_config: base config. It can contain both keyword arguments for openai client
            and additional kwargs.
            When using OpenAI or Azure OpenAI endpoints, please specify a non-empty 'model' either in `base_config` or in each config of `config_list`.
    """
    if logging_enabled():
        log_new_wrapper(self, locals())
    openai_config, extra_kwargs = self._separate_openai_config(base_config)
    # It's OK if "model" is not provided in base_config or config_list
    # Because one can provide "model" at `create` time.

    self._clients: list[ModelClient] = []
    self._config_list: list[dict[str, Any]] = []

    if config_list:
        config_list = [config.copy() for config in config_list]  # make a copy before modifying
        for config in config_list:
            self._register_default_client(config, openai_config)  # could modify the config
            self._config_list.append({
                **extra_kwargs,
                **{k: v for k, v in config.items() if k not in self.openai_kwargs},
            })
    else:
        self._register_default_client(extra_kwargs, openai_config)
        self._config_list = [extra_kwargs]
    self.wrapper_id = id(self)

extra_kwargs class-attribute instance-attribute #

extra_kwargs = {'agent', 'cache', 'cache_seed', 'filter_func', 'allow_format_str_template', 'context', 'api_version', 'api_type', 'tags', 'price'}

openai_kwargs property #

openai_kwargs

total_usage_summary class-attribute instance-attribute #

total_usage_summary = None

actual_usage_summary class-attribute instance-attribute #

actual_usage_summary = None

wrapper_id instance-attribute #

wrapper_id = id(self)

register_model_client #

register_model_client(model_client_cls, **kwargs)

Register a model client.

PARAMETER DESCRIPTION
model_client_cls

A custom client class that follows the ModelClient interface

TYPE: ModelClient

**kwargs

The kwargs for the custom client class to be initialized with

TYPE: Any DEFAULT: {}

Source code in autogen/oai/client.py
def register_model_client(self, model_client_cls: ModelClient, **kwargs: Any):
    """Register a model client.

    Args:
        model_client_cls: A custom client class that follows the ModelClient interface
        **kwargs: The kwargs for the custom client class to be initialized with
    """
    existing_client_class = False
    for i, client in enumerate(self._clients):
        if isinstance(client, PlaceHolderClient):
            placeholder_config = client.config

            if placeholder_config.get("model_client_cls") == model_client_cls.__name__:
                self._clients[i] = model_client_cls(placeholder_config, **kwargs)
                return
        elif isinstance(client, model_client_cls):
            existing_client_class = True

    if existing_client_class:
        logger.warn(
            f"Model client {model_client_cls.__name__} is already registered. Add more entries in the config_list to use multiple model clients."
        )
    else:
        raise ValueError(
            f'Model client "{model_client_cls.__name__}" is being registered but was not found in the config_list. '
            f'Please make sure to include an entry in the config_list with "model_client_cls": "{model_client_cls.__name__}"'
        )

instantiate classmethod #

instantiate(template, context=None, allow_format_str_template=False)
Source code in autogen/oai/client.py
@classmethod
def instantiate(
    cls,
    template: Optional[Union[str, Callable[[dict[str, Any]], str]]],
    context: Optional[dict[str, Any]] = None,
    allow_format_str_template: Optional[bool] = False,
) -> Optional[str]:
    if not context or template is None:
        return template  # type: ignore [return-value]
    if isinstance(template, str):
        return template.format(**context) if allow_format_str_template else template
    return template(context)

create #

create(**config)

Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. The config in each client will be overridden by the config.

PARAMETER DESCRIPTION
**config

The config for the completion.

TYPE: Any DEFAULT: {}

RAISES DESCRIPTION
RuntimeError

If all declared custom model clients are not registered

APIError

If any model client create call raises an APIError

Source code in autogen/oai/client.py
def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
    """Make a completion for a given config using available clients.
    Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
    The config in each client will be overridden by the config.

    Args:
        **config: The config for the completion.

    Raises:
        RuntimeError: If all declared custom model clients are not registered
        APIError: If any model client create call raises an APIError
    """
    # if ERROR:
    #     raise ERROR
    invocation_id = str(uuid.uuid4())
    last = len(self._clients) - 1
    # Check if all configs in config list are activated
    non_activated = [
        client.config["model_client_cls"] for client in self._clients if isinstance(client, PlaceHolderClient)
    ]
    if non_activated:
        raise RuntimeError(
            f"Model client(s) {non_activated} are not activated. Please register the custom model clients using `register_model_client` or filter them out form the config list."
        )
    for i, client in enumerate(self._clients):
        # merge the input config with the i-th config in the config list
        full_config = {**config, **self._config_list[i]}
        # separate the config into create_config and extra_kwargs
        create_config, extra_kwargs = self._separate_create_config(full_config)
        api_type = extra_kwargs.get("api_type")
        if api_type and api_type.startswith("azure") and "model" in create_config:
            create_config["model"] = create_config["model"].replace(".", "")
        # construct the create params
        params = self._construct_create_params(create_config, extra_kwargs)
        # get the cache_seed, filter_func and context
        cache_seed = extra_kwargs.get("cache_seed", LEGACY_DEFAULT_CACHE_SEED)
        cache = extra_kwargs.get("cache")
        filter_func = extra_kwargs.get("filter_func")
        context = extra_kwargs.get("context")
        agent = extra_kwargs.get("agent")
        price = extra_kwargs.get("price", None)
        if isinstance(price, list):
            price = tuple(price)
        elif isinstance(price, (float, int)):
            logger.warning(
                "Input price is a float/int. Using the same price for prompt and completion tokens. Use a list/tuple if prompt and completion token prices are different."
            )
            price = (price, price)

        total_usage = None
        actual_usage = None

        cache_client = None
        if cache is not None:
            # Use the cache object if provided.
            cache_client = cache
        elif cache_seed is not None:
            # Legacy cache behavior, if cache_seed is given, use DiskCache.
            cache_client = Cache.disk(cache_seed, LEGACY_CACHE_DIR)

        log_cache_seed_value(cache if cache is not None else cache_seed, client=client)

        if cache_client is not None:
            with cache_client as cache:
                # Try to get the response from cache
                key = get_key(
                    {
                        **params,
                        **{"response_format": json.dumps(TypeAdapter(params["response_format"]).json_schema())},
                    }
                    if "response_format" in params and not isinstance(params["response_format"], dict)
                    else params
                )
                request_ts = get_current_ts()

                response: ModelClient.ModelClientResponseProtocol = cache.get(key, None)

                if response is not None:
                    response.message_retrieval_function = client.message_retrieval
                    try:
                        response.cost  # type: ignore [attr-defined]
                    except AttributeError:
                        # update attribute if cost is not calculated
                        response.cost = client.cost(response)
                        cache.set(key, response)
                    total_usage = client.get_usage(response)

                    if logging_enabled():
                        # Log the cache hit
                        # TODO: log the config_id and pass_filter etc.
                        log_chat_completion(
                            invocation_id=invocation_id,
                            client_id=id(client),
                            wrapper_id=id(self),
                            agent=agent,
                            request=params,
                            response=response,
                            is_cached=1,
                            cost=response.cost,
                            start_time=request_ts,
                        )

                    # check the filter
                    pass_filter = filter_func is None or filter_func(context=context, response=response)
                    if pass_filter or i == last:
                        # Return the response if it passes the filter or it is the last client
                        response.config_id = i
                        response.pass_filter = pass_filter
                        self._update_usage(actual_usage=actual_usage, total_usage=total_usage)
                        return response
                    continue  # filter is not passed; try the next config
        try:
            request_ts = get_current_ts()
            response = client.create(params)
        except Exception as e:
            if openai_result.is_successful:
                if APITimeoutError is not None and isinstance(e, APITimeoutError):
                    logger.debug(f"config {i} timed out", exc_info=True)
                    if i == last:
                        raise TimeoutError(
                            "OpenAI API call timed out. This could be due to congestion or too small a timeout value. The timeout can be specified by setting the 'timeout' value (in seconds) in the llm_config (if you are using agents) or the OpenAIWrapper constructor (if you are using the OpenAIWrapper directly)."
                        ) from e
                elif APIError is not None and isinstance(e, APIError):
                    error_code = getattr(e, "code", None)
                    if logging_enabled():
                        log_chat_completion(
                            invocation_id=invocation_id,
                            client_id=id(client),
                            wrapper_id=id(self),
                            agent=agent,
                            request=params,
                            response=f"error_code:{error_code}, config {i} failed",
                            is_cached=0,
                            cost=0,
                            start_time=request_ts,
                        )

                    if error_code == "content_filter":
                        # raise the error for content_filter
                        raise
                    logger.debug(f"config {i} failed", exc_info=True)
                    if i == last:
                        raise
                else:
                    raise
            else:
                raise
        except (
            gemini_InternalServerError,
            gemini_ResourceExhausted,
            anthorpic_InternalServerError,
            anthorpic_RateLimitError,
            mistral_SDKError,
            mistral_HTTPValidationError,
            together_TogetherException,
            groq_InternalServerError,
            groq_RateLimitError,
            groq_APIConnectionError,
            cohere_InternalServerError,
            cohere_TooManyRequestsError,
            cohere_ServiceUnavailableError,
            ollama_RequestError,
            ollama_ResponseError,
            bedrock_BotoCoreError,
            bedrock_ClientError,
            cerebras_AuthenticationError,
            cerebras_InternalServerError,
            cerebras_RateLimitError,
        ):
            logger.debug(f"config {i} failed", exc_info=True)
            if i == last:
                raise
        else:
            # add cost calculation before caching no matter filter is passed or not
            if price is not None:
                response.cost = self._cost_with_customized_price(response, price)
            else:
                response.cost = client.cost(response)
            actual_usage = client.get_usage(response)
            total_usage = actual_usage.copy() if actual_usage is not None else total_usage
            self._update_usage(actual_usage=actual_usage, total_usage=total_usage)

            if cache_client is not None:
                # Cache the response
                with cache_client as cache:
                    cache.set(key, response)

            if logging_enabled():
                # TODO: log the config_id and pass_filter etc.
                log_chat_completion(
                    invocation_id=invocation_id,
                    client_id=id(client),
                    wrapper_id=id(self),
                    agent=agent,
                    request=params,
                    response=response,
                    is_cached=0,
                    cost=response.cost,
                    start_time=request_ts,
                )

            response.message_retrieval_function = client.message_retrieval
            # check the filter
            pass_filter = filter_func is None or filter_func(context=context, response=response)
            if pass_filter or i == last:
                # Return the response if it passes the filter or it is the last client
                response.config_id = i
                response.pass_filter = pass_filter
                return response
            continue  # filter is not passed; try the next config
    raise RuntimeError("Should not reach here.")

print_usage_summary #

print_usage_summary(mode=['actual', 'total'])

Print the usage summary.

Source code in autogen/oai/client.py
def print_usage_summary(self, mode: Union[str, list[str]] = ["actual", "total"]) -> None:
    """Print the usage summary."""
    iostream = IOStream.get_default()

    if isinstance(mode, list):
        if len(mode) == 0 or len(mode) > 2:
            raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
        if "actual" in mode and "total" in mode:
            mode = "both"
        elif "actual" in mode:
            mode = "actual"
        elif "total" in mode:
            mode = "total"

    iostream.send(
        UsageSummaryMessage(
            actual_usage_summary=self.actual_usage_summary, total_usage_summary=self.total_usage_summary, mode=mode
        )
    )

clear_usage_summary #

clear_usage_summary()

Clear the usage summary.

Source code in autogen/oai/client.py
def clear_usage_summary(self) -> None:
    """Clear the usage summary."""
    self.total_usage_summary = None
    self.actual_usage_summary = None

extract_text_or_completion_object classmethod #

extract_text_or_completion_object(response)

Extract the text or ChatCompletion objects from a completion or chat response.

PARAMETER DESCRIPTION
response

The response from openai.

TYPE: ChatCompletion | Completion

RETURNS DESCRIPTION
Union[list[str], list[Message]]

A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.

Source code in autogen/oai/client.py
@classmethod
def extract_text_or_completion_object(
    cls, response: ModelClient.ModelClientResponseProtocol
) -> Union[list[str], list[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
    """Extract the text or ChatCompletion objects from a completion or chat response.

    Args:
        response (ChatCompletion | Completion): The response from openai.

    Returns:
        A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
    """
    return response.message_retrieval_function(response)