Batch inference and drift detection#
This tutorial leverages a function from the MLRun Function Hub to perform batch inference using a logged model and a new prediction dataset. The function also calculates data drift by comparing the new prediction dataset with the original training set.
Make sure you have reviewed the basics in MLRun Quick Start Tutorial.
Tutorial steps:
MLRun installation and configuration#
Before running this notebook make sure mlrun
is installed and that you have configured the access to the MLRun service.
!/User/align_mlrun.sh
Both server & client are aligned (1.7.0rc28).
Set up a project#
First, import the dependencies and create an MLRun project. The project contains all of your models, functions, datasets, etc.:
import mlrun
import pandas as pd
project = mlrun.get_or_create_project("tutorial", context="./", user_project=True)
Note
This tutorial does not focus on training a model. Instead, it starts with a trained model and its corresponding training and prediction dataset.
You'll use the following model files and datasets to perform the batch prediction. The model is a DecisionTreeClassifier from sklearn and the datasets are in parquet
format.
# Choose the correct model to avoid pickle warnings
import sys
suffix = (
mlrun.__version__.split("-")[0].replace(".", "_")
if sys.version_info[1] > 7
else "3.7"
)
model_path = mlrun.get_sample_path(f"models/batch-predict/model-{suffix}.pkl")
training_set_path = mlrun.get_sample_path("data/batch-predict/training_set.parquet")
prediction_set_path = mlrun.get_sample_path("data/batch-predict/prediction_set.parquet")
drifted_prediction_set_path = mlrun.get_sample_path(
"data/batch-predict/drifted_prediction_set.parquet"
)
View the data#
The training data has 20 numerical features and a binary (0,1) label:
pd.read_parquet(training_set_path).head()
feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | ... | feature_11 | feature_12 | feature_13 | feature_14 | feature_15 | feature_16 | feature_17 | feature_18 | feature_19 | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.572754 | 0.171079 | 0.403080 | 0.955429 | 0.272039 | 0.360277 | -0.995429 | 0.437239 | 0.991556 | 0.010004 | ... | 0.112194 | -0.319256 | -0.392631 | -0.290766 | 1.265054 | 1.037082 | -1.200076 | 0.820992 | 0.834868 | 0 |
1 | 0.623733 | -0.149823 | -1.410537 | -0.729388 | -1.996337 | -1.213348 | 1.461307 | 1.187854 | -1.790926 | -0.981600 | ... | 0.428653 | -0.503820 | -0.798035 | 2.038105 | -3.080463 | 0.408561 | 1.647116 | -0.838553 | 0.680983 | 1 |
2 | 0.814168 | -0.221412 | 0.020822 | 1.066718 | -0.573164 | 0.067838 | 0.923045 | 0.338146 | 0.981413 | 1.481757 | ... | -1.052559 | -0.241873 | -1.232272 | -0.010758 | 0.806800 | 0.661162 | 0.589018 | 0.522137 | -0.924624 | 0 |
3 | 1.062279 | -0.966309 | 0.341471 | -0.737059 | 1.460671 | 0.367851 | -0.435336 | 0.445308 | -0.655663 | -0.196220 | ... | 0.641017 | 0.099059 | 1.902592 | -1.024929 | 0.030703 | -0.198751 | -0.342009 | -1.286865 | -1.118373 | 1 |
4 | 0.195755 | 0.576332 | -0.260496 | 0.841489 | 0.398269 | -0.717972 | 0.810550 | -1.058326 | 0.368610 | 0.606007 | ... | 0.195267 | 0.876144 | 0.151615 | 0.094867 | 0.627353 | -0.389023 | 0.662846 | -0.857000 | 1.091218 | 1 |
5 rows × 21 columns
The prediction data has 20 numerical features, but no label - this is what you will predict:
pd.read_parquet(prediction_set_path).head()
feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | feature_10 | feature_11 | feature_12 | feature_13 | feature_14 | feature_15 | feature_16 | feature_17 | feature_18 | feature_19 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -2.059506 | -1.314291 | 2.721516 | -2.132869 | -0.693963 | 0.376643 | 3.017790 | 3.876329 | -1.294736 | 0.030773 | 0.401491 | 2.775699 | 2.361580 | 0.173441 | 0.879510 | 1.141007 | 4.608280 | -0.518388 | 0.129690 | 2.794967 |
1 | -1.190382 | 0.891571 | 3.726070 | 0.673870 | -0.252565 | -0.729156 | 2.646563 | 4.782729 | 0.318952 | -0.781567 | 1.473632 | 1.101721 | 3.723400 | -0.466867 | -0.056224 | 3.344701 | 0.194332 | 0.463992 | 0.292268 | 4.665876 |
2 | -0.996384 | -0.099537 | 3.421476 | 0.162771 | -1.143458 | -1.026791 | 2.114702 | 2.517553 | -0.154620 | -0.465423 | -1.723025 | 1.729386 | 2.820340 | -1.041428 | -0.331871 | 2.909172 | 2.138613 | -0.046252 | -0.732631 | 4.716266 |
3 | -0.289976 | -1.680019 | 3.126478 | -0.704451 | -1.149112 | 1.174962 | 2.860341 | 3.753661 | -0.326119 | 2.128411 | -0.508000 | 2.328688 | 3.397321 | -0.932060 | -1.442370 | 2.058517 | 3.881936 | 2.090635 | -0.045832 | 4.197315 |
4 | -0.294866 | 1.044919 | 2.924139 | 0.814049 | -1.455054 | -0.270432 | 3.380195 | 2.339669 | 1.029101 | -1.171018 | -1.459395 | 1.283565 | 0.677006 | -2.147444 | -0.494150 | 3.222041 | 6.219348 | -1.914110 | 0.317786 | 4.143443 |
Log the model with training data#
Next, log the model using MLRun experiment tracking. This is usually done in a training pipeline, but you can also bring in your pre-trained models from other sources. See Working with data and model artifacts and Automated experiment tracking for more information.
In this example, you are logging a training set with the model for future comparison, however you can also directly pass in your training set to the batch prediction function.
model_artifact = project.log_model(
key="model",
model_file=model_path,
framework="sklearn",
training_set=pd.read_parquet(training_set_path),
label_column="label",
)
# the model artifact unique URI
model_artifact.uri
Enabling model monitoring#
The MLRun's model monitoring service includes built-in model monitoring and reporting capabilities.
Visit MLRun's Model monitoring description to read more and check out the Realtime monitoring and drift detection.
Modifying controller frequency with base_period
parameter to 1 minute allows to see monitoring results faster, by default, its value is 10 minutes.
project.set_model_monitoring_credentials(None, "v3io", "v3io", "v3io")
project.enable_model_monitoring(
base_period=1, wait_for_deployment=True, deploy_histogram_data_drift_app=True
)
2024-07-09 14:44:31 (info) Deploying function
2024-07-09 14:44:31 (info) Building
2024-07-09 14:44:32 (info) Staging files and preparing base images
2024-07-09 14:44:32 (warn) Using user provided base image, runtime interpreter version is provided by the base image
2024-07-09 14:44:32 (info) Building processor image
2024-07-09 14:47:26 (info) Build complete
2024-07-09 14:47:48 (info) Function deploy complete
2024-07-09 14:44:28 (info) Deploying function
2024-07-09 14:44:28 (info) Building
2024-07-09 14:44:28 (info) Staging files and preparing base images
2024-07-09 14:44:28 (warn) Using user provided base image, runtime interpreter version is provided by the base image
2024-07-09 14:44:28 (info) Building processor image
2024-07-09 14:50:13 (info) Build complete
2024-07-09 14:50:48 (info) Function deploy complete
2024-07-09 14:44:29 (info) Deploying function
2024-07-09 14:44:29 (info) Building
2024-07-09 14:44:30 (info) Staging files and preparing base images
2024-07-09 14:44:30 (warn) Using user provided base image, runtime interpreter version is provided by the base image
2024-07-09 14:44:30 (info) Building processor image
2024-07-09 14:50:04 (info) Build complete
2024-07-09 14:50:48 (info) Function deploy complete
2024-07-09 14:44:33 (info) Deploying function
2024-07-09 14:44:33 (info) Building
2024-07-09 14:44:34 (info) Staging files and preparing base images
2024-07-09 14:44:34 (warn) Using user provided base image, runtime interpreter version is provided by the base image
2024-07-09 14:44:34 (info) Building processor image
2024-07-09 14:50:03 (info) Build complete
2024-07-09 14:50:48 (info) Function deploy complete
Import and run the batch inference function#
Next, import the batch inference function from the MLRun Function Hub:
fn = mlrun.import_function("hub://batch_inference_v2")
Run batch inference#
Finally, perform the batch prediction by passing in your model and datasets.
Including perform_drift_analysis
saves the batch to be analyzed by the project's model monitoring applications, and a new model endpoint record is generated.
Model endpoint is a unique MLRun entity that includes statistics and important details about your model and function.
You can perform the drift analysis on an existing model endpoint, but you need to make sure that you don't mix unrelated datasets that could affect the final drift analysis process.
In general, it's recommended to perform the drift analysis on a new model endpoint to avoid possible analysis conflicts.
See the corresponding batch inference example notebook for an exhaustive list of other parameters that are supported:
run = project.run_function(
fn,
inputs={
"dataset": prediction_set_path,
},
params={
"model_path": model_artifact.uri,
"perform_drift_analysis": True,
},
)
> 2024-07-09 14:51:10,770 [info] Storing function: {"db":"http://mlrun-api:8080","name":"batch-inference-v2-infer","uid":"9543bc386e57490b84bda797ff71bc2b"}
> 2024-07-09 14:51:11,350 [info] Job is running in the background, pod: batch-inference-v2-infer-kntxp
project | uid | iter | start | state | kind | name | labels | inputs | parameters | results | artifacts |
---|---|---|---|---|---|---|---|---|---|---|---|
tutorial52-dani | 0 | Jul 09 14:51:15 | completed | run | batch-inference-v2-infer | v3io_user=dani kind=job owner=dani mlrun/client_version=1.7.0-rc28 mlrun/client_python_version=3.9.18 host=batch-inference-v2-infer-kntxp |
dataset |
model_path=store://models/tutorial52-dani/model#0@8b229459-630d-4190-858f-e48dd557d5b6 perform_drift_analysis=True |
batch_id=f6eb466eb01d69819605d87aa1976063fce726692e239597434e7973 |
prediction |
> 2024-07-09 14:51:32,773 [info] Run execution finished: {"name":"batch-inference-v2-infer","status":"completed"}
Predictions and drift status#
These are the batch predictions on the prediction set from the model:
run.artifact("prediction").as_df().head()
feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | ... | feature_11 | feature_12 | feature_13 | feature_14 | feature_15 | feature_16 | feature_17 | feature_18 | feature_19 | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -2.059506 | -1.314291 | 2.721516 | -2.132869 | -0.693963 | 0.376643 | 3.017790 | 3.876329 | -1.294736 | 0.030773 | ... | 2.775699 | 2.361580 | 0.173441 | 0.879510 | 1.141007 | 4.608280 | -0.518388 | 0.129690 | 2.794967 | 1 |
1 | -1.190382 | 0.891571 | 3.726070 | 0.673870 | -0.252565 | -0.729156 | 2.646563 | 4.782729 | 0.318952 | -0.781567 | ... | 1.101721 | 3.723400 | -0.466867 | -0.056224 | 3.344701 | 0.194332 | 0.463992 | 0.292268 | 4.665876 | 1 |
2 | -0.996384 | -0.099537 | 3.421476 | 0.162771 | -1.143458 | -1.026791 | 2.114702 | 2.517553 | -0.154620 | -0.465423 | ... | 1.729386 | 2.820340 | -1.041428 | -0.331871 | 2.909172 | 2.138613 | -0.046252 | -0.732631 | 4.716266 | 0 |
3 | -0.289976 | -1.680019 | 3.126478 | -0.704451 | -1.149112 | 1.174962 | 2.860341 | 3.753661 | -0.326119 | 2.128411 | ... | 2.328688 | 3.397321 | -0.932060 | -1.442370 | 2.058517 | 3.881936 | 2.090635 | -0.045832 | 4.197315 | 0 |
4 | -0.294866 | 1.044919 | 2.924139 | 0.814049 | -1.455054 | -0.270432 | 3.380195 | 2.339669 | 1.029101 | -1.171018 | ... | 1.283565 | 0.677006 | -2.147444 | -0.494150 | 3.222041 | 6.219348 | -1.914110 | 0.317786 | 4.143443 | 0 |
5 rows × 21 columns
Enabling model monitoring application with default deploy_histogram_data_drift_app
provide additional plots like; drift_table_plot
that compares the drift between the training data and prediction data per feature and features_drift_results
to show each feature drift score.
import time
# Wait two minutes saftey interval for the controller to finish triggering the apps and write to databases.
time.sleep(120)
project.get_artifact("drift_table_plot").to_dataitem().show()