Batch inference and drift detection#
This tutorial leverages a function from the MLRun Function Hub to perform batch inference using a logged model and a new prediction dataset. The function also calculates data drift by comparing the new prediction dataset with the original training set.
Make sure you have reviewed the basics in MLRun Quick Start Tutorial.
Tutorial steps:
MLRun installation and configuration#
Before running this notebook make sure mlrun
is installed and that you have configured the access to the MLRun service.
!/User/align_mlrun.sh
Both server & client are aligned (1.7.0rc28).
Set up a project#
First, import the dependencies and create an MLRun project. The project contains all of your models, functions, datasets, etc.:
import mlrun
import pandas as pd
from mlrun.datastore.datastore_profile import DatastoreProfileV3io
project = mlrun.get_or_create_project("tutorial", context="./", user_project=True)
Note
This tutorial does not focus on training a model. Instead, it starts with a trained model and its corresponding training and prediction dataset.
You'll use the following model files and datasets to perform the batch prediction. The model is a DecisionTreeClassifier from sklearn and the datasets are in parquet
format.
# Choose the correct model to avoid pickle warnings
suffix = mlrun.__version__.split("-")[0].replace(".", "_")
model_path = mlrun.get_sample_path(f"models/batch-predict/model-{suffix}.pkl")
training_set_path = mlrun.get_sample_path("data/batch-predict/training_set.parquet")
prediction_set_path = mlrun.get_sample_path("data/batch-predict/prediction_set.parquet")
drifted_prediction_set_path = mlrun.get_sample_path(
"data/batch-predict/drifted_prediction_set.parquet"
)
View the data#
The training data has 20 numerical features and a binary (0,1) label:
pd.read_parquet(training_set_path).head()
feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | ... | feature_11 | feature_12 | feature_13 | feature_14 | feature_15 | feature_16 | feature_17 | feature_18 | feature_19 | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.572754 | 0.171079 | 0.403080 | 0.955429 | 0.272039 | 0.360277 | -0.995429 | 0.437239 | 0.991556 | 0.010004 | ... | 0.112194 | -0.319256 | -0.392631 | -0.290766 | 1.265054 | 1.037082 | -1.200076 | 0.820992 | 0.834868 | 0 |
1 | 0.623733 | -0.149823 | -1.410537 | -0.729388 | -1.996337 | -1.213348 | 1.461307 | 1.187854 | -1.790926 | -0.981600 | ... | 0.428653 | -0.503820 | -0.798035 | 2.038105 | -3.080463 | 0.408561 | 1.647116 | -0.838553 | 0.680983 | 1 |
2 | 0.814168 | -0.221412 | 0.020822 | 1.066718 | -0.573164 | 0.067838 | 0.923045 | 0.338146 | 0.981413 | 1.481757 | ... | -1.052559 | -0.241873 | -1.232272 | -0.010758 | 0.806800 | 0.661162 | 0.589018 | 0.522137 | -0.924624 | 0 |
3 | 1.062279 | -0.966309 | 0.341471 | -0.737059 | 1.460671 | 0.367851 | -0.435336 | 0.445308 | -0.655663 | -0.196220 | ... | 0.641017 | 0.099059 | 1.902592 | -1.024929 | 0.030703 | -0.198751 | -0.342009 | -1.286865 | -1.118373 | 1 |
4 | 0.195755 | 0.576332 | -0.260496 | 0.841489 | 0.398269 | -0.717972 | 0.810550 | -1.058326 | 0.368610 | 0.606007 | ... | 0.195267 | 0.876144 | 0.151615 | 0.094867 | 0.627353 | -0.389023 | 0.662846 | -0.857000 | 1.091218 | 1 |
5 rows × 21 columns
The prediction data has 20 numerical features, but no label - this is what you will predict:
pd.read_parquet(prediction_set_path).head()
feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | feature_10 | feature_11 | feature_12 | feature_13 | feature_14 | feature_15 | feature_16 | feature_17 | feature_18 | feature_19 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -2.059506 | -1.314291 | 2.721516 | -2.132869 | -0.693963 | 0.376643 | 3.017790 | 3.876329 | -1.294736 | 0.030773 | 0.401491 | 2.775699 | 2.361580 | 0.173441 | 0.879510 | 1.141007 | 4.608280 | -0.518388 | 0.129690 | 2.794967 |
1 | -1.190382 | 0.891571 | 3.726070 | 0.673870 | -0.252565 | -0.729156 | 2.646563 | 4.782729 | 0.318952 | -0.781567 | 1.473632 | 1.101721 | 3.723400 | -0.466867 | -0.056224 | 3.344701 | 0.194332 | 0.463992 | 0.292268 | 4.665876 |
2 | -0.996384 | -0.099537 | 3.421476 | 0.162771 | -1.143458 | -1.026791 | 2.114702 | 2.517553 | -0.154620 | -0.465423 | -1.723025 | 1.729386 | 2.820340 | -1.041428 | -0.331871 | 2.909172 | 2.138613 | -0.046252 | -0.732631 | 4.716266 |
3 | -0.289976 | -1.680019 | 3.126478 | -0.704451 | -1.149112 | 1.174962 | 2.860341 | 3.753661 | -0.326119 | 2.128411 | -0.508000 | 2.328688 | 3.397321 | -0.932060 | -1.442370 | 2.058517 | 3.881936 | 2.090635 | -0.045832 | 4.197315 |
4 | -0.294866 | 1.044919 | 2.924139 | 0.814049 | -1.455054 | -0.270432 | 3.380195 | 2.339669 | 1.029101 | -1.171018 | -1.459395 | 1.283565 | 0.677006 | -2.147444 | -0.494150 | 3.222041 | 6.219348 | -1.914110 | 0.317786 | 4.143443 |
Log the model with training data#
Next, log the model using MLRun experiment tracking. This is usually done in a training pipeline, but you can also bring in your pre-trained models from other sources. See Working with data and model artifacts and Automated experiment tracking for more information.
In this example, you are logging a training set with the model for future comparison, however you can also directly pass in your training set to the batch prediction function.
model_artifact = project.log_model(
key="model",
model_file=model_path,
framework="sklearn",
training_set=pd.read_parquet(training_set_path),
label_column="label",
)
# the model artifact unique URI
model_artifact.uri
Enabling model monitoring#
The MLRun's model monitoring service includes built-in model monitoring and reporting capabilities.
Visit MLRun's Model monitoring architecture to read more and check out the Model monitoring tutorial.
Modifying controller frequency with base_period
parameter to 1 minute allows to see monitoring results faster, by default, its value is 10 minutes.
tsdb_profile = DatastoreProfileV3io(name="v3io-tsdb-profile")
project.register_datastore_profile(tsdb_profile)
stream_profile = DatastoreProfileV3io(
name="v3io-stream-profile",
v3io_access_key=mlrun.mlconf.get_v3io_access_key(),
)
project.register_datastore_profile(stream_profile)
project.set_model_monitoring_credentials(
tsdb_profile_name=tsdb_profile.name,
stream_profile_name=stream_profile.name,
)
project.enable_model_monitoring(
base_period=1, wait_for_deployment=True, deploy_histogram_data_drift_app=False
)
Change the histogram data drift application defaults#
To generate the drift table plot artifact using MLRun's histogram data drift application, you have to change the application defaults.
You have to keep the default name of the application - "histogram-data-drift"
- for its full functionality, including the statistics
that are used in the "Model Endpoint" -> "Feature Analysis" view in the UI.
import mlrun.model_monitoring.applications.histogram_data_drift as histogram_data_drift
custom_hist_app = project.set_model_monitoring_function(
name=histogram_data_drift.HistogramDataDriftApplicationConstants.NAME, # keep the default name
func=histogram_data_drift.__file__,
application_class=histogram_data_drift.HistogramDataDriftApplication.__name__,
produce_plotly_artifact=True,
)
project.deploy_function(custom_hist_app)
Import and run the batch inference function#
Next, import the batch inference function from the MLRun Function Hub:
fn = mlrun.import_function("hub://batch_inference_v2")
Run batch inference#
Finally, perform the batch prediction by passing in your model and datasets.
Including perform_drift_analysis
saves the batch to be analyzed by the project's model monitoring applications, and a new model endpoint record is generated.
Model endpoint is a unique MLRun entity that includes statistics and important details about your model and function.
You can perform the drift analysis on an existing model endpoint, but you need to make sure that you don't mix unrelated datasets that could affect the final drift analysis process.
In general, it's recommended to perform the drift analysis on a new model endpoint to avoid possible analysis conflicts.
See the corresponding batch inference example notebook for an exhaustive list of other parameters that are supported:
run = project.run_function(
fn,
inputs={
"dataset": prediction_set_path,
},
params={
"model_path": model_artifact.uri,
"perform_drift_analysis": True,
},
)
> 2024-07-09 14:51:10,770 [info] Storing function: {"db":"http://mlrun-api:8080","name":"batch-inference-v2-infer","uid":"9543bc386e57490b84bda797ff71bc2b"}
> 2024-07-09 14:51:11,350 [info] Job is running in the background, pod: batch-inference-v2-infer-kntxp
project | uid | iter | start | state | kind | name | labels | inputs | parameters | results | artifacts |
---|---|---|---|---|---|---|---|---|---|---|---|
tutorial52-dani | 0 | Jul 09 14:51:15 | completed | run | batch-inference-v2-infer | v3io_user=dani kind=job owner=dani mlrun/client_version=1.7.0-rc28 mlrun/client_python_version=3.9.18 host=batch-inference-v2-infer-kntxp |
dataset |
model_path=store://models/tutorial52-dani/model#0@8b229459-630d-4190-858f-e48dd557d5b6 perform_drift_analysis=True |
batch_id=f6eb466eb01d69819605d87aa1976063fce726692e239597434e7973 |
prediction |
> 2024-07-09 14:51:32,773 [info] Run execution finished: {"name":"batch-inference-v2-infer","status":"completed"}
Predictions and drift status#
These are the batch predictions on the prediction set from the model:
run.artifact("prediction").as_df().head()
feature_0 | feature_1 | feature_2 | feature_3 | feature_4 | feature_5 | feature_6 | feature_7 | feature_8 | feature_9 | ... | feature_11 | feature_12 | feature_13 | feature_14 | feature_15 | feature_16 | feature_17 | feature_18 | feature_19 | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -2.059506 | -1.314291 | 2.721516 | -2.132869 | -0.693963 | 0.376643 | 3.017790 | 3.876329 | -1.294736 | 0.030773 | ... | 2.775699 | 2.361580 | 0.173441 | 0.879510 | 1.141007 | 4.608280 | -0.518388 | 0.129690 | 2.794967 | 1 |
1 | -1.190382 | 0.891571 | 3.726070 | 0.673870 | -0.252565 | -0.729156 | 2.646563 | 4.782729 | 0.318952 | -0.781567 | ... | 1.101721 | 3.723400 | -0.466867 | -0.056224 | 3.344701 | 0.194332 | 0.463992 | 0.292268 | 4.665876 | 1 |
2 | -0.996384 | -0.099537 | 3.421476 | 0.162771 | -1.143458 | -1.026791 | 2.114702 | 2.517553 | -0.154620 | -0.465423 | ... | 1.729386 | 2.820340 | -1.041428 | -0.331871 | 2.909172 | 2.138613 | -0.046252 | -0.732631 | 4.716266 | 0 |
3 | -0.289976 | -1.680019 | 3.126478 | -0.704451 | -1.149112 | 1.174962 | 2.860341 | 3.753661 | -0.326119 | 2.128411 | ... | 2.328688 | 3.397321 | -0.932060 | -1.442370 | 2.058517 | 3.881936 | 2.090635 | -0.045832 | 4.197315 | 0 |
4 | -0.294866 | 1.044919 | 2.924139 | 0.814049 | -1.455054 | -0.270432 | 3.380195 | 2.339669 | 1.029101 | -1.171018 | ... | 1.283565 | 0.677006 | -2.147444 | -0.494150 | 3.222041 | 6.219348 | -1.914110 | 0.317786 | 4.143443 | 0 |
5 rows × 21 columns
The custom histogram data drift model monitoring application you deployed provides additional plots like:
drift_table_plot
that compares the drift between the training data and prediction data per feature;
and features_drift_results
to show each feature drift score.
import time
# Wait two minutes for the model monitoring applications flow to finish
time.sleep(120)
project.get_artifact("drift_table_plot").to_dataitem().show()
Finally, you also get a numerical drift metric and boolean flag denoting whether or not data drift is detected:
import mlrun.model_monitoring.api
import yaml
endpoint = mlrun.model_monitoring.api.get_or_create_model_endpoint(
project=project.name,
model_endpoint_name="batch-infer",
function_name="batch-inference-v2",
feature_analysis=True,
)
print(f"Drift detected : {'no' if endpoint.status.result_status=='0' else 'yes'}")
print(
f"Drift metrics :\n{yaml.dump(endpoint.status.drift_measures, default_flow_style=False)}"
)
Drift detected : no
Drift metrics :
feature_0:
hellinger: 0.033573682
kld: 0.0080091461
tvd: 0.0228
feature_1:
hellinger: 0.046301454
kld: 0.0151360992
tvd: 0.0436
feature_10:
hellinger: 0.0447340712
kld: 0.0161075955
tvd: 0.0404
feature_11:
hellinger: 0.8058863402
kld: 8.2074183749
tvd: 0.8432
feature_12:
hellinger: 0.8079575232
kld: 8.5406740489
tvd: 0.8336
feature_13:
hellinger: 0.0451944931
kld: 0.0143495624
tvd: 0.04
feature_14:
hellinger: 0.0547294476
kld: 0.0223062969
tvd: 0.038
feature_15:
hellinger: 0.7468151368
kld: 7.4931645795
tvd: 0.7436
feature_16:
hellinger: 0.8003245178
kld: 8.0970639059
tvd: 0.8412
feature_17:
hellinger: 0.038955715
kld: 0.0106774482
tvd: 0.0332
feature_18:
hellinger: 0.0464746522
kld: 0.0158942375
tvd: 0.0424
feature_19:
hellinger: 0.7993397396
kld: 7.6134558152
tvd: 0.8284
feature_2:
hellinger: 0.7900559843
kld: 7.4047027397
tvd: 0.8264
feature_3:
hellinger: 0.049139638
kld: 0.0173319463
tvd: 0.0384
feature_4:
hellinger: 0.0469112823
kld: 0.016764728
tvd: 0.0388
feature_5:
hellinger: 0.0540843967
kld: 0.0214804341
tvd: 0.0504
feature_6:
hellinger: 0.7926084404
kld: 7.6316683385
tvd: 0.84
feature_7:
hellinger: 0.794981259
kld: 7.7541374562
tvd: 0.8368
feature_8:
hellinger: 0.0397202637
kld: 0.0110989917
tvd: 0.0388
feature_9:
hellinger: 0.0465672701
kld: 0.0142817875
tvd: 0.0432
label:
hellinger: 0.6496873941
kld: 7.0054651625
tvd: 0.4952
# Data/concept drift per feature
import json
json.loads(project.get_artifact("features_drift_results").to_dataitem().get())
{'feature_0': 0.028186841,
'feature_1': 0.044950727,
'feature_2': 0.8082279922,
'feature_3': 0.043769819,
'feature_4': 0.0428556412,
'feature_5': 0.0522421983,
'feature_6': 0.8163042202,
'feature_7': 0.8158906295,
'feature_8': 0.0392601319,
'feature_9': 0.044883635,
'feature_10': 0.0425670356,
'feature_11': 0.8245431701,
'feature_12': 0.8207787616,
'feature_13': 0.0425972466,
'feature_14': 0.0463647238,
'feature_15': 0.7452075684,
'feature_16': 0.8207622589,
'feature_17': 0.0360778575,
'feature_18': 0.0444373261,
'feature_19': 0.8138698698,
'label': 0.572443697}
Examining the drift results in the dashboard#
This section reviews the main charts and statistics that can be found on the platform dashboard. See Model monitoring architecture to learn more about the available model monitoring features and how to use them.
Before analyzing the results in the visual dashboards, run another batch infer job, but this time with drifted data, to get a drifted result. The drift decision rule is the value per-feature mean of the Total Variance Distance (TVD) and Hellinger distance scores.
In the histogram-data-drift application, the "Drift detected" threshold is 0.7 and the "Drift suspected" threshold is 0.3
run = project.run_function(
fn,
inputs={
"dataset": drifted_prediction_set_path,
},
params={
"model_path": model_artifact.uri,
"perform_drift_analysis": True,
"model_endpoint_name": "drifted-model-endpoint",
},
)
> 2024-07-09 14:53:33,903 [info] Storing function: {"db":"http://mlrun-api:8080","name":"batch-inference-v2-infer","uid":"0f1dd3eaf3644358bc77a43388eb527c"}
> 2024-07-09 14:53:34,425 [info] Job is running in the background, pod: batch-inference-v2-infer-6fddk
project | uid | iter | start | state | kind | name | labels | inputs | parameters | results | artifacts |
---|---|---|---|---|---|---|---|---|---|---|---|
tutorial52-dani | 0 | Jul 09 14:53:39 | completed | run | batch-inference-v2-infer | v3io_user=dani kind=job owner=dani mlrun/client_version=1.7.0-rc28 mlrun/client_python_version=3.9.18 host=batch-inference-v2-infer-6fddk |
dataset |
model_path=store://models/tutorial52-dani/model#0@8b229459-630d-4190-858f-e48dd557d5b6 perform_drift_analysis=True model_endpoint_name=drifted-model-endpoint |
batch_id=e83a7dc0d76df11abeba32397460884b9e7f2949ef9a4e75222bea66 |
prediction |
> 2024-07-09 14:53:55,768 [info] Run execution finished: {"name":"batch-inference-v2-infer","status":"completed"}
Now you can observe the drift result:
time.sleep(120)
endpoint = mlrun.model_monitoring.api.get_or_create_model_endpoint(
project=project.name,
model_endpoint_name="drifted-model-endpoint",
function_name="batch-inference-v2",
feature_analysis=True,
)
print(f"Drift detected : {'no' if endpoint.status.result_status=='0' else 'yes'}")
print(
f"Drift metrics :\n{yaml.dump(endpoint.status.drift_measures, default_flow_style=False)}"
)
Drift detected : yes
Drift metrics :
feature_0:
hellinger: 1.0
kld: 15.9478317661
tvd: 1.0
feature_1:
hellinger: 1.0
kld: 15.9425307407
tvd: 1.0
feature_10:
hellinger: 0.0447340712
kld: 0.0161075955
tvd: 0.0404
feature_11:
hellinger: 0.8058863402
kld: 8.2074183749
tvd: 0.8432
feature_12:
hellinger: 0.8079575232
kld: 8.5406740489
tvd: 0.8336
feature_13:
hellinger: 0.0451944931
kld: 0.0143495624
tvd: 0.04
feature_14:
hellinger: 0.0547294476
kld: 0.0223062969
tvd: 0.038
feature_15:
hellinger: 0.7468151368
kld: 7.4931645795
tvd: 0.7436
feature_16:
hellinger: 0.8003245178
kld: 8.0970639059
tvd: 0.8412
feature_17:
hellinger: 0.038955715
kld: 0.0106774482
tvd: 0.0332
feature_18:
hellinger: 0.0464746522
kld: 0.0158942375
tvd: 0.0424
feature_19:
hellinger: 0.7993397396
kld: 7.6134558152
tvd: 0.8284
feature_2:
hellinger: 1.0
kld: 15.9542498714
tvd: 1.0
feature_3:
hellinger: 1.0
kld: 16.0409151351
tvd: 1.0
feature_4:
hellinger: 1.0
kld: 16.0700327124
tvd: 1.0
feature_5:
hellinger: 1.0
kld: 15.8530461861
tvd: 1.0
feature_6:
hellinger: 1.0
kld: 15.8925883625
tvd: 1.0
feature_7:
hellinger: 1.0
kld: 15.9013904377
tvd: 1.0
feature_8:
hellinger: 1.0
kld: 15.9352176612
tvd: 1.0
feature_9:
hellinger: 1.0
kld: 15.9650317552
tvd: 1.0
label:
hellinger: 0.6974778813
kld: 8.2594612504
tvd: 0.4952
Model Endpoints#
In the Projects page > Model endpoint summary list, you can see the new two model endpoints, including their drift status:
You can zoom into one of the model endpoints to get an overview about the selected endpoint, including the calculated statistical drift metrics:
Press Features Analysis to see details of the drift analysis in a table format with each feature in the selected model on its own line, including the predicted label:
Next steps#
In a production setting, you probably want to incorporate this as part of a larger pipeline or application.
For example, if you use this function for the prediction capabilities, you can pass the prediction
output as the input to another pipeline step, store it in an external location like S3, or send to an application or user.
If you use this function for the drift detection capabilities, you can use the drift_status
and drift_metrics
outputs to automate further pipeline steps, send a notification, or kick off a re-training pipeline.
Done!#
Congratulations! You've completed Part 6 of the MLRun getting-started tutorial. You may want to review the additional tutorials: