# Copyright 2018 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util as imputil
import inspect
import json
import os
import socket
import sys
import tempfile
import traceback
from contextlib import redirect_stdout
from copy import copy
from io import StringIO
from os import environ, remove
from pathlib import Path
from subprocess import PIPE, Popen
from sys import executable
from distributed import Client, as_completed
from nuclio import Event
import mlrun
from mlrun.lists import RunList
from ..execution import MLClientCtx
from ..model import RunObject
from ..utils import get_handler_extended, get_in, logger, set_paths
from ..utils.clones import extract_source
from .base import BaseRuntime, FunctionSpec, spec_fields
from .kubejob import KubejobRuntime
from .remotesparkjob import RemoteSparkRuntime
from .utils import RunError, global_context, log_std
class ParallelRunner:
def _get_handler(self, handler, context):
return handler
def _get_dask_client(self, options):
if options.dask_cluster_uri:
function = mlrun.import_function(options.dask_cluster_uri)
return function.client, function.metadata.name
return Client(), None
def _parallel_run_many(
self, generator, execution: MLClientCtx, runobj: RunObject
) -> RunList:
if self.spec.build.source and generator.options.dask_cluster_uri:
# the attached dask cluster will not have the source code when we clone the git on run
raise mlrun.errors.MLRunRuntimeError(
"Cannot load source code into remote Dask at runtime use, "
"function.deploy() to add the code into the image instead"
)
results = RunList()
tasks = generator.generate(runobj)
handler = runobj.spec.handler
self._force_handler(handler)
set_paths(self.spec.pythonpath)
handler = self._get_handler(handler, execution)
client, function_name = self._get_dask_client(generator.options)
parallel_runs = generator.options.parallel_runs or 4
queued_runs = 0
num_errors = 0
def process_result(future):
nonlocal num_errors
resp, sout, serr = future.result()
runobj = RunObject.from_dict(resp)
try:
log_std(self._db_conn, runobj, sout, serr, skip=self.is_child)
resp = self._update_run_state(resp)
except RunError as err:
resp = self._update_run_state(resp, err=str(err))
num_errors += 1
results.append(resp)
if num_errors > generator.max_errors:
logger.error("max errors reached, stopping iterations!")
return True
run_results = resp["status"].get("results", {})
stop = generator.eval_stop_condition(run_results)
if stop:
logger.info(
f"reached early stop condition ({generator.options.stop_condition}), stopping iterations!"
)
return stop
completed_iter = as_completed([])
for task in tasks:
task_struct = task.to_dict()
project = get_in(task_struct, "metadata.project")
uid = get_in(task_struct, "metadata.uid")
iter = get_in(task_struct, "metadata.iteration", 0)
mlrun.get_run_db().store_run(
task_struct, uid=uid, project=project, iter=iter
)
resp = client.submit(
remote_handler_wrapper, task.to_json(), handler, self.spec.workdir
)
completed_iter.add(resp)
queued_runs += 1
if queued_runs >= parallel_runs:
future = next(completed_iter)
early_stop = process_result(future)
queued_runs -= 1
if early_stop:
break
for future in completed_iter:
process_result(future)
client.close()
if function_name and generator.options.teardown_dask:
logger.info("tearing down the dask cluster..")
mlrun.get_run_db().delete_runtime_resources(
kind="dask", object_id=function_name, force=True
)
return results
def remote_handler_wrapper(task, handler, workdir=None):
if task and not isinstance(task, dict):
task = json.loads(task)
context = MLClientCtx.from_dict(
task,
autocommit=False,
host=socket.gethostname(),
)
runobj = RunObject.from_dict(task)
sout, serr = exec_from_params(handler, runobj, context, workdir)
return context.to_dict(), sout, serr
[docs]class HandlerRuntime(BaseRuntime, ParallelRunner):
kind = "handler"
def _run(self, runobj: RunObject, execution: MLClientCtx):
handler = runobj.spec.handler
self._force_handler(handler)
tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
environ["MLRUN_META_TMPFILE"] = tmp
set_paths(self.spec.pythonpath)
context = MLClientCtx.from_dict(
runobj.to_dict(),
rundb=self.spec.rundb,
autocommit=False,
tmp=tmp,
host=socket.gethostname(),
)
global_context.set(context)
sout, serr = exec_from_params(handler, runobj, context, self.spec.workdir)
log_std(self._db_conn, runobj, sout, serr, show=False)
return context.to_dict()
class LocalFunctionSpec(FunctionSpec):
_dict_fields = spec_fields + ["clone_target_dir"]
def __init__(
self,
command=None,
args=None,
mode=None,
default_handler=None,
pythonpath=None,
entry_points=None,
description=None,
workdir=None,
build=None,
clone_target_dir=None,
):
super().__init__(
command=command,
args=args,
mode=mode,
build=build,
entry_points=entry_points,
description=description,
workdir=workdir,
default_handler=default_handler,
pythonpath=pythonpath,
)
self.clone_target_dir = clone_target_dir
[docs]class LocalRuntime(BaseRuntime, ParallelRunner):
kind = "local"
_is_remote = False
@property
def spec(self) -> LocalFunctionSpec:
return self._spec
@spec.setter
def spec(self, spec):
self._spec = self._verify_dict(spec, "spec", LocalFunctionSpec)
[docs] def to_job(self, image=""):
struct = self.to_dict()
obj = KubejobRuntime.from_dict(struct)
if image:
obj.spec.image = image
return obj
[docs] def with_source_archive(self, source, workdir=None, handler=None, target_dir=None):
"""load the code from git/tar/zip archive at runtime or build
:param source: valid path to git, zip, or tar file, e.g.
git://github.com/mlrun/something.git
http://some/url/file.zip
:param handler: default function handler
:param workdir: working dir relative to the archive root or absolute (e.g. './subdir')
:param target_dir: local target dir for repo clone (by default its <current-dir>/code)
"""
self.spec.build.source = source
self.spec.build.load_source_on_run = True
if handler:
self.spec.default_handler = handler
if workdir:
self.spec.workdir = workdir
if target_dir:
self.spec.clone_target_dir = target_dir
[docs] def is_deployed(self):
return True
def _get_handler(self, handler, context):
command = self.spec.command
if not command and self.spec.build.functionSourceCode:
# if the code is embedded in the function object extract or find it
command, _ = mlrun.run.load_func_code(self)
return load_module(command, handler, context)
def _pre_run(self, runobj: RunObject, execution: MLClientCtx):
workdir = self.spec.workdir
execution._current_workdir = workdir
execution._old_workdir = None
if self.spec.build.source and not hasattr(self, "_is_run_local"):
target_dir = extract_source(
self.spec.build.source,
self.spec.clone_target_dir,
secrets=execution._secrets_manager,
)
if workdir and not workdir.startswith("/"):
execution._current_workdir = os.path.join(target_dir, workdir)
else:
execution._current_workdir = workdir or target_dir
if execution._current_workdir:
execution._old_workdir = os.getcwd()
workdir = os.path.realpath(execution._current_workdir)
set_paths(workdir)
os.chdir(workdir)
else:
set_paths(os.path.realpath("."))
if (
runobj.metadata.labels["kind"] == RemoteSparkRuntime.kind
and environ["MLRUN_SPARK_CLIENT_IGZ_SPARK"] == "true"
):
from mlrun.runtimes.remotesparkjob import igz_spark_pre_hook
igz_spark_pre_hook()
def _post_run(self, results, execution: MLClientCtx):
if execution._old_workdir:
os.chdir(execution._old_workdir)
def _run(self, runobj: RunObject, execution: MLClientCtx):
environ["MLRUN_EXEC_CONFIG"] = runobj.to_json()
tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False).name
environ["MLRUN_META_TMPFILE"] = tmp
if self.spec.rundb:
environ["MLRUN_DBPATH"] = self.spec.rundb
handler = runobj.spec.handler
handler_str = handler or "main"
logger.debug(f"starting local run: {self.spec.command} # {handler_str}")
pythonpath = self.spec.pythonpath
if handler:
set_paths(pythonpath)
context = MLClientCtx.from_dict(
runobj.to_dict(),
rundb=self.spec.rundb,
autocommit=False,
tmp=tmp,
host=socket.gethostname(),
)
fn = self._get_handler(handler, context)
global_context.set(context)
sout, serr = exec_from_params(fn, runobj, context)
log_std(self._db_conn, runobj, sout, serr, skip=self.is_child, show=False)
return context.to_dict()
else:
command = self.spec.command
command = command.format(**runobj.spec.parameters)
logger.info(f"handler was not provided running main ({command})")
arg_list = command.split()
if self.spec.mode == "pass":
cmd = arg_list
else:
cmd = [executable, "-u"] + arg_list
env = None
if pythonpath:
if "PYTHONPATH" in environ:
pythonpath = f"{environ['PYTHONPATH']}:{pythonpath}"
env = {"PYTHONPATH": pythonpath}
if runobj.spec.verbose:
if not env:
env = {}
env["MLRUN_LOG_LEVEL"] = "DEBUG"
args = self.spec.args
if args:
new_args = []
for arg in args:
arg = arg.format(**runobj.spec.parameters)
new_args.append(arg)
args = new_args
sout, serr = run_exec(cmd, args, env=env, cwd=execution._current_workdir)
log_std(self._db_conn, runobj, sout, serr, skip=self.is_child, show=False)
try:
with open(tmp) as fp:
resp = fp.read()
remove(tmp)
if resp:
return json.loads(resp)
logger.error("empty context tmp file")
except FileNotFoundError:
logger.info("no context file found")
return runobj.to_dict()
def load_module(file_name, handler, context):
"""Load module from file name"""
module = None
if file_name:
path = Path(file_name)
mod_name = path.name
if path.suffix:
mod_name = mod_name[: -len(path.suffix)]
spec = imputil.spec_from_file_location(mod_name, file_name)
if spec is None:
raise RunError(f"cannot import from {file_name!r}")
module = imputil.module_from_spec(spec)
spec.loader.exec_module(module)
class_args = {}
if context:
class_args = copy(context._parameters.get("_init_args", {}))
return get_handler_extended(handler, context, class_args, namespaces=module)
def run_exec(cmd, args, env=None, cwd=None):
if args:
cmd += args
out = ""
if env and "SYSTEMROOT" in os.environ:
env["SYSTEMROOT"] = os.environ["SYSTEMROOT"]
process = Popen(cmd, stdout=PIPE, stderr=PIPE, env=os.environ, cwd=cwd)
while True:
nextline = process.stdout.readline()
if not nextline and process.poll() is not None:
break
print(nextline.decode("utf-8"), end="")
sys.stdout.flush()
out += nextline.decode("utf-8")
code = process.poll()
err = process.stderr.read().decode("utf-8") if code != 0 else ""
return out, err
class _DupStdout(object):
def __init__(self):
self.terminal = sys.stdout
self.buf = StringIO()
def write(self, message):
self.terminal.write(message)
self.buf.write(message)
def flush(self):
self.terminal.flush()
def exec_from_params(handler, runobj: RunObject, context: MLClientCtx, cwd=None):
old_level = logger.level
if runobj.spec.verbose:
logger.set_logger_level("DEBUG")
kwargs = get_func_arg(handler, runobj, context)
stdout = _DupStdout()
err = ""
val = None
old_dir = os.getcwd()
with redirect_stdout(stdout):
context.set_logger_stream(stdout)
try:
if cwd:
os.chdir(cwd)
val = handler(**kwargs)
context.set_state("completed", commit=False)
except Exception as exc:
err = str(exc)
logger.error(traceback.format_exc())
context.set_state(error=err, commit=False)
logger.set_logger_level(old_level)
stdout.flush()
if cwd:
os.chdir(old_dir)
context.set_logger_stream(sys.stdout)
if val:
context.log_result("return", val)
context.commit()
logger.set_logger_level(old_level)
return stdout.buf.getvalue(), err
def get_func_arg(handler, runobj: RunObject, context: MLClientCtx, is_nuclio=False):
params = runobj.spec.parameters or {}
inputs = runobj.spec.inputs or {}
kwargs = {}
args = inspect.signature(handler).parameters
for key in args.keys():
if key == "context":
kwargs[key] = context
elif is_nuclio and key == "event":
kwargs[key] = Event(runobj.to_dict())
elif key in params:
kwargs[key] = copy(params[key])
elif key in inputs:
obj = context.get_input(key, inputs[key])
if type(args[key].default) is str or args[key].annotation == str:
kwargs[key] = obj.local()
else:
kwargs[key] = context.get_input(key, inputs[key])
return kwargs