# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from base64 import b64decode
from copy import deepcopy
from http import HTTPMethod
from typing import Union
import nuclio
from jsonpath_ng import parse as jsonpath_parse
from jsonpath_ng.exceptions import JsonPathLexerError, JsonPathParserError
from nuclio import KafkaTrigger
from nuclio.triggers import NuclioTrigger
import mlrun
import mlrun.common.schemas as schemas
import mlrun.common.secrets
import mlrun.datastore.datastore_profile as ds_profile
import mlrun.runtimes.kubejob as kubejob_runtime
import mlrun.runtimes.nuclio.function as nuclio_function
import mlrun.runtimes.pod as pod_runtime
import mlrun.serving.utils as serving_utils
from mlrun.common.schemas.serving import _APIEndpointKeys
from mlrun.datastore import get_kafka_brokers_from_dict, parse_kafka_url
from mlrun.model import ObjectList
from mlrun.runtimes.function_reference import FunctionReference
from mlrun.secrets import SecretsStore
from mlrun.serving.server import (
GraphServer,
add_system_steps_to_graph,
create_graph_server,
)
from mlrun.serving.states import (
RootFlowStep,
RouterStep,
StepKinds,
TaskStep,
graph_root_setter,
new_remote_endpoint,
params_to_step,
)
from mlrun.utils import get_caller_globals, logger, merge_requirements, set_paths
serving_subkind = "serving_v2"
[docs]
class APIHandlerConfig(mlrun.model.ModelObj):
"""Configuration for API handler in serving graph"""
_dict_fields = ["enabled", "endpoints", "body_map", "include_url_info"]
def __init__(
self,
enabled: bool = True,
endpoints: dict[str, dict] | None = None,
body_map: dict[str, str] | None = None,
include_url_info: bool = False,
):
self.enabled = enabled
self._endpoints = endpoints or {}
self._body_map = body_map or {}
self.include_url_info = include_url_info
@property
def body_map(self) -> dict[str, str]:
"""Get the body_map configuration as a dictionary."""
return self._body_map
@body_map.setter
def body_map(self, value: dict[str, str] | None) -> None:
"""Set the body_map configuration from a dictionary."""
self._body_map = {}
if value:
for parameter_name, json_path in value.items():
self.add_body_mapping(parameter_name, json_path)
@property
def endpoints(self) -> dict[str, dict]:
"""Get the endpoints configuration as a dictionary."""
return self._endpoints
@endpoints.setter
def endpoints(self, endpoints: dict[str, dict]) -> None:
"""Set the endpoints configuration from a dictionary."""
self._endpoints = {}
for endpoint_key, config in endpoints.items():
method, path = self._parse_endpoint_key(endpoint_key)
self.add_endpoint_handler(
path=path,
http_method=method,
action=schemas.serving.APIHandlerAction(
config.get(_APIEndpointKeys.ACTION)
),
description=config.get(_APIEndpointKeys.DESCRIPTION),
)
def _parse_endpoint_key(self, endpoint_key: str) -> tuple[HTTPMethod, str]:
"""Parse endpoint key 'METHOD:path' back to method and path components."""
try:
return serving_utils._split_serving_endpoint_key(endpoint_key)
except (ValueError, AttributeError) as e:
raise ValueError(
f"Invalid endpoint key format '{endpoint_key}'. Expected 'METHOD:path'"
) from e
@staticmethod
def _normalize_path(path: str) -> str:
"""Normalize path to ensure it starts with a forward slash.
:param path: URL path to normalize
:return: Normalized path with leading slash
"""
if not path.startswith("/"):
return f"/{path}"
return path
@staticmethod
def _validate_path(path: str) -> None:
"""Validate an endpoint path for structural correctness.
Currently enforces wildcard ``*`` rules:
* ``*`` may only appear once.
* ``*`` must be the last character in the path.
:param path: Normalized path (with leading ``/``) to validate.
:raises mlrun.errors.MLRunValueError: If the path contains an invalid ``*`` pattern.
"""
star_count = path.count("*")
if star_count == 0:
return
# We know there is a wildcard, validate its position and count
if path[-1] != "*":
raise mlrun.errors.MLRunValueError(
f"Invalid endpoint path '{path}': "
f"wildcard '*' must be at the end of the path"
)
if star_count > 1:
raise mlrun.errors.MLRunValueError(
f"Invalid endpoint path '{path}': "
f"wildcard '*' must appear only once at the end of the path"
)
@staticmethod
def _validate_http_method(http_method: HTTPMethod | str) -> HTTPMethod:
"""Validate and normalize the provided HTTP method.
:param http_method: HTTP method to validate (HTTPMethod enum or string)
:return: Normalized HTTPMethod enum value
:raises mlrun.errors.MLRunInvalidArgumentError: If method is not a valid HTTPMethod or string
"""
if isinstance(http_method, HTTPMethod):
return http_method
if isinstance(http_method, str):
try:
return HTTPMethod(http_method.upper())
except ValueError:
raise mlrun.errors.MLRunInvalidArgumentError(
f"Invalid HTTP method string '{http_method}'. "
f"Valid values are: {', '.join(m.value for m in HTTPMethod)}"
) from None
# Not HTTPMethod or str - reject with helpful error
raise mlrun.errors.MLRunInvalidArgumentError(
f"http_method must be an HTTPMethod enum or string, got {type(http_method).__name__} "
f"with value '{http_method}'. Valid values are: {', '.join(m.value for m in HTTPMethod)}"
)
[docs]
def get_endpoint_config(self, method: HTTPMethod | str, path: str) -> dict | None:
"""Get endpoint configuration for a specific method and path."""
method = self._validate_http_method(method)
path = self._normalize_path(path)
endpoint_key = serving_utils._combine_serving_endpoint_key(method, path)
return self._endpoints.get(endpoint_key)
[docs]
def add_endpoint_handler(
self,
path: str,
http_method: HTTPMethod | str = HTTPMethod.POST,
action: schemas.serving.APIHandlerAction = schemas.serving.APIHandlerAction.ALLOW,
description: str | None = None,
) -> None:
"""Add an endpoint handler configuration.
:param path: URL path for the endpoint (e.g., ``/v1/models`` or ``/api/v1/*``)
:param http_method: HTTP method for the endpoint (``HTTPMethod`` enum or string like ``"GET"``, ``"POST"``)
:param action: Action to take for this endpoint (:py:class:`~mlrun.common.schemas.serving.APIHandlerAction`)
:param description: Optional description of the endpoint
:raises mlrun.errors.MLRunValueError: If the path contains an invalid wildcard ``*`` pattern
"""
http_method = self._validate_http_method(http_method)
path = self._normalize_path(path)
self._validate_path(path)
endpoint_key = serving_utils._combine_serving_endpoint_key(http_method, path)
# Warn if overriding an existing endpoint
if endpoint_key in self._endpoints:
logger.warning(
"Overriding existing endpoint handler configuration",
method=http_method.value,
path=path,
old_action=self._endpoints[endpoint_key].get(_APIEndpointKeys.ACTION),
new_action=str(action),
)
self._endpoints[endpoint_key] = {
_APIEndpointKeys.ACTION: str(action),
_APIEndpointKeys.DESCRIPTION: description,
}
[docs]
def remove_endpoint_handler(
self,
path: str,
http_method: HTTPMethod | str = HTTPMethod.POST,
) -> None:
"""Remove an endpoint handler configuration.
:param path: URL path for the endpoint to remove
:param http_method: HTTP method for the endpoint to remove (`HTTPMethod` enum or string like
``'GET'``, ``'POST'``)
"""
http_method = self._validate_http_method(http_method)
path = self._normalize_path(path)
endpoint_key = serving_utils._combine_serving_endpoint_key(http_method, path)
self._endpoints.pop(endpoint_key, None)
[docs]
def add_body_mapping(self, parameter_name: str, json_path: str) -> None:
"""Add a JSONPath body mapping for extracting request parameters.
Maps a JSONPath expression to a parameter name. When a request is received,
the JSONPath will be evaluated against the request body and the result
will be passed as a named parameter to the handler function.
:param parameter_name: Name of the parameter to pass to the handler
:param json_path: JSONPath expression to extract the value from request body
(e.g., ``'$.user.name'`` or ``'$.items[*].id'``)
:raises mlrun.errors.MLRunValueError: If json_path is not a valid JSONPath expression
Example::
config = APIHandlerConfig()
config.add_body_mapping("user_name", "$.user.name")
config.add_body_mapping("user_email", "$.user.contact.email")
config.add_body_mapping(
"item_ids", "$.items[*].id"
) # Multiple matches return list
"""
# Validate JSONPath expression by parsing it
try:
jsonpath_parse(json_path)
except (JsonPathLexerError, JsonPathParserError) as exc:
raise mlrun.errors.MLRunValueError(
f"Invalid JSON path expression for parameter '{parameter_name}': "
f"'{json_path}'. Error: {exc}"
) from exc
# Warn if overriding an existing mapping
if parameter_name in self._body_map:
logger.warning(
"Overriding existing body mapping",
parameter_name=parameter_name,
old_json_path=self._body_map[parameter_name],
new_json_path=json_path,
)
self._body_map[parameter_name] = json_path
[docs]
def remove_body_mapping(self, parameter_name: str) -> None:
"""Remove a body mapping by parameter name.
:param parameter_name: Name of the parameter mapping to remove
"""
self._body_map.pop(parameter_name, None)
def new_v2_model_server(
name,
model_class: str,
models: dict | None = None,
filename="",
protocol="",
image="",
endpoint="",
workers=8,
canary=None,
):
f = ServingRuntime()
if not image:
name, spec, code = nuclio.build_file(
filename, name=name, handler="handler", kind=serving_subkind
)
f.spec.base_spec = spec
f.metadata.name = name
f.spec.default_class = model_class
params = None
if protocol:
params = {"protocol": protocol}
if models:
for name, model_path in models.items():
f.add_model(name, model_path=model_path, parameters=params)
f.with_http(workers, host=endpoint, canary=canary)
if image:
f.from_image(image)
return f
class ServingSpec(nuclio_function.NuclioSpec):
_dict_fields = nuclio_function.NuclioSpec._dict_fields + [
"graph",
"load_mode",
"graph_initializer",
"function_refs",
"parameters",
"models",
"default_content_type",
"error_stream",
"default_class",
"secret_sources",
"track_models",
"streaming",
"api_handler_config",
]
def __init__(
self,
command=None,
args=None,
image=None,
mode=None,
entry_points=None,
description=None,
replicas=None,
min_replicas=None,
max_replicas=None,
volumes=None,
volume_mounts=None,
env=None,
resources=None,
config=None,
base_spec=None,
no_cache=None,
source=None,
image_pull_policy=None,
function_kind=None,
service_account=None,
readiness_timeout=None,
readiness_timeout_before_failure=None,
models=None,
graph=None,
parameters=None,
default_class=None,
load_mode=None,
build=None,
function_refs=None,
graph_initializer=None,
error_stream=None,
track_models=None,
secret_sources=None,
default_content_type=None,
node_name=None,
node_selector=None,
affinity=None,
disable_auto_mount=False,
priority_class_name=None,
default_handler=None,
pythonpath=None,
workdir=None,
image_pull_secret=None,
tolerations=None,
preemption_mode=None,
security_context=None,
service_type=None,
add_templated_ingress_host_mode=None,
state_thresholds=None,
disable_default_http_trigger=None,
custom_scaling_metric_specs=None,
model_endpoint_creation_task_name=None,
serving_spec=None,
auth=None,
streaming: bool | None = None,
api_handler_config: APIHandlerConfig | None = None,
env_from=None,
):
super().__init__(
command=command,
args=args,
image=image,
mode=mode,
entry_points=entry_points,
description=description,
replicas=replicas,
min_replicas=min_replicas,
max_replicas=max_replicas,
volumes=volumes,
volume_mounts=volume_mounts,
env=env,
env_from=env_from,
resources=resources,
config=config,
base_spec=base_spec,
no_cache=no_cache,
source=source,
image_pull_policy=image_pull_policy,
function_kind=serving_subkind,
service_account=service_account,
readiness_timeout=readiness_timeout,
readiness_timeout_before_failure=readiness_timeout_before_failure,
build=build,
node_name=node_name,
node_selector=node_selector,
affinity=affinity,
disable_auto_mount=disable_auto_mount,
priority_class_name=priority_class_name,
default_handler=default_handler,
pythonpath=pythonpath,
workdir=workdir,
image_pull_secret=image_pull_secret,
tolerations=tolerations,
preemption_mode=preemption_mode,
security_context=security_context,
service_type=service_type,
add_templated_ingress_host_mode=add_templated_ingress_host_mode,
disable_default_http_trigger=disable_default_http_trigger,
custom_scaling_metric_specs=custom_scaling_metric_specs,
serving_spec=serving_spec,
auth=auth,
)
self.models = models or {}
self._graph = None
self.graph: Union[RouterStep, RootFlowStep] = graph
self.parameters = parameters or {}
self.default_class = default_class
self.load_mode = load_mode
self._function_refs: ObjectList = None
self.function_refs = function_refs or []
self.graph_initializer = graph_initializer
self.error_stream = error_stream
self.track_models = track_models
self.secret_sources = secret_sources or []
self.default_content_type = default_content_type
self.model_endpoint_creation_task_name = model_endpoint_creation_task_name
self.streaming = streaming
self.api_handler_config = (
api_handler_config.to_dict()
if isinstance(api_handler_config, APIHandlerConfig)
else api_handler_config
)
@property
def graph(self) -> Union[RouterStep, RootFlowStep]:
"""states graph, holding the serving workflow/DAG topology"""
return self._graph
@graph.setter
def graph(self, graph):
graph_root_setter(self, graph)
@property
def function_refs(self) -> list[FunctionReference]:
"""function references, list of optional child function refs"""
return self._function_refs
@function_refs.setter
def function_refs(self, function_refs: list[FunctionReference]):
self._function_refs = ObjectList.from_list(FunctionReference, function_refs)
[docs]
class ServingRuntime(nuclio_function.RemoteRuntime):
"""MLRun Serving Runtime"""
kind = "serving"
@property
def spec(self) -> ServingSpec:
return self._spec
@spec.setter
def spec(self, spec):
self._spec = self._verify_dict(spec, "spec", ServingSpec)
[docs]
def set_topology(
self,
topology=None,
class_name=None,
engine=None,
exist_ok=False,
allow_cyclic: bool = False,
max_iterations: int | None = None,
**class_args,
) -> Union[RootFlowStep, RouterStep]:
"""set the serving graph topology (router/flow) and root class or params
examples::
# simple model router topology
graph = fn.set_topology("router")
fn.add_model(name, class_name="ClassifierModel", model_path=model_uri)
# async flow topology
graph = fn.set_topology("flow", engine="async")
graph.to("MyClass").to(name="to_json", handler="json.dumps").respond()
topology options are::
router - root router + multiple child route states/models
route is usually determined by the path (route key/name)
can specify special router class and router arguments
flow - workflow (DAG) with a chain of states
flow supports both "sync" and "async" engines, with "async" being the default.
Branches are not allowed in sync mode.
when using async mode calling state.respond() will mark the state as the
one which generates the (REST) call response
:param topology: - graph topology, router or flow
:param class_name: - optional for router, router class name/path or router object
:param engine: - optional for flow, sync or async engine
:param exist_ok: - allow overriding existing topology
:param allow_cyclic: - allow cyclic graphs (only for async flow)
:param max_iterations: - optional, max iterations for cyclic graphs (only for async flow), default 100
:param class_args: - optional, router/flow class init args
:return: graph object (fn.spec.graph)
"""
topology = topology or StepKinds.router
if self.spec.graph and not exist_ok:
raise mlrun.errors.MLRunInvalidArgumentError(
"graph topology is already set, graph was initialized, use exist_ok=True to override"
)
if allow_cyclic and topology == StepKinds.router:
raise mlrun.errors.MLRunInvalidArgumentError(
"cyclic graphs are only supported in flow topology with async engine"
)
if topology == StepKinds.router:
if class_name and hasattr(class_name, "to_dict"):
_, step = params_to_step(class_name, None)
if step.kind != StepKinds.router:
raise mlrun.errors.MLRunInvalidArgumentError(
"provided class is not a router step, must provide a router class in router topology"
)
else:
step = RouterStep(class_name=class_name, class_args=class_args)
self.spec.graph = step
elif topology == StepKinds.flow:
self.spec.graph = RootFlowStep(
engine=engine or "async",
allow_cyclic=allow_cyclic,
max_iterations=max_iterations,
)
self.spec.graph.track_models = self.spec.track_models
else:
raise mlrun.errors.MLRunInvalidArgumentError(
f"unsupported topology {topology}, use 'router' or 'flow'"
)
return self.spec.graph
[docs]
def set_tracking(
self,
stream_path: str | None = None,
sampling_percentage: float = 100,
stream_args: dict | None = None,
enable_tracking: bool = True,
) -> None:
"""Apply on your serving function to monitor a deployed model, including real-time dashboards to detect drift
and analyze performance.
:param stream_path: Path/url of the tracking stream e.g. v3io:///users/mike/mystream
you can use the "dummy://" path for test/simulation.
:param sampling_percentage: Down sampling events that will be pushed to the monitoring stream based on
a specified percentage. e.g. 50 for 50%. By default, all events are pushed.
:param stream_args: Stream initialization parameters, e.g. shards, retention_in_hours, ..
:param enable_tracking: Enabled/Disable model-monitoring tracking. Default True (tracking enabled).
Example::
# initialize a new serving function
serving_fn = mlrun.import_function(
"hub://v2-model-server", new_name="serving"
)
# apply model monitoring
serving_fn.set_tracking()
"""
# Applying model monitoring configurations
self.spec.track_models = enable_tracking
if self.spec.graph and isinstance(self.spec.graph, RootFlowStep):
self.spec.graph.track_models = enable_tracking
if self._spec and self._spec.function_refs:
logger.debug(
"Set tracking for children references", enable_tracking=enable_tracking
)
for name in self._spec.function_refs.keys():
self._spec.function_refs[name].track_models = enable_tracking
# Check if function_refs _function is filled if so update track_models field:
if self._spec.function_refs[name]._function:
self._spec.function_refs[
name
]._function.spec.track_models = enable_tracking
if self._spec.function_refs[
name
]._function.spec.graph and isinstance(
self._spec.function_refs[name]._function.spec.graph,
RootFlowStep,
):
self._spec.function_refs[
name
]._function.spec.graph.track_models = enable_tracking
if not 0 < sampling_percentage <= 100:
raise mlrun.errors.MLRunInvalidArgumentError(
"`sampling_percentage` must be greater than 0 and less or equal to 100."
)
self.spec.parameters["sampling_percentage"] = sampling_percentage
if stream_path:
self.spec.parameters["log_stream"] = stream_path
if stream_args:
self.spec.parameters["stream_args"] = stream_args
[docs]
def set_streaming(self, enabled: bool = True) -> None:
"""Enable or disable streaming mode for the serving function.
When streaming is enabled, the function handler yields results as they
arrive from streaming steps in the graph, allowing for real-time
streaming responses (e.g., for LLM token streaming).
Streaming is only supported with HTTP triggers. When streaming is enabled,
non-HTTP triggers cannot be added to the function.
:param enabled: Enable or disable streaming mode. Default is True.
Example::
# Create a serving function with streaming enabled
serving_fn = mlrun.code_to_function(kind="serving")
serving_fn.set_topology("flow", engine="async")
serving_fn.set_streaming(enabled=True)
"""
# Validate that only HTTP triggers are configured when enabling streaming
if enabled:
for key, trigger_spec in self.spec.config.items():
if key.startswith("spec.triggers."):
trigger_name = key.split(".")[-1]
trigger_kind = trigger_spec.get("kind", "http")
if trigger_kind != "http":
raise mlrun.errors.MLRunInvalidArgumentError(
f"Streaming is only supported with HTTP triggers. "
f"Found non-HTTP trigger '{trigger_name}' of kind '{trigger_kind}'. "
f"Remove non-HTTP triggers before enabling streaming."
)
self.spec.streaming = enabled
[docs]
def add_trigger(self, name: str, spec: NuclioTrigger | dict):
"""Add a nuclio trigger object/dict.
Overrides parent to validate streaming compatibility.
:param name: trigger name
:param spec: trigger object or dict
"""
# Validate streaming compatibility
if self.spec.streaming:
trigger_spec = spec.to_dict() if hasattr(spec, "to_dict") else spec
trigger_kind = trigger_spec.get("kind", "http")
if trigger_kind != "http":
raise mlrun.errors.MLRunInvalidArgumentError(
f"Cannot add non-HTTP trigger '{name}' (kind='{trigger_kind}') "
f"when streaming is enabled. Streaming only supports HTTP triggers. "
f"Either disable streaming with set_streaming(False) or use HTTP triggers only."
)
return super().add_trigger(name, spec)
[docs]
def add_model(
self,
key: str,
model_path: str | None = None,
class_name: str | None = None,
model_url: str | None = None,
handler: str | None = None,
router_step: str | None = None,
child_function: str | None = None,
creation_strategy: schemas.ModelEndpointCreationStrategy
| None = schemas.ModelEndpointCreationStrategy.INPLACE,
outputs: list[str] | None = None,
**class_args,
):
"""Add ml model and/or route to the function.
Example, create a function (from the notebook), add a model class, and deploy::
fn = code_to_function(kind="serving")
fn.add_model("boost", model_path, model_class="MyClass", my_arg=5)
fn.deploy()
Only works with router topology. For nested topologies (model under router under flow)
need to add router to flow and use router.add_route()
:param key: model api key (or name:version), will determine the relative url/path
:param model_path: path to mlrun model artifact or model directory file/object path
:param class_name: V2 Model python class name or a model class instance
(can also module.submodule.class and it will be imported automatically)
:param model_url: url of a remote model serving endpoint (cannot be used with model_path)
:param handler: for advanced users!, override default class handler name (do_event)
:param router_step: router step name (to determine which router we add the model to in graphs
with multiple router steps)
:param child_function: child function name, when the model runs in a child function
:param creation_strategy: Strategy for creating or updating the model endpoint:
* **overwrite**: If model endpoints with the same name exist, delete the `latest`
one. Create a new model endpoint entry and set it as `latest`.
* **inplace** (default): If model endpoints with the same name exist, update the
`latest` entry. Otherwise, create a new entry.
* **archive**: If model endpoints with the same name exist, preserve them.
Create a new model endpoint with the same name and set it to `latest`.
:param outputs: list of the model outputs (e.g. labels), if provided will override the outputs that were
configured in the model artifact. Note that those outputs need to be equal to the
model serving function outputs (length, and order).
:param class_args: extra kwargs to pass to the model serving class __init__
(can be read in the model using .get_param(key) method)
"""
graph = self.spec.graph
if not graph:
graph = self.set_topology()
if graph.kind != StepKinds.router:
if router_step:
if router_step not in graph:
raise ValueError(
f"router step {router_step} not present in the graph"
)
graph = graph[router_step]
else:
routers = [
step
for step in graph.steps.values()
if step.kind == StepKinds.router
]
if len(routers) == 0:
raise ValueError(
"graph does not contain any router, add_model can only be "
"used when there is a router step"
)
if len(routers) > 1:
raise ValueError(
f"found {len(routers)} routers, please specify the router_step"
" you would like to add this model to"
)
graph = routers[0]
if class_name and hasattr(class_name, "to_dict"):
if model_path:
class_name.model_path = model_path
if outputs:
class_name.outputs = outputs
key, state = params_to_step(
class_name,
key,
model_endpoint_creation_strategy=creation_strategy,
endpoint_type=schemas.EndpointType.LEAF_EP,
)
else:
class_name = class_name or self.spec.default_class
if class_name and not isinstance(class_name, str):
raise ValueError(
"class name must be a string (name of module.submodule.name)"
)
if model_path and not class_name:
raise ValueError("model_path must be provided with class_name")
if model_path:
model_path = str(model_path)
if model_url:
state = new_remote_endpoint(
model_url,
creation_strategy=creation_strategy,
endpoint_type=schemas.EndpointType.LEAF_EP,
**class_args,
)
else:
class_args = deepcopy(class_args)
class_args["model_path"] = model_path
class_args["outputs"] = outputs
state = TaskStep(
class_name,
class_args,
name=key,
handler=handler,
function=child_function,
model_endpoint_creation_strategy=creation_strategy,
endpoint_type=schemas.EndpointType.LEAF_EP,
)
return graph.add_route(key, state)
[docs]
def add_child_function(
self, name, url=None, image=None, requirements=None, kind=None
):
"""in a multi-function pipeline add child function
example::
fn.add_child_function("enrich", "./enrich.ipynb", "mlrun/mlrun")
:param name: child function name
:param url: function/code url, support .py, .ipynb, .yaml extensions
:param image: base docker image for the function
:param requirements: py package requirements file path OR list of packages
:param kind: mlrun function/runtime kind
:return: function object
"""
function_reference = FunctionReference(
url,
image,
requirements=requirements,
kind=kind or "serving",
track_models=self.spec.track_models,
)
self._spec.function_refs.update(function_reference, name)
func = function_reference.to_function(self.kind)
return func
def _add_ref_triggers(self):
"""add stream trigger to downstream child functions"""
for function_name, stream in self.spec.graph.get_queue_links().items():
if stream.path:
if function_name not in self._spec.function_refs.keys():
raise ValueError(f"function reference {function_name} not present")
group = stream.options.get("group", f"{function_name}-consumer-group")
child_function = self._spec.function_refs[function_name]
trigger_args = stream.trigger_args or {}
engine = self.spec.graph.engine or "async"
if mlrun.mlconf.is_explicit_ack_enabled() and engine == "async":
trigger_args["explicit_ack_mode"] = trigger_args.get(
"explicit_ack_mode", "explicitOnly"
)
extra_attributes = trigger_args.get("extra_attributes", {})
trigger_args["extra_attributes"] = extra_attributes
extra_attributes["worker_allocation_mode"] = extra_attributes.get(
"worker_allocation_mode", "static"
)
brokers = get_kafka_brokers_from_dict(stream.options)
if stream.path.startswith("kafka://") or brokers:
if brokers:
brokers = brokers.split(",")
topic, brokers = parse_kafka_url(stream.path, brokers)
trigger = KafkaTrigger(
brokers=brokers,
topics=[topic],
consumer_group=group,
**trigger_args,
)
child_function.function_object.add_trigger("kafka", trigger)
else:
# V3IO doesn't allow hyphens in object names
group = group.replace("-", "_")
child_function.function_object.add_v3io_stream_trigger(
stream.path, group=group, shards=stream.shards, **trigger_args
)
def _deploy_function_refs(self, builder_env: dict | None = None):
"""set metadata and deploy child functions"""
for function_ref in self._spec.function_refs.values():
logger.info(f"deploy child function {function_ref.name} ...")
function_object = function_ref.function_object
if not function_object:
function_object = function_ref.to_function(self.kind)
function_object.metadata.name = function_ref.fullname(self)
function_object.metadata.project = self.metadata.project
function_object.metadata.tag = self.metadata.tag
function_object.metadata.labels = function_object.metadata.labels or {}
function_object.metadata.labels["mlrun/parent-function"] = (
self.metadata.name
)
function_object._is_child_function = True
if not function_object.spec.graph:
# copy the current graph only if the child doesnt have a graph of his own
function_object.set_env("SERVING_CURRENT_FUNCTION", function_ref.name)
function_object.spec.graph = self.spec.graph
function_object.verbose = self.verbose
function_object.spec.secret_sources = self.spec.secret_sources
function_object.deploy(builder_env=builder_env)
[docs]
def remove_states(self, keys: list):
"""remove one, multiple, or all states/models from the spec (blank list for all)"""
if self.spec.graph:
self.spec.graph.clear_children(keys)
[docs]
def with_secrets(self, kind, source):
"""register a secrets source (file, env or dict)
read secrets from a source provider to be used in workflows, example::
task.with_secrets('file', 'file.txt')
task.with_secrets('inline', {'key': 'val'})
task.with_secrets('env', 'ENV1,ENV2')
task.with_secrets('vault', ['secret1', 'secret2'...])
# If using an empty secrets list [] then all accessible secrets will be available.
task.with_secrets('vault', [])
# To use with Azure key vault, a k8s secret must be created with the following keys:
# kubectl -n <namespace> create secret generic azure-key-vault-secret \\
# --from-literal=tenant_id=<service principal tenant ID> \\
# --from-literal=client_id=<service principal client ID> \\
# --from-literal=secret=<service principal secret key>
task.with_secrets('azure_vault', {
'name': 'my-vault-name',
'k8s_secret': 'azure-key-vault-secret',
# An empty secrets list may be passed ('secrets': []) to access all vault secrets.
'secrets': ['secret1', 'secret2'...]
})
:param kind: secret type (file, inline, env)
:param source: secret data or link (see example)
:returns: The Runtime (function) object
"""
if kind == "azure_vault" and isinstance(source, dict):
candidate_secret_name = (source.get("k8s_secret") or "").strip()
if candidate_secret_name:
mlrun.common.secrets.validate_not_forbidden_secret(
candidate_secret_name
)
if kind == "vault" and isinstance(source, list):
source = {"project": self.metadata.project, "secrets": source}
self.spec.secret_sources.append({"kind": kind, "source": source})
return self
@nuclio_function.min_nuclio_versions("1.12.10")
def deploy(
self,
project="",
tag="",
verbose=False,
builder_env: dict | None = None,
force_build: bool = False,
):
"""deploy model serving function to a local/remote cluster
:param project: optional, override function specified project name
:param tag: specify unique function tag (a different function service is created for every tag)
:param verbose: verbose logging
:param builder_env: env vars dict for source archive config/credentials e.g. builder_env={"GIT_TOKEN": token}
:param force_build: set True for force building the image
"""
load_mode = self.spec.load_mode
if load_mode and load_mode not in ["sync", "async"]:
raise ValueError(f"illegal model loading mode {load_mode}")
if not self.spec.graph:
raise ValueError("nothing to deploy, .spec.graph is none, use .add_model()")
if self.spec.graph.kind != StepKinds.router and not getattr(
self, "_is_child_function", None
):
# initialize or create required streams/queues
self.spec.graph.check_and_process_graph()
self.spec.graph.create_queue_streams()
functions_in_steps = self.spec.graph.list_child_functions()
child_functions = list(self._spec.function_refs.keys())
for function in functions_in_steps:
if function not in child_functions:
raise mlrun.errors.MLRunInvalidArgumentError(
f"function {function} is used in steps and is not defined, "
"use the .add_child_function() to specify child function attributes"
)
if (
isinstance(self.spec.graph, RootFlowStep)
and any(
isinstance(step_type, mlrun.serving.states.ModelRunnerStep)
for step_type in self.spec.graph.steps.values()
)
and self.spec.build.functionSourceCode
):
# Add import for LLModel
decoded_code = b64decode(self.spec.build.functionSourceCode).decode("utf-8")
import_llmodel_code = "\nfrom mlrun.serving.states import LLModel\n"
if import_llmodel_code not in decoded_code:
decoded_code += import_llmodel_code
encoded_code = mlrun.utils.helpers.encode_user_code(decoded_code)
self.spec.build.functionSourceCode = encoded_code
# Handle secret processing before handling child functions, since secrets are transferred to them
if self.spec.secret_sources:
# Before passing to remote builder, secrets values must be retrieved (for example from ENV)
# and stored as inline secrets. Otherwise, they will not be available to the builder.
self._secrets = SecretsStore.from_list(self.spec.secret_sources)
self.spec.secret_sources = self._secrets.to_serial()
if self._spec.function_refs:
# ensure the function is available to the UI while deploying the child functions
self.save(versioned=False)
# deploy child functions
self._add_ref_triggers()
self._deploy_function_refs()
logger.info(f"deploy root function {self.metadata.name} ...")
self._add_steps_requirements()
return super().deploy(
project,
tag,
verbose,
builder_env=builder_env,
force_build=force_build,
)
def _get_serving_spec(self):
function_name_uri_map = {f.name: f.uri(self) for f in self.spec.function_refs}
serving_spec = {
"function_name": self.metadata.name,
"function_tag": self.metadata.tag,
"function_uri": self._function_uri(),
"function_hash": self.metadata.hash,
"project": self.metadata.project,
"version": "v2",
"parameters": self.spec.parameters,
"graph": self.spec.graph.to_dict(strip=True) if self.spec.graph else {},
"load_mode": self.spec.load_mode,
"functions": function_name_uri_map,
"graph_initializer": self.spec.graph_initializer,
"error_stream": self.spec.error_stream,
"track_models": self.spec.track_models,
"default_content_type": self.spec.default_content_type,
"model_endpoint_creation_task_name": self.spec.model_endpoint_creation_task_name,
"streaming": self.spec.streaming,
# TODO: find another way to pass this (needed for local run)
"filename": getattr(self.spec, "filename", None),
}
# Include API handler config if present
if self.spec.api_handler_config:
serving_spec["api_handler_config"] = self.spec.api_handler_config
if self.spec.secret_sources:
self._secrets = SecretsStore.from_list(self.spec.secret_sources)
serving_spec["secret_sources"] = self._secrets.to_serial()
return json.dumps(serving_spec)
@property
def serving_spec(self):
return self._get_serving_spec()
[docs]
def to_mock_server(
self,
namespace=None,
current_function="*",
track_models=False,
workdir=None,
stream_profile: ds_profile.DatastoreProfile | None = None,
**kwargs,
) -> GraphServer:
"""create mock server object for local testing/emulation
:param namespace: one or list of namespaces/modules to search the steps classes/functions in
:param current_function: specify if you want to simulate a child function, * for all functions
:param track_models: allow model tracking (disabled by default in the mock server)
:param workdir: working directory to locate the source code (if not the current one)
:param stream_profile: stream profile to use for the mock server output stream.
"""
# set the namespaces/modules to look for the steps code in
namespace = namespace or []
if not isinstance(namespace, list):
namespace = [namespace]
module = mlrun.run.function_to_module(self, silent=True, workdir=workdir)
if module:
namespace.append(module)
namespace.append(get_caller_globals())
if workdir:
old_workdir = os.getcwd()
workdir = os.path.realpath(workdir)
set_paths(workdir)
os.chdir(workdir)
server = create_graph_server(
parameters=self.spec.parameters,
load_mode=self.spec.load_mode,
graph=self.spec.graph,
verbose=self.verbose,
current_function=current_function,
graph_initializer=self.spec.graph_initializer,
track_models=self.spec.track_models,
function_uri=self._function_uri(),
secret_sources=self.spec.secret_sources,
default_content_type=self.spec.default_content_type,
function_name=self.metadata.name,
function_tag=self.metadata.tag,
project=self.metadata.project,
api_handler_config=self.spec.api_handler_config,
**kwargs,
)
server.streaming = self.spec.streaming
server.init_states(
context=None,
namespace=namespace,
logger=logger,
is_mock=True,
monitoring_mock=self.spec.track_models,
stream_profile=stream_profile,
)
server.graph = add_system_steps_to_graph(
server.project,
deepcopy(server.graph),
self.spec.track_models,
server.context,
self.spec,
)
# Update context.root to point to the new graph
server.context.root = server.graph
if workdir:
os.chdir(old_workdir)
server.init_object(namespace)
return server
[docs]
def plot(self, filename=None, format=None, source=None, **kw):
"""plot/save graph using graphviz
example::
serving_fn = mlrun.new_function(
"serving", image="mlrun/mlrun", kind="serving"
)
serving_fn.add_model(
"my-classifier",
model_path=model_path,
class_name="mlrun.frameworks.sklearn.SKLearnModelServer",
)
serving_fn.plot(rankdir="LR")
:param filename: target filepath for the image (None for the notebook)
:param format: The output format used for rendering (``'pdf'``, ``'png'``, etc.)
:param source: source step to add to the graph
:param kw: kwargs passed to graphviz, e.g. rankdir="LR" (see: https://graphviz.org/doc/info/attrs.html)
:return: graphviz graph object
"""
return self.spec.graph.plot(filename, format=format, source=source, **kw)
def _set_as_mock(self, enable):
if not enable:
self._mock_server = None
return
logger.info(
"Deploying serving function MOCK (for simulation)...\n"
"Turn off the mock (mock=False) and make sure Nuclio is installed for real deployment to Nuclio"
)
self._mock_server = self.to_mock_server()
[docs]
def to_job(self, func_name: str | None = None) -> "kubejob_runtime.KubejobRuntime":
"""Convert this ServingRuntime to a KubejobRuntime, so that the graph can be run as a standalone job.
:param func_name: Optional custom name for the job function. If not provided, automatically
appends '-batch' suffix to the serving function name to prevent database collision.
:return: KubejobRuntime configured to execute the serving graph as a batch job.
Note:
The job will have a different name than the serving function to prevent database collision.
The original serving function remains unchanged and can still be invoked after running the job.
"""
if self.spec.function_refs:
raise mlrun.errors.MLRunInvalidArgumentError(
f"Cannot convert function '{self.metadata.name}' to a job because it has child functions"
)
if self.spec.streaming:
raise mlrun.errors.MLRunInvalidArgumentError(
f"Cannot convert function '{self.metadata.name}' to a job because streaming "
f"is enabled. Streaming functions return real-time HTTP responses and cannot "
f"run as batch jobs. Please disable streaming with set_streaming(False) first."
)
self._add_steps_requirements()
spec = pod_runtime.KubeResourceSpec(
image=self.spec.image,
mode=self.spec.mode,
volumes=self.spec.volumes,
volume_mounts=self.spec.volume_mounts,
env=self.spec.env,
resources=self.spec.resources,
default_handler="mlrun.serving.server.execute_graph",
pythonpath=self.spec.pythonpath,
entry_points=self.spec.entry_points,
description=self.spec.description,
workdir=self.spec.workdir,
image_pull_secret=self.spec.image_pull_secret,
build=self.spec.build,
node_name=self.spec.node_name,
node_selector=self.spec.node_selector,
affinity=self.spec.affinity,
disable_auto_mount=self.spec.disable_auto_mount,
priority_class_name=self.spec.priority_class_name,
tolerations=self.spec.tolerations,
preemption_mode=self.spec.preemption_mode,
security_context=self.spec.security_context,
state_thresholds=self.spec.state_thresholds,
serving_spec=self._get_serving_spec(),
track_models=self.spec.track_models,
parameters=self.spec.parameters,
graph=self.spec.graph,
)
job_metadata = deepcopy(self.metadata)
original_name = job_metadata.name
if func_name:
# User provided explicit job name
job_metadata.name = func_name
logger.debug(
"Creating job from serving function with custom name",
new_name=func_name,
)
else:
job_metadata.name, was_renamed, suffix = (
mlrun.utils.helpers.ensure_batch_job_suffix(job_metadata.name)
)
# Check if the resulting name exceeds Kubernetes length limit
if (
len(job_metadata.name)
> mlrun.common.constants.K8S_DNS_1123_LABEL_MAX_LENGTH
):
raise mlrun.errors.MLRunInvalidArgumentError(
f"Cannot convert serving function '{original_name}' to batch job: "
f"the resulting name '{job_metadata.name}' ({len(job_metadata.name)} characters) "
f"exceeds Kubernetes limit of {mlrun.common.constants.K8S_DNS_1123_LABEL_MAX_LENGTH} characters. "
f"Please provide a custom name via the func_name parameter, "
f"with at most {mlrun.common.constants.K8S_DNS_1123_LABEL_MAX_LENGTH} characters."
)
if was_renamed:
logger.info(
"Creating job from serving function (auto-appended suffix to prevent collision)",
new_name=job_metadata.name,
suffix=suffix,
)
else:
logger.debug(
"Creating job from serving function (name already has suffix)",
name=original_name,
suffix=suffix,
)
job = kubejob_runtime.KubejobRuntime(
spec=spec,
metadata=job_metadata,
)
return job
def _add_steps_requirements(self) -> None:
# extract child function name from self.metadata.name if parent label exists
full_name = self.metadata.name
parent_label = (
self.metadata.labels.get("mlrun/parent-function")
if self.metadata.labels
else None
)
current_function = None # only set if current function is a child
if parent_label and full_name.startswith(parent_label + "-"):
current_function = full_name[len(parent_label) + 1 :]
steps = getattr(getattr(self.spec, "graph", {}), "steps", {})
for step in steps.values():
# only add requirements to the function if this step is local to it
if step_requirements := getattr(step, "requirements", []):
if not step._is_local_function(
context=None, current_function=current_function
):
continue
build_reqs = getattr(
getattr(self.spec, "build", {}), "requirements", []
)
reqs_union = merge_requirements(
reqs_priority=build_reqs,
reqs_secondary=step_requirements,
)
self.with_requirements(requirements=reqs_union, overwrite=True)
[docs]
def set_api_handler_config(self, config: Union[APIHandlerConfig, dict]) -> None:
"""Set the API handler configuration for the serving function.
:param config: :py:class:`~mlrun.runtimes.nuclio.serving.APIHandlerConfig` object or dictionary containing
the configuration for handling different API endpoints and their actions.
Example::
# Using APIHandlerConfig object
from mlrun.runtimes.nuclio.serving import APIHandlerConfig
from mlrun.common.schemas.serving import APIHandlerAction
from http import HTTPMethod
api_config = APIHandlerConfig()
api_config.add_endpoint_handler(
"/v1/models", HTTPMethod.GET, APIHandlerAction.ALLOW
)
serving_fn.set_api_handler_config(api_config)
# Using dictionary
serving_fn.set_api_handler_config(
{"endpoints": {("GET", "/v1/models"): {"action": "allow"}}}
)
"""
if isinstance(config, APIHandlerConfig):
config = config.to_dict()
elif isinstance(config, dict):
# Validate the dict by converting it to APIHandlerConfig and back
# This ensures it has the correct format
try:
validated_config = APIHandlerConfig.from_dict(config)
config = validated_config.to_dict()
except Exception as exc:
raise ValueError(
f"Invalid API handler config dict format: {exc}"
) from exc
else:
raise ValueError(
f"config must be `APIHandlerConfig` or a `dict`, got {type(config)}"
)
# Store the configuration in the spec for serialization
self.spec.api_handler_config = config