Skip to content

Instantly share code, notes, and snippets.

@shiyuangu
Last active April 5, 2023 01:48
Show Gist options
  • Select an option

  • Save shiyuangu/960912bdced2c36648e35c70adad6213 to your computer and use it in GitHub Desktop.

Select an option

Save shiyuangu/960912bdced2c36648e35c70adad6213 to your computer and use it in GitHub Desktop.
lightgbm_multiclass.ipynb
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shiyuangu/960912bdced2c36648e35c70adad6213/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "8bb22cb3-6b96-4998-bb72-fffea5edacb1",
"showTitle": false,
"title": ""
},
"id": "bMF37kQvBBS-"
},
"source": [
"# lightgbm for multiclass\n",
"\n",
"+ LightGBM use *one-vs-rest* fashion for multi-class. you can view it as *num_class* binary classification tasks, so there are *num_class* GBDTs.+ There are two objective functions for for multiclass. `multiclass` and `multiclassova`. They gives different scores for the same data point. `multiclassova` gives exactly the same score as the binary-classifcation problem for the corresponding class\n",
"\n",
"+ I will guess that the objective function for `multiclassova` is the [sigmoid loss](https://www.geeksforgeeks.org/sigmoid-cross-entropy-function-of-tensorflow/) while objective function for `multiclass` is the softmax (As in [SO](https://stats.stackexchange.com/questions/563718/formal-steps-for-gradient-boosting-with-softmax-and-cross-entropy-loss-function) , each `yi` for the class i is modeled by a booster, and the booster can be updated simultaneously. We can draw an analogy to one-layer perceptron: for two class problem, we can use sigmoid loss or softmax loss)\n",
"\n",
"+ Cf: [lgb-1518](https://github.com/microsoft/LightGBM/issues/1518), [lgb-1675](https://github.com/microsoft/LightGBM/issues/1675)\n",
"\n",
"#+begin_verse sigmoid loss\n",
"loss = -(y_true log(sigmoid(y_pred)) + (1 - y_true) log(1 - sigmoid(y_pred)))\n",
"#+end_verse\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "42c656de-94cd-43a9-b109-7cd47ba0057e",
"showTitle": false,
"title": ""
},
"id": "pQQOKfrvBBS_",
"outputId": "dc423827-2bcd-4c58-942c-31e4c4d56a29"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">Out[1]: &#39;3.3.2&#39;</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">Out[1]: &#39;3.3.2&#39;</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"import lightgbm as lgb\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"import numpy as np\n",
"lgb.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "1d19b4b6-39cc-4a33-879d-76ddef6ef53d",
"showTitle": false,
"title": ""
},
"id": "lVSXd_TKBBTB",
"outputId": "d9d50a2b-8acc-4dd3-a9b1-e82755fc41ea"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">/databricks/python/lib/python3.8/site-packages/lightgbm/engine.py:239: UserWarning: &#39;verbose_eval&#39; argument is deprecated and will be removed in a future release of LightGBM. Pass &#39;log_evaluation()&#39; callback via &#39;callbacks&#39; argument instead.\n",
" _log_warning(&#34;&#39;verbose_eval&#39; argument is deprecated and will be removed in a future release of LightGBM. &#34;\n",
"[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.008743 seconds.\n",
"You can set `force_row_wise=true` to remove the overhead.\n",
"And if memory is not enough, you can set `force_col_wise=true`.\n",
"[LightGBM] [Info] Total Bins 90\n",
"[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n",
"[LightGBM] [Info] Start training from score -1.049822\n",
"[LightGBM] [Info] Start training from score -1.176574\n",
"[LightGBM] [Info] Start training from score -1.073920\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n",
"[LightGBM] [Info] Number of positive: 42, number of negative: 78\n",
"[LightGBM] [Info] Number of positive: 37, number of negative: 83\n",
"[LightGBM] [Info] Number of positive: 41, number of negative: 79\n",
"[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000041 seconds.\n",
"You can set `force_col_wise=true` to remove the overhead.\n",
"[LightGBM] [Info] Total Bins 90\n",
"[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n",
"[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.350000 -&gt; initscore=-0.619039\n",
"[LightGBM] [Info] Start training from score -0.619039\n",
"[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.308333 -&gt; initscore=-0.807923\n",
"[LightGBM] [Info] Start training from score -0.807923\n",
"[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.341667 -&gt; initscore=-0.655876\n",
"[LightGBM] [Info] Start training from score -0.655876\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n",
"</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">/databricks/python/lib/python3.8/site-packages/lightgbm/engine.py:239: UserWarning: &#39;verbose_eval&#39; argument is deprecated and will be removed in a future release of LightGBM. Pass &#39;log_evaluation()&#39; callback via &#39;callbacks&#39; argument instead.\n _log_warning(&#34;&#39;verbose_eval&#39; argument is deprecated and will be removed in a future release of LightGBM. &#34;\n[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.008743 seconds.\nYou can set `force_row_wise=true` to remove the overhead.\nAnd if memory is not enough, you can set `force_col_wise=true`.\n[LightGBM] [Info] Total Bins 90\n[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n[LightGBM] [Info] Start training from score -1.049822\n[LightGBM] [Info] Start training from score -1.176574\n[LightGBM] [Info] Start training from score -1.073920\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Info] Number of positive: 42, number of negative: 78\n[LightGBM] [Info] Number of positive: 37, number of negative: 83\n[LightGBM] [Info] Number of positive: 41, number of negative: 79\n[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000041 seconds.\nYou can set `force_col_wise=true` to remove the overhead.\n[LightGBM] [Info] Total Bins 90\n[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.350000 -&gt; initscore=-0.619039\n[LightGBM] [Info] Start training from score -0.619039\n[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.308333 -&gt; initscore=-0.807923\n[LightGBM] [Info] Start training from score -0.807923\n[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.341667 -&gt; initscore=-0.655876\n[LightGBM] [Info] Start training from score -0.655876\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"# Cf: https://github.com/chaupmcs/multiclass_vs_multiclassova/blob/master/multiclass_vs_multiclassova.ipynb\n",
"\n",
"iris = load_iris()\n",
"X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)\n",
"\n",
"train_data = lgb.Dataset(X_train, label=y_train)\n",
"test_data = lgb.Dataset(X_test, label=y_test)\n",
"\n",
"params = {\n",
" 'objective': 'multiclass',\n",
" 'num_class': 3,\n",
" 'metric': 'multi_logloss',\n",
" 'deterministic': True\n",
"}\n",
"params2 = {\n",
" 'objective': 'multiclassova',\n",
" 'num_class': 3,\n",
" 'metric': 'multi_logloss',\n",
" 'deterministic': True\n",
"}\n",
"\n",
"\n",
"model = lgb.train(params=params,\n",
" train_set=train_data,\n",
" num_boost_round=1,\n",
" valid_sets=[train_data, test_data],\n",
" verbose_eval=10)\n",
"model_ova = lgb.train(params=params2,\n",
" train_set=train_data,\n",
" num_boost_round=1,\n",
" valid_sets=[train_data, test_data],\n",
" verbose_eval=10)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "a11fca25-22cb-42d6-8674-abdf5dabe48a",
"showTitle": false,
"title": ""
},
"id": "OEdFVD6kBBTB",
"outputId": "4f84978d-e958-4670-c48d-66ec89083181"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">[[ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n",
" [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n",
" [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n",
" [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n",
" [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n",
" [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n",
" [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n",
" [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n",
" [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n",
" [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n",
" [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n",
" [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n",
" [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n",
" [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n",
" [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n",
" [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n",
" [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n",
" [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n",
" [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n",
" [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n",
" [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n",
" [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n",
" [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n",
" [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n",
" [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n",
" [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n",
" [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n",
" [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n",
" [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n",
" [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]]\n",
"</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">[[ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]]\n</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"y_pred_leaf = model.predict(X_test, pred_leaf=True)\n",
"y_pred_proba = model.predict(X_test, raw_score=False)\n",
"y_pred_score = model.predict(X_test, raw_score=True)\n",
"print(np.hstack([y_pred_proba,y_pred_score]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "44cb071b-84d9-4420-837d-9ca4668d02c4",
"showTitle": false,
"title": ""
},
"id": "L169NEQkBBTC",
"outputId": "b94f7d92-9853-4ffa-a348-0d16683532c4"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">Out[4]: 3.3766115072321297e-16</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">Out[4]: 3.3766115072321297e-16</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"# 'objective': 'multiclass' gives softmax \n",
"_t1 = np.exp(y_pred_score)\n",
"_t2 = np.sum(_t1, axis=1)\n",
"_y = _t1/_t2.repeat(repeats=3).reshape(30,3)\n",
"np.linalg.norm(_y - y_pred_proba)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "878915dc-2fce-411b-aa23-7e879d0295a7",
"showTitle": false,
"title": ""
},
"id": "N4HxDSLBBBTC",
"outputId": "d7129ca5-2725-4cbf-b138-91ce9dc0a25b"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">[[ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n",
" [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n",
" [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n",
" [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n",
" [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n",
" [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n",
" [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n",
" [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n",
" [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n",
" [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n",
" [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n",
" [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n",
" [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n",
" [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n",
" [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n",
" [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n",
" [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n",
" [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n",
" [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n",
" [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n",
" [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n",
" [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n",
" [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n",
" [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n",
" [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n",
" [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n",
" [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n",
" [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n",
" [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n",
" [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]]\n",
"</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">[[ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]]\n</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"y_pred_leaf_ova = model_ova.predict(X_test, pred_leaf=True)\n",
"y_pred_proba_ova = model_ova.predict(X_test, raw_score=False)\n",
"y_pred_score_ova = model_ova.predict(X_test, raw_score=True)\n",
"print(np.hstack([y_pred_proba_ova,y_pred_score_ova]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "92d62b97-df11-4168-af67-9fad8f2b74ea",
"showTitle": false,
"title": ""
},
"id": "0-mIDG1YBBTC",
"outputId": "b2e32bfd-3bbe-49b9-dc3a-aac9c70a2d35"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\"></div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\"></div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"def sigmoid(x):\n",
" return 1.0 / (1.0 + np.exp(-x))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "f9218b6a-92c6-48c7-ade5-2906568b41af",
"showTitle": false,
"title": ""
},
"id": "-NQRA6Q9BBTC",
"outputId": "ff596466-2230-4d17-d626-3a9d30d255f0"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">Out[7]: 0.0</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">Out[7]: 0.0</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"# the proba is obtained from the sigmoid for 'objective': 'multiclassova'\n",
"np.linalg.norm(y_pred_proba_ova - sigmoid(y_pred_score_ova))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "cdec4744-bd3b-4386-afc3-8d31b3275e6e",
"showTitle": false,
"title": ""
},
"id": "PJjMY94YBBTD",
"outputId": "7447047c-5d61-4f5a-ea41-f5cc40dfa44b"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">Out[8]: (array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n",
" array([0.98917244, 1.00010707, 0.99075337, 1.00417889, 1.00010707,\n",
" 1.00417889, 1.00678529, 0.99932827, 0.99075337, 1.00010707,\n",
" 1.00010707, 1.00010707, 0.98917244, 0.99075337, 0.98917244,\n",
" 0.99932827, 1.00678529, 0.99932827, 1.00010707, 1.00010707,\n",
" 1.00678529, 0.99075337, 1.00678529, 0.99075337, 0.99932827,\n",
" 1.00417889, 1.00678529, 0.99932827, 0.98917244, 1.00678529]))</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">Out[8]: (array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n array([0.98917244, 1.00010707, 0.99075337, 1.00417889, 1.00010707,\n 1.00417889, 1.00678529, 0.99932827, 0.99075337, 1.00010707,\n 1.00010707, 1.00010707, 0.98917244, 0.99075337, 0.98917244,\n 0.99932827, 1.00678529, 0.99932827, 1.00010707, 1.00010707,\n 1.00678529, 0.99075337, 1.00678529, 0.99075337, 0.99932827,\n 1.00417889, 1.00678529, 0.99932827, 0.98917244, 1.00678529]))</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"# the proba for all classes doesn't sum up to one for 'objective': 'multiclassova'\n",
"np.sum(y_pred_proba,axis=1), np.sum(y_pred_proba_ova,axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "d6cf2282-dfc3-4910-adde-51813596d2e0",
"showTitle": false,
"title": ""
},
"id": "XMCVYPRwBBTD",
"outputId": "92b48857-ab4d-4760-e4e7-e53a11eb750f"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\"></div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\"></div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"import pandas as pd\n",
"pd.set_option('display.expand_frame_repr', False) # doesn't allow a row to print across multiple lines \n",
"pd.set_option('display.precision', 5)\n",
"pd.set_option('display.max_rows', 100)\n",
"pd.set_option('display.min_rows',30)\n",
"pd.set_option('display.max_columns', 100)\n",
"\n",
"np.set_printoptions(threshold=np.inf) # print full np array without truncation Cf: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "e77784c0-02e2-4580-9886-0e8bfc6c25a4",
"showTitle": false,
"title": ""
},
"id": "j7lpmuWkBBTD",
"outputId": "67b9b34b-8a20-4c42-f882-7de2fb6bf935"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">Out[10]: array([[-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n",
" [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n",
" [-1.15238622, -1.26514433, -0.92325589, 3. , 2. , 1. ],\n",
" [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 3. ],\n",
" [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n",
" [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 0. ],\n",
" [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n",
" [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n",
" [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n",
" [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n",
" [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n",
" [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n",
" [-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n",
" [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n",
" [-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n",
" [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n",
" [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n",
" [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n",
" [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n",
" [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n",
" [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n",
" [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n",
" [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n",
" [-1.15238622, -1.26514433, -0.92325589, 3. , 2. , 1. ],\n",
" [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n",
" [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 0. ],\n",
" [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n",
" [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 3. ],\n",
" [-1.15238622, -1.00724788, -1.1751855 , 1. , 3. , 3. ],\n",
" [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ]])</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">Out[10]: array([[-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n [-1.15238622, -1.26514433, -0.92325589, 3. , 2. , 1. ],\n [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 3. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 0. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n [-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n [-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-1.15238622, -1.26514433, -0.92325589, 3. , 2. , 1. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 0. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 3. ],\n [-1.15238622, -1.00724788, -1.1751855 , 1. , 3. , 3. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ]])</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"np.set_printoptions(linewidth=128)\n",
"np.hstack([y_pred_score,y_pred_leaf])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "a45b6ba3-a1f3-47ae-bae0-40c5acac4b51",
"showTitle": false,
"title": ""
},
"id": "xeMxnw4DBBTE",
"outputId": "0d8132c1-ec9e-4c11-abc7-f0d5e5ec4f94"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">(21, 14)\n",
"Out[11]: </div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">(21, 14)\nOut[11]: </div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tree_index</th>\n",
" <th>node_depth</th>\n",
" <th>left_child</th>\n",
" <th>right_child</th>\n",
" <th>parent_index</th>\n",
" <th>split_feature</th>\n",
" <th>split_gain</th>\n",
" <th>threshold</th>\n",
" <th>decision_type</th>\n",
" <th>missing_direction</th>\n",
" <th>missing_type</th>\n",
" <th>value</th>\n",
" <th>weight</th>\n",
" <th>count</th>\n",
" </tr>\n",
" <tr>\n",
" <th>node_index</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0-L0</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>0-S1</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-0.85935</td>\n",
" <td>7.16625</td>\n",
" <td>21</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0-L2</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>0-S1</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-0.87267</td>\n",
" <td>7.50750</td>\n",
" <td>22</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0-L1</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>0-S2</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-1.15239</td>\n",
" <td>15.69750</td>\n",
" <td>46</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0-L3</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>0-S2</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-1.15239</td>\n",
" <td>10.57875</td>\n",
" <td>31</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1-L0</th>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>1-S0</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-1.27296</td>\n",
" <td>12.79583</td>\n",
" <td>40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1-L1</th>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>1-S2</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-0.97599</td>\n",
" <td>6.39792</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1-L3</th>\n",
" <td>1</td>\n",
" <td>4</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>1-S2</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-1.00725</td>\n",
" <td>6.39792</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1-L2</th>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>1-S1</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-1.26514</td>\n",
" <td>12.79583</td>\n",
" <td>40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2-L0</th>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>2-S2</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-1.17519</td>\n",
" <td>16.53240</td>\n",
" <td>49</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2-L3</th>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>2-S2</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-1.17519</td>\n",
" <td>9.10969</td>\n",
" <td>27</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2-L1</th>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>2-S1</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-0.92326</td>\n",
" <td>6.74792</td>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2-L2</th>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>2-S1</td>\n",
" <td>None</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>None</td>\n",
" <td>-0.87880</td>\n",
" <td>8.09750</td>\n",
" <td>24</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>tree_index</th>\n <th>node_depth</th>\n <th>left_child</th>\n <th>right_child</th>\n <th>parent_index</th>\n <th>split_feature</th>\n <th>split_gain</th>\n <th>threshold</th>\n <th>decision_type</th>\n <th>missing_direction</th>\n <th>missing_type</th>\n <th>value</th>\n <th>weight</th>\n <th>count</th>\n </tr>\n <tr>\n <th>node_index</th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0-L0</th>\n <td>0</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>0-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.85935</td>\n <td>7.16625</td>\n <td>21</td>\n </tr>\n <tr>\n <th>0-L2</th>\n <td>0</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>0-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.87267</td>\n <td>7.50750</td>\n <td>22</td>\n </tr>\n <tr>\n <th>0-L1</th>\n <td>0</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>0-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.15239</td>\n <td>15.69750</td>\n <td>46</td>\n </tr>\n <tr>\n <th>0-L3</th>\n <td>0</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>0-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.15239</td>\n <td>10.57875</td>\n <td>31</td>\n </tr>\n <tr>\n <th>1-L0</th>\n <td>1</td>\n <td>2</td>\n <td>None</td>\n <td>None</td>\n <td>1-S0</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.27296</td>\n <td>12.79583</td>\n <td>40</td>\n </tr>\n <tr>\n <th>1-L1</th>\n <td>1</td>\n <td>4</td>\n <td>None</td>\n <td>None</td>\n <td>1-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.97599</td>\n <td>6.39792</td>\n <td>20</td>\n </tr>\n <tr>\n <th>1-L3</th>\n <td>1</td>\n <td>4</td>\n <td>None</td>\n <td>None</td>\n <td>1-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.00725</td>\n <td>6.39792</td>\n <td>20</td>\n </tr>\n <tr>\n <th>1-L2</th>\n <td>1</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>1-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.26514</td>\n <td>12.79583</td>\n <td>40</td>\n </tr>\n <tr>\n <th>2-L0</th>\n <td>2</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>2-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.17519</td>\n <td>16.53240</td>\n <td>49</td>\n </tr>\n <tr>\n <th>2-L3</th>\n <td>2</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>2-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.17519</td>\n <td>9.10969</td>\n <td>27</td>\n </tr>\n <tr>\n <th>2-L1</th>\n <td>2</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>2-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.92326</td>\n <td>6.74792</td>\n <td>20</td>\n </tr>\n <tr>\n <th>2-L2</th>\n <td>2</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>2-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.87880</td>\n <td>8.09750</td>\n <td>24</td>\n </tr>\n </tbody>\n</table>\n</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"textData": null,
"type": "htmlSandbox"
}
},
"output_type": "display_data"
}
],
"source": [
"# the score is 'value' of the leaf of the tree for the corresponding class\n",
"bst = model\n",
"df_tree = bst.trees_to_dataframe().set_index('node_index')\n",
"sel = df_tree['left_child'].isnull()\n",
"print(df_tree.shape)\n",
"df_tree[sel]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HQszhTXIBBTE"
},
"source": [
"\n",
"# multiclassova gives the same score for the binary classificiation problem for the corresponding class"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "4dfd32c2-c27d-4201-85f2-6bfa52aac240",
"showTitle": false,
"title": ""
},
"id": "FEpoZg6XBBTE",
"outputId": "d9b64e79-89e2-47c3-8812-fefe63460508"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">[LightGBM] [Info] Number of positive: 37, number of negative: 83\n",
"[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000051 seconds.\n",
"You can set `force_col_wise=true` to remove the overhead.\n",
"[LightGBM] [Info] Total Bins 90\n",
"[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n",
"[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.308333 -&gt; initscore=-0.807923\n",
"[LightGBM] [Info] Start training from score -0.807923\n",
"[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n",
"</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">[LightGBM] [Info] Number of positive: 37, number of negative: 83\n[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000051 seconds.\nYou can set `force_col_wise=true` to remove the overhead.\n[LightGBM] [Info] Total Bins 90\n[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.308333 -&gt; initscore=-0.807923\n[LightGBM] [Info] Start training from score -0.807923\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"y_train_bin = np.copy(y_train)\n",
"y_test_bin = np.copy(y_test)\n",
"y_train_bin[y_train_bin==2] = 0\n",
"y_test_bin[y_test_bin==2] = 0\n",
"params_bin = {\n",
" 'objective': 'binary',\n",
" 'num_class': 1,\n",
" 'metric': 'bin_logloss',\n",
" 'deterministic': True\n",
"}\n",
"train_data_bin = lgb.Dataset(X_train, label=y_train_bin)\n",
"test_data_bin = lgb.Dataset(X_test, label=y_test_bin)\n",
"model_bin = lgb.train(params=params_bin,\n",
" train_set= train_data_bin,\n",
" num_boost_round=1,\n",
" valid_sets=[train_data_bin, test_data_bin],\n",
" verbose_eval=10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "dd4afa3c-a337-4e9d-80ad-78ac8e95fc71",
"showTitle": false,
"title": ""
},
"id": "A4fUfOBaBBTE",
"outputId": "b6429be9-6f5e-4987-a52e-c6accc41a457"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">[[ 0.36495223 -0.55393376]\n",
" [ 0.37588685 -0.5070435 ]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.27838213 -0.95250101]\n",
" [ 0.37588685 -0.5070435 ]\n",
" [ 0.27838213 -0.95250101]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.27838213 -0.95250101]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.37588685 -0.5070435 ]\n",
" [ 0.37588685 -0.5070435 ]\n",
" [ 0.37588685 -0.5070435 ]\n",
" [ 0.36495223 -0.55393376]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.36495223 -0.55393376]\n",
" [ 0.27838213 -0.95250101]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.27838213 -0.95250101]\n",
" [ 0.37588685 -0.5070435 ]\n",
" [ 0.37588685 -0.5070435 ]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.27838213 -0.95250101]\n",
" [ 0.27838213 -0.95250101]\n",
" [ 0.28074313 -0.94077845]\n",
" [ 0.27838213 -0.95250101]\n",
" [ 0.36495223 -0.55393376]\n",
" [ 0.28074313 -0.94077845]]\n",
"</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">[[ 0.36495223 -0.55393376]\n [ 0.37588685 -0.5070435 ]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.37588685 -0.5070435 ]\n [ 0.27838213 -0.95250101]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.28074313 -0.94077845]\n [ 0.37588685 -0.5070435 ]\n [ 0.37588685 -0.5070435 ]\n [ 0.37588685 -0.5070435 ]\n [ 0.36495223 -0.55393376]\n [ 0.28074313 -0.94077845]\n [ 0.36495223 -0.55393376]\n [ 0.27838213 -0.95250101]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.37588685 -0.5070435 ]\n [ 0.37588685 -0.5070435 ]\n [ 0.28074313 -0.94077845]\n [ 0.28074313 -0.94077845]\n [ 0.28074313 -0.94077845]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.27838213 -0.95250101]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.36495223 -0.55393376]\n [ 0.28074313 -0.94077845]]\n</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"y_pred_leaf_bin = model_bin.predict(X_test, pred_leaf=True)\n",
"y_pred_proba_bin = model_bin.predict(X_test, raw_score=False)\n",
"y_pred_score_bin = model_bin.predict(X_test, raw_score=True)\n",
"print(np.hstack([y_pred_proba_bin[:,np.newaxis],y_pred_score_bin[:,np.newaxis]]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {
"byteLimit": 2048000,
"rowLimit": 10000
},
"inputWidgets": {},
"nuid": "2fe53b14-d9be-4a88-ac0c-41fe63c76621",
"showTitle": false,
"title": ""
},
"id": "f-2htuP9BBTF",
"outputId": "f819287a-4145-4001-cb2a-27bf58737cbb"
},
"outputs": [
{
"data": {
"text/html": [
"<style scoped>\n",
" .ansiout {\n",
" display: block;\n",
" unicode-bidi: embed;\n",
" white-space: pre-wrap;\n",
" word-wrap: break-word;\n",
" word-break: break-all;\n",
" font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n",
" font-size: 13px;\n",
" color: #555;\n",
" margin-left: 4px;\n",
" line-height: 19px;\n",
" }\n",
"</style>\n",
"<div class=\"ansiout\">Out[20]: (0.0, 0.0)</div>"
]
},
"metadata": {
"application/vnd.databricks.v1+output": {
"addedWidgets": {},
"arguments": {},
"data": "<div class=\"ansiout\">Out[20]: (0.0, 0.0)</div>",
"datasetInfos": [],
"metadata": {},
"removedWidgets": [],
"type": "html"
}
},
"output_type": "display_data"
}
],
"source": [
"np.linalg.norm(y_pred_score_bin - y_pred_score_ova[:,1]), np.linalg.norm(y_pred_proba_bin - y_pred_proba_ova[:,1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"application/vnd.databricks.v1+cell": {
"cellMetadata": {},
"inputWidgets": {},
"nuid": "d154b099-0da7-41c2-b3f5-c35aabb8439c",
"showTitle": false,
"title": ""
},
"id": "U7akV5sFBBTF"
},
"outputs": [],
"source": []
}
],
"metadata": {
"application/vnd.databricks.v1+notebook": {
"dashboards": [],
"language": "python",
"notebookMetadata": {
"pythonIndentUnit": 2
},
"notebookName": "202301",
"notebookOrigID": 3593247923918104,
"widgets": {}
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
},
"colab": {
"provenance": [],
"name": "lightgbm_multiclass.ipynb",
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment