Create a basic training job
Contents
Create a basic training job#
In this section, you create a simple job to train a model and log metrics, logs, and plots using MLRun’s auto-logging:
Define the training code#
The code you run is as follows. Notice, there is only a single line from MLRun to add all the MLOps capabilities:
%%writefile trainer.py
from sklearn import ensemble
from sklearn.model_selection import train_test_split
import mlrun
from mlrun.frameworks.sklearn import apply_mlrun
def train(
dataset: mlrun.DataItem, # data inputs are of type DataItem (abstract the data source)
label_column: str = "label",
n_estimators: int = 100,
learning_rate: float = 0.1,
max_depth: int = 3,
model_name: str = "cancer_classifier",
):
# Get the input dataframe (Use DataItem.as_df() to access any data source)
df = dataset.as_df()
# Initialize the x & y data
X = df.drop(label_column, axis=1)
y = df[label_column]
# Train/Test split the dataset
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Pick an ideal ML model
model = ensemble.GradientBoostingClassifier(
n_estimators=n_estimators, learning_rate=learning_rate, max_depth=max_depth
)
# -------------------- The only line you need to add for MLOps -------------------------
# Wraps the model with MLOps (test set is provided for analysis & accuracy measurements)
apply_mlrun(model=model, model_name=model_name, x_test=X_test, y_test=y_test)
# --------------------------------------------------------------------------------------
# Train the model
model.fit(X_train, y_train)
Writing trainer.py
Create the job#
Next, use code_to_function
to package up the Job
to get ready to execute on the cluster:
import mlrun
training_job = mlrun.code_to_function(
name="basic-training",
filename="trainer.py",
kind="job",
image="mlrun/mlrun",
handler="train"
)
Run the job#
Finally, run the job. The dataset is from S3, but usually it is the output from a previous step in a pipeline.
run = training_job.run(
inputs={"dataset": "https://igz-demo-datasets.s3.us-east-2.amazonaws.com/cancer-dataset.csv"},
params = {"n_estimators": 100, "learning_rate": 1e-1, "max_depth": 3}
)
> 2022-07-22 22:27:15,162 [info] starting run basic-training-train uid=bc1c6ad491c340e1a3b9b91bb520454f DB=http://mlrun-api:8080
> 2022-07-22 22:27:15,349 [info] Job is running in the background, pod: basic-training-train-kkntj
> 2022-07-22 22:27:20,927 [info] run executed, status=completed
final state: completed
project | uid | iter | start | state | name | labels | inputs | parameters | results | artifacts |
---|---|---|---|---|---|---|---|---|---|---|
default | 0 | Jul 22 22:27:18 | completed | basic-training-train | v3io_user=nick kind=job owner=nick mlrun/client_version=1.0.4 host=basic-training-train-kkntj |
dataset |
n_estimators=100 learning_rate=0.1 max_depth=3 |
accuracy=0.956140350877193 f1_score=0.965034965034965 precision_score=0.9583333333333334 recall_score=0.971830985915493 |
feature-importance test_set confusion-matrix roc-curves calibration-curve model |
> to track results use the .show() or .logs() methods or click here to open in UI
> 2022-07-22 22:27:21,640 [info] run executed, status=completed
View job results#
Once the job is complete, you can view the output metrics and visualize the artifacts.
run.outputs
{'accuracy': 0.956140350877193,
'f1_score': 0.965034965034965,
'precision_score': 0.9583333333333334,
'recall_score': 0.971830985915493,
'feature-importance': 'v3io:///projects/default/artifacts/feature-importance.html',
'test_set': 'store://artifacts/default/basic-training-train_test_set:bc1c6ad491c340e1a3b9b91bb520454f',
'confusion-matrix': 'v3io:///projects/default/artifacts/confusion-matrix.html',
'roc-curves': 'v3io:///projects/default/artifacts/roc-curves.html',
'calibration-curve': 'v3io:///projects/default/artifacts/calibration-curve.html',
'model': 'store://artifacts/default/cancer_classifier:bc1c6ad491c340e1a3b9b91bb520454f'}
run.artifact("confusion-matrix").show()
run.artifact("feature-importance").show()
run.artifact("test_set").show()
mean radius | mean texture | mean perimeter | mean area | mean smoothness | mean compactness | mean concavity | mean concave points | mean symmetry | mean fractal dimension | ... | worst texture | worst perimeter | worst area | worst smoothness | worst compactness | worst concavity | worst concave points | worst symmetry | worst fractal dimension | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 12.47 | 18.60 | 81.09 | 481.9 | 0.09965 | 0.10580 | 0.08005 | 0.03821 | 0.1925 | 0.06373 | ... | 24.64 | 96.05 | 677.9 | 0.1426 | 0.2378 | 0.2671 | 0.10150 | 0.3014 | 0.08750 | 1 |
1 | 18.94 | 21.31 | 123.60 | 1130.0 | 0.09009 | 0.10290 | 0.10800 | 0.07951 | 0.1582 | 0.05461 | ... | 26.58 | 165.90 | 1866.0 | 0.1193 | 0.2336 | 0.2687 | 0.17890 | 0.2551 | 0.06589 | 0 |
2 | 15.46 | 19.48 | 101.70 | 748.9 | 0.10920 | 0.12230 | 0.14660 | 0.08087 | 0.1931 | 0.05796 | ... | 26.00 | 124.90 | 1156.0 | 0.1546 | 0.2394 | 0.3791 | 0.15140 | 0.2837 | 0.08019 | 0 |
3 | 12.40 | 17.68 | 81.47 | 467.8 | 0.10540 | 0.13160 | 0.07741 | 0.02799 | 0.1811 | 0.07102 | ... | 22.91 | 89.61 | 515.8 | 0.1450 | 0.2629 | 0.2403 | 0.07370 | 0.2556 | 0.09359 | 1 |
4 | 11.54 | 14.44 | 74.65 | 402.9 | 0.09984 | 0.11200 | 0.06737 | 0.02594 | 0.1818 | 0.06782 | ... | 19.68 | 78.78 | 457.8 | 0.1345 | 0.2118 | 0.1797 | 0.06918 | 0.2329 | 0.08134 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
109 | 14.64 | 16.85 | 94.21 | 666.0 | 0.08641 | 0.06698 | 0.05192 | 0.02791 | 0.1409 | 0.05355 | ... | 25.44 | 106.00 | 831.0 | 0.1142 | 0.2070 | 0.2437 | 0.07828 | 0.2455 | 0.06596 | 1 |
110 | 16.07 | 19.65 | 104.10 | 817.7 | 0.09168 | 0.08424 | 0.09769 | 0.06638 | 0.1798 | 0.05391 | ... | 24.56 | 128.80 | 1223.0 | 0.1500 | 0.2045 | 0.2829 | 0.15200 | 0.2650 | 0.06387 | 0 |
111 | 11.52 | 14.93 | 73.87 | 406.3 | 0.10130 | 0.07808 | 0.04328 | 0.02929 | 0.1883 | 0.06168 | ... | 21.19 | 80.88 | 491.8 | 0.1389 | 0.1582 | 0.1804 | 0.09608 | 0.2664 | 0.07809 | 1 |
112 | 14.22 | 27.85 | 92.55 | 623.9 | 0.08223 | 0.10390 | 0.11030 | 0.04408 | 0.1342 | 0.06129 | ... | 40.54 | 102.50 | 764.0 | 0.1081 | 0.2426 | 0.3064 | 0.08219 | 0.1890 | 0.07796 | 1 |
113 | 20.73 | 31.12 | 135.70 | 1419.0 | 0.09469 | 0.11430 | 0.13670 | 0.08646 | 0.1769 | 0.05674 | ... | 47.16 | 214.00 | 3432.0 | 0.1401 | 0.2644 | 0.3442 | 0.16590 | 0.2868 | 0.08218 | 0 |
114 rows × 31 columns