Source code for mlrun.datastore.model_provider.model_provider

# Copyright 2025 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from typing import Any, Union

import mlrun.errors
from mlrun.common.types import StrEnum
from mlrun.datastore.remote_client import (
    BaseRemoteClient,
)


class InvokeResponseFormat(StrEnum):
    STRING = "string"
    USAGE = "usage"
    FULL = "full"

    @classmethod
    def is_str_response(cls, invoke_response_format: str) -> bool:
        """
        Returns True if the response key corresponds to a string-based response (not a full generation object).
        """
        return invoke_response_format in {
            cls.USAGE,
            cls.STRING,
        }


class UsageResponseKeys(StrEnum):
    ANSWER = "answer"
    USAGE = "usage"

    @classmethod
    def fields(cls) -> list[str]:
        return [cls.ANSWER, cls.USAGE]


[docs] class ModelProvider(BaseRemoteClient): """ The ModelProvider class is an abstract base for integrating with external model providers, primarily generative AI (GenAI) services. Designed to be subclassed, it defines a consistent interface and shared functionality for tasks such as text generation, embeddings, and invoking fine-tuned models. Subclasses should implement provider-specific logic, including SDK client initialization, model invocation, and custom operations. Key Features: - Establishes a consistent, reusable client management for model provider integrations. - Simplifies GenAI service integration by abstracting common operations. - Reduces duplication through shared components for common tasks. - Holds default invocation parameters (e.g., temperature, max_tokens) to avoid boilerplate code and promote consistency. """ support_async = False supports_streaming = False def __init__( self, parent, kind, name, endpoint="", secrets: dict | None = None, default_invoke_kwargs: dict | None = None, ): super().__init__( parent=parent, name=name, kind=kind, endpoint=endpoint, secrets=secrets ) self.default_invoke_kwargs = default_invoke_kwargs or {} self._client = None self._async_client = None @staticmethod def _validate_and_detect_batch_invocation( messages: Union[list[dict], list[list[dict]]], ) -> bool: """ Validate messages format and detect if this is a batch invocation. :param messages: Either a list of message dicts (single) or list of message lists (batch) :return: True if batch invocation, False if single invocation :raises MLRunInvalidArgumentError: If messages format is invalid (mixed types or strings) """ if not messages or not isinstance(messages, list): raise mlrun.errors.MLRunInvalidArgumentError( "Messages must be a non-empty list of dictionaries or list of lists of dictionaries." ) # Check if user mistakenly passed a list of strings has_str = any(isinstance(item, str) for item in messages) if has_str: raise mlrun.errors.MLRunInvalidArgumentError( "Invalid messages format: list of strings is not supported. " "Messages must be a list of dicts (single invocation) or list of lists of dicts (batch invocation)." ) has_list = any(isinstance(item, list) for item in messages) has_dict = any(isinstance(item, dict) for item in messages) if has_list and has_dict: raise mlrun.errors.MLRunInvalidArgumentError( "Invalid messages format: cannot mix list and dict items. " "Use either all lists for batch invocation or all dicts for single invocation." ) if has_list: return True return False @staticmethod def _extract_string_output(response: Any) -> str: """ Extracts string response from response object """ pass def _response_handler( self, response: Any, invoke_response_format: InvokeResponseFormat = InvokeResponseFormat.FULL, **kwargs, ) -> Union[str, dict, Any]: """ Handles the model response according to the specified response format. :param response: The raw response returned from the model invocation. :param invoke_response_format: Determines how the response should be processed and returned. Options include: - STRING: Return only the main generated content as a string, typically for single-answer responses. - USAGE: Return a dictionary combining the string response with additional metadata or token usage statistics, in this format: {"answer": <string>, "usage": <dict>} - FULL: Return the full raw response object. :param kwargs: Additional parameters that may be required by specific implementations. :return: The processed response in the format specified by `invoke_response_format`. Can be a string, dictionary, or the original response object. """ return None
[docs] def get_client_options(self) -> dict: """ Returns a dictionary containing credentials and configuration options required for client creation. :return: A dictionary with client-specific settings. """ return {}
[docs] def load_client(self) -> None: """ Initialize the SDK client for the model provider and assign it to an instance attribute. Subclasses should override this method to create and configure the provider-specific client. """ raise NotImplementedError("load_client method is not implemented")
[docs] def load_async_client(self) -> Any: raise NotImplementedError("load_async_client method is not implemented")
@property def client(self) -> Any: return self._client @property def model(self) -> str | None: """ Returns the model identifier used by the underlying SDK. :return: A string representing the model ID, or None if not set. """ return self.endpoint
[docs] def get_invoke_kwargs(self, invoke_kwargs) -> dict: kwargs = self.default_invoke_kwargs.copy() kwargs.update(invoke_kwargs) return kwargs
@property def async_client(self) -> Any: if not self.support_async: raise mlrun.errors.MLRunInvalidArgumentError( f"{self.__class__.__name__} does not support async operations" ) return self._async_client
[docs] def custom_invoke(self, operation: Callable | None = None, **invoke_kwargs) -> Any: """ Invokes a model operation from a provider (e.g., OpenAI, Hugging Face, etc.) with the given keyword arguments. Useful for dynamically calling model methods like text generation, chat completions, or image generation. The operation must be a callable that accepts keyword arguments. :param operation: A callable representing the model operation (e.g., a client method). :param invoke_kwargs: Keyword arguments to pass to the operation. :return: The full response returned by the operation. """ raise NotImplementedError("custom_invoke method is not implemented")
[docs] async def async_custom_invoke( self, operation: Callable[..., Awaitable[Any]] | None = None, **invoke_kwargs ) -> Any: """ Asynchronously invokes a model operation from a provider (e.g., OpenAI, Hugging Face, etc.) with the given keyword arguments. The operation must be an async callable (e.g., a method from an async client) that accepts keyword arguments. :param operation: An async callable representing the model operation (e.g., an async_client method). :param invoke_kwargs: Keyword arguments to pass to the operation. :return: The full response returned by the awaited operation. """ raise NotImplementedError("async_custom_invoke is not implemented")
[docs] def invoke( self, messages: Union[list[dict], list[list[dict]]], invoke_response_format: InvokeResponseFormat = InvokeResponseFormat.FULL, **invoke_kwargs, ) -> Union[str, dict[str, Any], Any]: """ Invokes a generative AI model with the provided messages and additional parameters. This method is designed to be a flexible interface for interacting with various generative AI backends (e.g., OpenAI, Hugging Face, etc.). It allows users to send a list of messages (following a standardized format) and receive a response. :param messages: A list of dictionaries representing the conversation history or input messages. Each dictionary should follow the format:: {"role": "system"| "user" | "assistant" ..., "content": "Message content as a string"} Example: .. code-block:: json [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"} ] This format is consistent across all backends. Defaults to None if no messages are provided. :param invoke_response_format: Determines how the model response is returned: - string: Returns only the generated text content from the model output, for single-answer responses only. - usage: Combines the STRING response with additional metadata (token usage), and returns the result in a dictionary. Note: The usage dictionary may contain additional keys depending on the model provider: .. code-block:: json { "answer": "<generated_text>", "usage": { "prompt_tokens": <int>, "completion_tokens": <int>, "total_tokens": <int> } } - full: Returns the full model output. :param invoke_kwargs: Additional keyword arguments to be passed to the underlying model API call. These can include parameters such as temperature, max tokens, etc., depending on the capabilities of the specific backend being used. :return: The invoke result formatted according to the specified invoke_response_format parameter. """ raise NotImplementedError("invoke method is not implemented")
[docs] async def async_invoke( self, messages: Union[list[dict], list[list[dict]]], invoke_response_format=InvokeResponseFormat.FULL, **invoke_kwargs, ) -> Union[str, dict[str, Any], Any]: """ Asynchronously invokes a generative AI model with the provided messages and additional parameters. This method is designed to be a flexible interface for interacting with various generative AI backends (e.g., OpenAI, Hugging Face, etc.). It allows users to send a list of messages (following a standardized format) and receive a response. :param messages: A list of dictionaries representing the conversation history or input messages. Each dictionary should follow the format:: {"role": "system"| "user" | "assistant" ..., "content": "Message content as a string"} Example: .. code-block:: json [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"} ] This format is consistent across all backends. Defaults to None if no messages are provided. :param invoke_response_format: Determines how the model response is returned: - string: Returns only the generated text content from the model output, for single-answer responses only. - usage: Combines the STRING response with additional metadata (token usage), and returns the result in a dictionary. Note: The usage dictionary may contain additional keys depending on the model provider: .. code-block:: json { "answer": "<generated_text>", "usage": { "prompt_tokens": <int>, "completion_tokens": <int>, "total_tokens": <int> } } - full: Returns the full model output. :param invoke_kwargs: Additional keyword arguments to be passed to the underlying model API call. These can include parameters such as temperature, max tokens, etc., depending on the capabilities of the specific backend being used. :return: The invoke result formatted according to the specified invoke_response_format parameter. """ raise NotImplementedError("async_invoke is not implemented")
[docs] def invoke_stream( self, messages: list[dict], **invoke_kwargs, ) -> Generator[str, None, None]: """ Invokes a generative AI model in streaming mode, yielding text tokens as they are generated. :param messages: A list of dictionaries representing the conversation history. Must be a single conversation (not a batch). :param invoke_kwargs: Additional keyword arguments passed to the underlying model API call. :return: A generator yielding text tokens as strings. """ raise NotImplementedError( f"{self.__class__.__name__} does not support streaming" )
[docs] async def async_invoke_stream( self, messages: list[dict], **invoke_kwargs, ) -> AsyncGenerator[str, None]: """ Asynchronously invokes a generative AI model in streaming mode, yielding text tokens as they are generated. :param messages: A list of dictionaries representing the conversation history. Must be a single conversation (not a batch). :param invoke_kwargs: Additional keyword arguments passed to the underlying model API call. :return: An async generator yielding text tokens as strings. """ raise NotImplementedError( f"{self.__class__.__name__} does not support async streaming" ) # yield is needed to make this an async generator function yield