Databricks runtime#

The databricks runtime runs on a Databricks cluster (and not in the Iguazio cluster). The function raises a pod on MLRun, which communicates with the Databricks cluster. The requests originate in MLRun and all computing is in the Databricks cluster.

With the databricks runtime, you can send your local file/code as a string to the job, and use a handler as an endpoint for user code. You can optionally send keyword arguments (kwargs) to this job.

You can run the function on:

  • An existing cluster, by including DATABRICKS_CLUSTER_ID

  • A job compute cluster, created and dedicated for this function only.

Params that are not related to a new cluster or an existing cluster:

  • timeout_minutes

  • token_key

  • artifact_json_dir (location where the json file that contains all logged mlrun artifacts is saved, and which is deleted after the run)

Params that are related to a new cluster:

  • spark_version

  • node_type_id

  • num_workers

Example of a job compute cluster#

To create a job compute cluster, omit DATABRICKS_CLUSTER_ID, and set the cluster specs by using the task parameters when running the function. For example:

params['task_parameters'] = {'new_cluster_spec': {'node_type_id': 'm5d.large'}, 'number_of_workers': 2, 'timeout_minutes': 15, `token_key`: non-default-value}

Do not send variables named task_parameters or context since these are utilized by the internal processes of the runtime.

Example of running a Databricks job from a local file#

This example uses an existing cluster: DATABRICKS_CLUSTER_ID.

import os
import mlrun
from mlrun.runtimes.function_reference import FunctionReference
# If using a Databricks data store, for example, set the credentials:
os.environ["DATABRICKS_HOST"] = "DATABRICKS_HOST"
os.environ["DATABRICKS_TOKEN"] = "DATABRICKS_TOKEN"
os.environ["DATABRICKS_CLUSTER_ID"] = "DATABRICKS_CLUSTER_ID"
def add_databricks_env(function):
    job_env = {
        "DATABRICKS_HOST": os.environ["DATABRICKS_HOST"],
        "DATABRICKS_CLUSTER_ID": os.environ.get("DATABRICKS_CLUSTER_ID"),
    }

    for name, val in job_env.items():
        function.spec.env.append({"name": name, "value": val})
project_name = "databricks-runtime-project"
project = mlrun.get_or_create_project(project_name, context="./", user_project=False)

secrets = {"DATABRICKS_TOKEN": os.environ["DATABRICKS_TOKEN"]}

project.set_secrets(secrets)

code = """
def print_kwargs(**kwargs):
    print(f"kwargs: {kwargs}")
"""

function_ref = FunctionReference(
    kind="databricks",
    code=code,
    image="mlrun/mlrun",
    name="databricks-function",
)

function = function_ref.to_function()

add_databricks_env(function=function)

run = function.run(
    handler="print_kwargs",
    project=project_name,
    params={
        "param1": "value1",
        "param2": "value2",
        "task_parameters": {"timeout_minutes": 15},
    },
)

Logging a Databricks response as an artifact#

import numpy as np
import pandas as pd
from pyspark.sql import SparkSession


def main():
    df = pd.DataFrame({"A": np.random.randint(1, 100, 5), "B": np.random.rand(5)})
    path = "/dbfs/path/folder"
    parquet_df_path = f"{path}/df.parquet"
    csv_df_path = f"{path}/df.csv"

    if not os.path.exists(path):
        os.makedirs(path)

    # save df
    df.to_parquet(parquet_df_path)
    df.to_csv(csv_df_path, index=False)

    # log artifact
    mlrun_log_artifact("parquet_artifact", parquet_df_path)
    mlrun_log_artifact("csv_artifact", csv_df_path)

    # spark
    spark = SparkSession.builder.appName("example").getOrCreate()
    spark_df = spark.createDataFrame(df)

    # spark path format:
    spark_parquet_path = "dbfs:///path/folder/spark_df.parquet"
    spark_df.write.mode("overwrite").parquet(spark_parquet_path)
    mlrun_log_artifact("spark_artifact", spark_parquet_path)

    # an illegal artifact does not raise an error, it logs an error log instead, for example:
    # mlrun_log_artifact("illegal_artifact", "/not_exists_path/illegal_df.parquet")
function = mlrun.code_to_function(
    name="databricks-log_artifact",
    kind="databricks",
    project=project_name,
    filename="./databricks_job.py",
    image="mlrun/mlrun",
)
add_databricks_env(function=function)
run = function.run(
    handler="main",
    project=project_name,
)
project.list_artifacts()