Source code for mlrun.serving.server

# Copyright 2018 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["GraphServer", "create_graph_server", "GraphContext", "MockEvent"]

import asyncio
import json
import os
import socket
import sys
import traceback
import uuid
from typing import Optional, Union

import mlrun
from mlrun.config import config
from mlrun.errors import err_to_str
from mlrun.secrets import SecretsStore

from ..datastore import get_stream_pusher
from ..datastore.store_resources import ResourceCache
from ..errors import MLRunInvalidArgumentError
from ..model import ModelObj
from ..utils import create_logger, get_caller_globals, parse_versioned_object_uri
from .states import RootFlowStep, RouterStep, get_function, graph_root_setter
from .utils import event_id_key, event_path_key


class _StreamContext:
    def __init__(self, enabled, parameters, function_uri):
        self.enabled = False
        self.hostname = socket.gethostname()
        self.function_uri = function_uri
        self.output_stream = None
        self.stream_uri = None

        log_stream = parameters.get("log_stream", "")
        stream_uri = config.model_endpoint_monitoring.store_prefixes.default

        if ((enabled and stream_uri) or log_stream) and function_uri:
            self.enabled = True

            project, _, _, _ = parse_versioned_object_uri(
                function_uri, config.default_project
            )

            stream_uri = stream_uri.format(project=project, kind="stream")

            if log_stream:
                stream_uri = log_stream.format(project=project)

            stream_args = parameters.get("stream_args", {})

            self.stream_uri = stream_uri

            self.output_stream = get_stream_pusher(stream_uri, **stream_args)


[docs]class GraphServer(ModelObj): kind = "server" def __init__( self, graph=None, parameters=None, load_mode=None, function_uri=None, verbose=False, version=None, functions=None, graph_initializer=None, error_stream=None, track_models=None, tracking_policy=None, secret_sources=None, default_content_type=None, ): self._graph = None self.graph: Union[RouterStep, RootFlowStep] = graph self.function_uri = function_uri self.parameters = parameters or {} self.verbose = verbose self.load_mode = load_mode or "sync" self.version = version or "v2" self.context = None self._current_function = None self.functions = functions or {} self.graph_initializer = graph_initializer self.error_stream = error_stream self.track_models = track_models self.tracking_policy = tracking_policy self._error_stream_object = None self.secret_sources = secret_sources self._secrets = SecretsStore.from_list(secret_sources) self._db_conn = None self.resource_cache = None self.default_content_type = default_content_type self.http_trigger = True
[docs] def set_current_function(self, function): """set which child function this server is currently running on""" self._current_function = function
@property def graph(self) -> Union[RootFlowStep, RouterStep]: return self._graph @graph.setter def graph(self, graph): graph_root_setter(self, graph)
[docs] def set_error_stream(self, error_stream): """set/initialize the error notification stream""" self.error_stream = error_stream if error_stream: self._error_stream_object = get_stream_pusher(error_stream) else: self._error_stream_object = None
def _get_db(self): return mlrun.get_run_db(secrets=self._secrets)
[docs] def init_states( self, context, namespace, resource_cache: ResourceCache = None, logger=None, is_mock=False, ): """for internal use, initialize all steps (recursively)""" if self.secret_sources: self._secrets = SecretsStore.from_list(self.secret_sources) if self.error_stream: self._error_stream_object = get_stream_pusher(self.error_stream) self.resource_cache = resource_cache or ResourceCache() context = GraphContext(server=self, nuclio_context=context, logger=logger) context.is_mock = is_mock context.root = self.graph context.stream = _StreamContext( self.track_models, self.parameters, self.function_uri ) context.current_function = self._current_function context.get_store_resource = self.resource_cache.resource_getter( self._get_db(), self._secrets ) context.get_table = self.resource_cache.get_table context.verbose = self.verbose self.context = context if self.graph_initializer: if callable(self.graph_initializer): handler = self.graph_initializer else: handler = get_function(self.graph_initializer, namespace or []) handler(self) context.root = self.graph
[docs] def init_object(self, namespace): self.graph.init_object(self.context, namespace, self.load_mode, reset=True) return ( v2_serving_async_handler if config.datastore.async_source_mode == "enabled" else v2_serving_handler )
[docs] def test( self, path: str = "/", body: Union[str, bytes, dict] = None, method: str = "", headers: Optional[str] = None, content_type: Optional[str] = None, silent: bool = False, get_body: bool = True, event_id: Optional[str] = None, trigger: "MockTrigger" = None, offset=None, time=None, ): """invoke a test event into the server to simulate/test server behavior example:: server = create_graph_server() server.add_model("my", class_name=MyModelClass, model_path="{path}", z=100) print(server.test("my/infer", testdata)) :param path: api path, e.g. (/{router.url_prefix}/{model-name}/..) path :param body: message body (dict or json str/bytes) :param method: optional, GET, POST, .. :param headers: optional, request headers, .. :param content_type: optional, http mime type :param silent: don't raise on error responses (when not 20X) :param get_body: return the body as py object (vs serialize response into json) :param event_id: specify the unique event ID (by default a random value will be generated) :param trigger: nuclio trigger info or mlrun.serving.server.MockTrigger class (holds kind and name) :param offset: trigger offset (for streams) :param time: event time Datetime or str, default to now() """ if not self.graph: raise MLRunInvalidArgumentError( "no models or steps were set, use function.set_topology() and add steps" ) if not method: method = "POST" if body else "GET" event = MockEvent( body=body, path=path, method=method, headers=headers, content_type=content_type, event_id=event_id, trigger=trigger, offset=offset, time=time, ) resp = self.run(event, get_body=get_body) if hasattr(resp, "status_code") and resp.status_code >= 300 and not silent: raise RuntimeError(f"failed ({resp.status_code}): {resp.body}") return resp
[docs] def run(self, event, context=None, get_body=False, extra_args=None): server_context = self.context context = context or server_context event.content_type = event.content_type or self.default_content_type or "" if event.headers: if event_id_key in event.headers: event.id = event.headers.get(event_id_key) if event_path_key in event.headers: event.path = event.headers.get(event_path_key) if isinstance(event.body, (str, bytes)) and ( not event.content_type or event.content_type in ["json", "application/json"] ): # assume it is json and try to load try: body = json.loads(event.body) event.body = body except (json.decoder.JSONDecodeError, UnicodeDecodeError) as exc: if event.content_type in ["json", "application/json"]: # if its json type and didnt load, raise exception message = f"failed to json decode event, {err_to_str(exc)}" context.logger.error(message) server_context.push_error(event, message, source="_handler") return context.Response( body=message, content_type="text/plain", status_code=400 ) try: response = self.graph.run(event, **(extra_args or {})) except Exception as exc: message = f"{exc.__class__.__name__}: {err_to_str(exc)}" if server_context.verbose: message += "\n" + str(traceback.format_exc()) context.logger.error(f"run error, {traceback.format_exc()}") server_context.push_error(event, message, source="_handler") return context.Response( body=message, content_type="text/plain", status_code=400 ) if asyncio.iscoroutine(response): return self._process_async_response(context, response, get_body) else: return self._process_response(context, response, get_body)
async def _process_async_response(self, context, response, get_body): return self._process_response(context, await response, get_body) def _process_response(self, context, response, get_body): body = response.body if isinstance(body, context.Response) or get_body: return body if body and not isinstance(body, (str, bytes)): body = json.dumps(body) return context.Response( body=body, content_type="application/json", status_code=200 ) return body
[docs] def wait_for_completion(self): """wait for async operation to complete""" self.graph.wait_for_completion()
def v2_serving_init(context, namespace=None): """hook for nuclio init_context()""" data = os.environ.get("SERVING_SPEC_ENV", "") if not data: raise MLRunInvalidArgumentError("failed to find spec env var") spec = json.loads(data) context.logger.info("Initializing server from spec") server = GraphServer.from_dict(spec) if config.log_level.lower() == "debug": server.verbose = True if hasattr(context, "trigger"): server.http_trigger = getattr(context.trigger, "kind", "http") == "http" context.logger.info_with( "Setting current function", current_functiton=os.environ.get("SERVING_CURRENT_FUNCTION", ""), ) server.set_current_function(os.environ.get("SERVING_CURRENT_FUNCTION", "")) context.logger.info_with( "Initializing states", namespace=namespace or get_caller_globals() ) server.init_states(context, namespace or get_caller_globals()) context.logger.info("Initializing graph steps") serving_handler = server.init_object(namespace or get_caller_globals()) # set the handler hook to point to our handler setattr(context, "mlrun_handler", serving_handler) setattr(context, "_server", server) context.logger.info_with("Serving was initialized", verbose=server.verbose) if server.verbose: context.logger.info(server.to_yaml()) def v2_serving_handler(context, event, get_body=False): """hook for nuclio handler()""" if context._server.http_trigger: # Workaround for a Nuclio bug where it sometimes passes b'' instead of None due to dirty memory if event.body == b"": event.body = None else: event.path = "/" # fix the issue that non http returns "Unsupported" return context._server.run(event, context, get_body) async def v2_serving_async_handler(context, event, get_body=False): """hook for nuclio handler()""" return await context._server.run(event, context, get_body)
[docs]def create_graph_server( parameters={}, load_mode=None, graph=None, verbose=False, current_function=None, **kwargs, ) -> GraphServer: """create graph server host/emulator for local or test runs Usage example:: server = create_graph_server(graph=RouterStep(), parameters={}) server.init(None, globals()) server.graph.add_route("my", class_name=MyModelClass, model_path="{path}", z=100) print(server.test("/v2/models/my/infer", testdata)) """ server = GraphServer(graph, parameters, load_mode, verbose=verbose, **kwargs) server.set_current_function( current_function or os.environ.get("SERVING_CURRENT_FUNCTION", "") ) return server
class MockTrigger(object): """mock nuclio event trigger""" def __init__(self, kind="", name=""): self.kind = kind self.name = name class MockEvent(object): """mock basic nuclio event object""" def __init__( self, body=None, content_type=None, headers=None, method=None, path=None, event_id=None, trigger: MockTrigger = None, offset=None, time=None, ): self.id = event_id or uuid.uuid4().hex self.key = "" self.body = body # optional self.headers = headers or {} self.method = method self.path = path or "/" self.content_type = content_type self.error = None self.trigger = trigger or MockTrigger() self.offset = offset or 0 def __str__(self): error = f", error={self.error}" if self.error else "" return f"Event(id={self.id}, body={self.body}, method={self.method}, path={self.path}{error})" class Response(object): def __init__(self, headers=None, body=None, content_type=None, status_code=200): self.headers = headers or {} self.body = body self.status_code = status_code self.content_type = content_type or "text/plain" def __repr__(self): cls = self.__class__.__name__ items = self.__dict__.items() args = (f"{key}={repr(value)}" for key, value in items) args_str = ", ".join(args) return f"{cls}({args_str})"
[docs]class GraphContext: """Graph context object""" def __init__(self, level="info", logger=None, server=None, nuclio_context=None): self.state = None self.logger = logger self.worker_id = 0 self.Response = Response self.verbose = False self.stream = None self.root = None if nuclio_context: self.logger = nuclio_context.logger self.Response = nuclio_context.Response self.worker_id = nuclio_context.worker_id elif not logger: self.logger = create_logger(level, "human", "flow", sys.stdout) self._server = server self.current_function = None self.get_store_resource = None self.get_table = None self.is_mock = False @property def server(self): return self._server @property def project(self): """current project name (for the current function)""" project, _, _, _ = mlrun.utils.parse_versioned_object_uri( self._server.function_uri ) return project
[docs] def push_error(self, event, message, source=None, **kwargs): if self.verbose: self.logger.error( f"got error from {source} state:\n{event.body}\n{message}" ) if self._server and self._server._error_stream_object: try: message = format_error( self._server, self, source, event, message, kwargs ) self._server._error_stream_object.push(message) except Exception as ex: message = traceback.format_exc() self.logger.error(f"failed to write to error stream: {ex}\n{message}")
[docs] def get_param(self, key: str, default=None): if self._server and self._server.parameters: return self._server.parameters.get(key, default) return default
[docs] def get_secret(self, key: str): if self._server and self._server._secrets: return self._server._secrets.get(key) return None
[docs] def get_remote_endpoint(self, name, external=True): """return the remote nuclio/serving function http(s) endpoint given its name :param name: the function name/uri in the form [project/]function-name[:tag] :param external: return the external url (returns the external url by default) """ if "://" in name: return name project, uri, tag, _ = mlrun.utils.parse_versioned_object_uri( self._server.function_uri ) if name.startswith("."): name = f"{uri}-{name[1:]}" else: project, name, tag, _ = mlrun.utils.parse_versioned_object_uri( name, project ) ( state, fullname, _, _, _, function_status, ) = mlrun.runtimes.function.get_nuclio_deploy_status(name, project, tag) if state in ["error", "unhealthy"]: raise ValueError( f"Nuclio function {fullname} is in error state, cannot be accessed" ) key = "externalInvocationUrls" if external else "internalInvocationUrls" urls = function_status.get(key) if not urls: raise ValueError(f"cannot read {key} for nuclio function {fullname}") return f"http://{urls[0]}"
def format_error(server, context, source, event, message, args): return { "function_uri": server.function_uri, "worker": context.worker_id, "host": socket.gethostname(), "source": source, "event": {"id": event.id, "body": event.body}, "message": message, "args": args, }