Wrap a message class with a type field to be used in a union type
This is needed for proper serialization and deserialization of messages in a union type.
Source code in autogen/messages/base_message.py
| @export_module("autogen.messages")
def wrap_message(message_cls: type[BaseMessage]) -> type[BaseModel]:
"""Wrap a message class with a type field to be used in a union type
This is needed for proper serialization and deserialization of messages in a union type.
Args:
message_cls (type[BaseMessage]): Message class to wrap
"""
global _message_classes
if not message_cls.__name__.endswith("Message"):
raise ValueError("Message class name must end with 'Message'")
type_name = camel2snake(message_cls.__name__)
type_name = type_name[: -len("_message")]
class WrapperBase(BaseModel):
# these types are generated dynamically so we need to disable the type checker
type: Literal[type_name] = type_name # type: ignore[valid-type]
content: message_cls # type: ignore[valid-type]
def __init__(self, *args: Any, **data: Any):
if set(data.keys()) == {"type", "content"} and "content" in data:
super().__init__(*args, **data)
else:
if "content" in data:
content = data.pop("content")
super().__init__(*args, content=message_cls(*args, **data, content=content), **data)
else:
super().__init__(content=message_cls(*args, **data), **data)
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
self.content.print(f) # type: ignore[attr-defined]
wrapper_cls = create_model(message_cls.__name__, __base__=WrapperBase)
# Preserve the original class's docstring and other attributes
wrapper_cls.__doc__ = message_cls.__doc__
wrapper_cls.__module__ = message_cls.__module__
# Copy any other relevant attributes/metadata from the original class
if hasattr(message_cls, "__annotations__"):
wrapper_cls.__annotations__ = message_cls.__annotations__
_message_classes[type_name] = wrapper_cls
return wrapper_cls
|