Source code for mlrun.runtimes.daskjob

# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import inspect
import socket
import time
from os import environ
from typing import Callable, Optional, Union

import mlrun.common.schemas
import mlrun.errors
import mlrun.k8s_utils
import mlrun.utils
import mlrun.utils.regex
from mlrun.errors import err_to_str

from ..config import config
from ..execution import MLClientCtx
from ..model import RunObject
from ..render import ipython_display
from ..utils import logger
from .base import FunctionStatus
from .kubejob import KubejobRuntime
from .local import exec_from_params, load_module
from .pod import KubeResourceSpec
from .utils import RunError, get_func_selector, log_std


class DaskSpec(KubeResourceSpec):
    _dict_fields = KubeResourceSpec._dict_fields + [
        "extra_pip",
        "remote",
        "service_type",
        "nthreads",
        "kfp_image",
        "node_port",
        "min_replicas",
        "max_replicas",
        "scheduler_timeout",
        "scheduler_resources",
        "worker_resources",
    ]

    def __init__(
        self,
        command=None,
        args=None,
        image=None,
        mode=None,
        volumes=None,
        volume_mounts=None,
        env=None,
        resources=None,
        build=None,
        default_handler=None,
        entry_points=None,
        description=None,
        replicas=None,
        image_pull_policy=None,
        service_account=None,
        image_pull_secret=None,
        extra_pip=None,
        remote=None,
        service_type=None,
        nthreads=None,
        kfp_image=None,
        node_port=None,
        min_replicas=None,
        max_replicas=None,
        scheduler_timeout=None,
        node_name=None,
        node_selector=None,
        affinity=None,
        scheduler_resources=None,
        worker_resources=None,
        priority_class_name=None,
        disable_auto_mount=False,
        pythonpath=None,
        workdir=None,
        tolerations=None,
        preemption_mode=None,
        security_context=None,
        clone_target_dir=None,
        state_thresholds=None,
    ):
        super().__init__(
            command=command,
            args=args,
            image=image,
            mode=mode,
            volumes=volumes,
            volume_mounts=volume_mounts,
            env=env,
            resources=resources,
            replicas=replicas,
            image_pull_policy=image_pull_policy,
            service_account=service_account,
            build=build,
            default_handler=default_handler,
            entry_points=entry_points,
            description=description,
            image_pull_secret=image_pull_secret,
            node_name=node_name,
            node_selector=node_selector,
            affinity=affinity,
            priority_class_name=priority_class_name,
            disable_auto_mount=disable_auto_mount,
            pythonpath=pythonpath,
            workdir=workdir,
            tolerations=tolerations,
            preemption_mode=preemption_mode,
            security_context=security_context,
            clone_target_dir=clone_target_dir,
            state_thresholds=state_thresholds,
        )
        self.args = args

        self.extra_pip = extra_pip
        self.remote = True if remote is None else remote  # make remote the default

        self.service_type = service_type
        self.kfp_image = kfp_image
        self.node_port = node_port
        self.min_replicas = min_replicas or 0
        self.max_replicas = max_replicas or 16
        # supported format according to https://github.com/dask/dask/blob/master/dask/utils.py#L1402
        self.scheduler_timeout = scheduler_timeout or "60 minutes"
        self.nthreads = nthreads or 1
        self._scheduler_resources = self.enrich_resources_with_default_pod_resources(
            "scheduler_resources", scheduler_resources
        )
        self._worker_resources = self.enrich_resources_with_default_pod_resources(
            "worker_resources", worker_resources
        )

        self.state_thresholds = None  # not supported in dask

    @property
    def scheduler_resources(self) -> dict:
        return self._scheduler_resources

    @scheduler_resources.setter
    def scheduler_resources(self, resources):
        self._scheduler_resources = self.enrich_resources_with_default_pod_resources(
            "scheduler_resources", resources
        )

    @property
    def worker_resources(self) -> dict:
        return self._worker_resources

    @worker_resources.setter
    def worker_resources(self, resources):
        self._worker_resources = self.enrich_resources_with_default_pod_resources(
            "worker_resources", resources
        )


class DaskStatus(FunctionStatus):
    def __init__(
        self,
        state=None,
        build_pod=None,
        scheduler_address=None,
        cluster_name=None,
        node_ports=None,
    ):
        super().__init__(state, build_pod)

        self.scheduler_address = scheduler_address
        self.cluster_name = cluster_name
        self.node_ports = node_ports


[docs]class DaskCluster(KubejobRuntime): kind = "dask" _is_nested = False _is_remote = False def __init__(self, spec=None, metadata=None): super().__init__(spec, metadata) self._cluster = None self.use_remote = not mlrun.k8s_utils.is_running_inside_kubernetes_cluster() self.spec.build.base_image = self.spec.build.base_image or "daskdev/dask:latest" @property def spec(self) -> DaskSpec: return self._spec @spec.setter def spec(self, spec): self._spec = self._verify_dict(spec, "spec", DaskSpec) @property def status(self) -> DaskStatus: return self._status @status.setter def status(self, status): self._status = self._verify_dict(status, "status", DaskStatus)
[docs] def is_deployed(self): if not self.spec.remote: return True return super().is_deployed()
@property def initialized(self): return bool(self._cluster) def _load_db_status(self): meta = self.metadata if self._is_remote_api(): db = self._get_db() db_func = None try: db_func = db.get_function(meta.name, meta.project, meta.tag) except Exception: pass if db_func and "status" in db_func: self.status = db_func["status"] if self.kfp: logger.info(f"Dask status: {db_func['status']}") return "scheduler_address" in db_func["status"] return False def _start(self, watch=True): db = self._get_db() if not self._is_remote_api(): self._cluster = db.start_function(function=self) return self.try_auto_mount_based_on_config() self._fill_credentials() if not self.is_deployed(): raise RunError( "Function image is not built/ready, use .deploy()" " method first, or set base dask image (daskdev/dask:latest)" ) self.save(versioned=False) background_task = db.start_function(func_url=self._function_uri()) if watch: now = datetime.datetime.utcnow() timeout = now + datetime.timedelta(minutes=10) while now < timeout: background_task = db.get_project_background_task( background_task.metadata.project, background_task.metadata.name ) if ( background_task.status.state in mlrun.common.schemas.BackgroundTaskState.terminal_states() ): if ( background_task.status.state == mlrun.common.schemas.BackgroundTaskState.failed ): raise mlrun.errors.MLRunRuntimeError( "Failed bringing up dask cluster" ) else: function = db.get_function( self.metadata.name, self.metadata.project, self.metadata.tag, ) if function and function.get("status"): self.status = function.get("status") return time.sleep(5) now = datetime.datetime.utcnow()
[docs] def close(self, running=True): from dask.distributed import default_client try: client = default_client() # shutdown the cluster first, then close the client client.shutdown() client.close() except ValueError: pass
[docs] def get_status(self): meta = self.metadata selector = get_func_selector(meta.project, meta.name, meta.tag) db = self._get_db() return db.function_status(meta.project, meta.name, self.kind, selector)
[docs] def cluster(self): return self._cluster
def _remote_addresses(self): addr = self.status.scheduler_address dash = "" if config.remote_host: if self.spec.service_type == "NodePort" and self.use_remote: addr = f"{config.remote_host}:{self.status.node_ports.get('scheduler')}" if self.spec.service_type == "NodePort": dash = f"{config.remote_host}:{self.status.node_ports.get('dashboard')}" else: logger.info("To get a dashboard link, use NodePort service_type") return addr, dash @property def client(self): from dask.distributed import Client, default_client if self.spec.remote and not self.status.scheduler_address: if not self._load_db_status(): self._start() if self.status.scheduler_address: addr, dash = self._remote_addresses() logger.info(f"Trying dask client at: {addr}") try: client = Client(addr) except OSError as exc: logger.warning( f"Remote scheduler at {addr} not ready, will try to restart {err_to_str(exc)}" ) status = self.get_status() if status != "running": self._start() addr, dash = self._remote_addresses() client = Client(addr) logger.info( f"Using remote dask scheduler ({self.status.cluster_name}) at: {addr}" ) if dash: ipython_display( f'<a href="http://{dash}/status" target="_blank" >dashboard link: {dash}</a>', alt_text=f"remote dashboard: {dash}", ) return client try: return default_client() except ValueError: return Client()
[docs] def deploy( self, watch=True, with_mlrun=None, skip_deployed=False, is_kfp=False, mlrun_version_specifier=None, builder_env: dict = None, show_on_failure: bool = False, force_build: bool = False, ): """deploy function, build container with dependencies :param watch: wait for the deploy to complete (and print build logs) :param with_mlrun: add the current mlrun package to the container build :param skip_deployed: skip the build if we already have an image for the function :param is_kfp: deploy as part of a kfp pipeline :param mlrun_version_specifier: which mlrun package version to include (if not current) :param builder_env: Kaniko builder pod env vars dict (for config/credentials) e.g. builder_env={"GIT_TOKEN": token} :param show_on_failure: show logs only in case of build failure :param force_build: force building the image, even when no changes were made :return True if the function is ready (deployed) """ return super().deploy( watch, with_mlrun, skip_deployed, is_kfp=is_kfp, mlrun_version_specifier=mlrun_version_specifier, builder_env=builder_env, show_on_failure=show_on_failure, force_build=force_build, )
[docs] def with_limits( self, mem=None, cpu=None, gpus=None, gpu_type="nvidia.com/gpu", patch: bool = False, ): raise NotImplementedError( "Use with_scheduler_limits/with_worker_limits to set resource limits", )
[docs] def with_scheduler_limits( self, mem: str = None, cpu: str = None, gpus: int = None, gpu_type: str = "nvidia.com/gpu", patch: bool = False, ): """ set scheduler pod resources limits by default it overrides the whole limits section, if you wish to patch specific resources use `patch=True`. """ self.spec._verify_and_set_limits( "scheduler_resources", mem, cpu, gpus, gpu_type, patch=patch )
[docs] def with_worker_limits( self, mem: str = None, cpu: str = None, gpus: int = None, gpu_type: str = "nvidia.com/gpu", patch: bool = False, ): """ set worker pod resources limits by default it overrides the whole limits section, if you wish to patch specific resources use `patch=True`. """ self.spec._verify_and_set_limits( "worker_resources", mem, cpu, gpus, gpu_type, patch=patch )
[docs] def with_requests(self, mem=None, cpu=None, patch: bool = False): raise NotImplementedError( "Use with_scheduler_requests/with_worker_requests to set resource requests", )
[docs] def with_scheduler_requests( self, mem: str = None, cpu: str = None, patch: bool = False ): """ set scheduler pod resources requests by default it overrides the whole requests section, if you wish to patch specific resources use `patch=True`. """ self.spec._verify_and_set_requests("scheduler_resources", mem, cpu, patch=patch)
[docs] def with_worker_requests( self, mem: str = None, cpu: str = None, patch: bool = False ): """ set worker pod resources requests by default it overrides the whole requests section, if you wish to patch specific resources use `patch=True`. """ self.spec._verify_and_set_requests("worker_resources", mem, cpu, patch=patch)
[docs] def set_state_thresholds( self, state_thresholds: dict[str, str], patch: bool = True, ): raise NotImplementedError( "State thresholds is not supported for Dask runtime yet, use spec.scheduler_timeout instead.", )
[docs] def run( self, runspec: Optional[ Union["mlrun.run.RunTemplate", "mlrun.run.RunObject", dict] ] = None, handler: Optional[Union[str, Callable]] = None, name: Optional[str] = "", project: Optional[str] = "", params: Optional[dict] = None, inputs: Optional[dict[str, str]] = None, out_path: Optional[str] = "", workdir: Optional[str] = "", artifact_path: Optional[str] = "", watch: Optional[bool] = True, schedule: Optional[Union[str, mlrun.common.schemas.ScheduleCronTrigger]] = None, hyperparams: Optional[dict[str, list]] = None, hyper_param_options: Optional[mlrun.model.HyperParamOptions] = None, verbose: Optional[bool] = None, scrape_metrics: Optional[bool] = None, local: Optional[bool] = False, local_code_path: Optional[str] = None, auto_build: Optional[bool] = None, param_file_secrets: Optional[dict[str, str]] = None, notifications: Optional[list[mlrun.model.Notification]] = None, returns: Optional[list[Union[str, dict[str, str]]]] = None, state_thresholds: Optional[dict[str, int]] = None, **launcher_kwargs, ) -> RunObject: if state_thresholds: raise mlrun.errors.MLRunInvalidArgumentError( "State thresholds is not supported for Dask runtime yet, use spec.scheduler_timeout instead." ) return super().run( runspec=runspec, handler=handler, name=name, project=project, params=params, inputs=inputs, out_path=out_path, workdir=workdir, artifact_path=artifact_path, watch=watch, schedule=schedule, hyperparams=hyperparams, hyper_param_options=hyper_param_options, verbose=verbose, scrape_metrics=scrape_metrics, local=local, local_code_path=local_code_path, auto_build=auto_build, param_file_secrets=param_file_secrets, notifications=notifications, returns=returns, state_thresholds=state_thresholds, **launcher_kwargs, )
def _run(self, runobj: RunObject, execution): handler = runobj.spec.handler self._force_handler(handler) # TODO: investigate if the following instructions could overwrite the environment on any MLRun API Pod # Such action could result on race conditions against other runtimes and MLRun itself extra_env = self._generate_runtime_env(runobj) environ.update(extra_env) context = MLClientCtx.from_dict( runobj.to_dict(), rundb=self.spec.rundb, autocommit=False, host=socket.gethostname(), ) if not inspect.isfunction(handler): if not self.spec.command: raise ValueError( "specified handler (string) without command " "(py file path), specify command or use handler pointer" ) handler = load_module(self.spec.command, handler, context=context) client = self.client setattr(context, "dask_client", client) sout, serr = exec_from_params(handler, runobj, context) log_std(self._db_conn, runobj, sout, serr, skip=self.is_child, show=False) return context.to_dict()