Last active
April 5, 2023 01:48
-
-
Save shiyuangu/960912bdced2c36648e35c70adad6213 to your computer and use it in GitHub Desktop.
lightgbm_multiclass.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/shiyuangu/960912bdced2c36648e35c70adad6213/notebook.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "8bb22cb3-6b96-4998-bb72-fffea5edacb1", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "bMF37kQvBBS-" | |
| }, | |
| "source": [ | |
| "# lightgbm for multiclass\n", | |
| "\n", | |
| "+ LightGBM use *one-vs-rest* fashion for multi-class. you can view it as *num_class* binary classification tasks, so there are *num_class* GBDTs.+ There are two objective functions for for multiclass. `multiclass` and `multiclassova`. They gives different scores for the same data point. `multiclassova` gives exactly the same score as the binary-classifcation problem for the corresponding class\n", | |
| "\n", | |
| "+ I will guess that the objective function for `multiclassova` is the [sigmoid loss](https://www.geeksforgeeks.org/sigmoid-cross-entropy-function-of-tensorflow/) while objective function for `multiclass` is the softmax (As in [SO](https://stats.stackexchange.com/questions/563718/formal-steps-for-gradient-boosting-with-softmax-and-cross-entropy-loss-function) , each `yi` for the class i is modeled by a booster, and the booster can be updated simultaneously. We can draw an analogy to one-layer perceptron: for two class problem, we can use sigmoid loss or softmax loss)\n", | |
| "\n", | |
| "+ Cf: [lgb-1518](https://github.com/microsoft/LightGBM/issues/1518), [lgb-1675](https://github.com/microsoft/LightGBM/issues/1675)\n", | |
| "\n", | |
| "#+begin_verse sigmoid loss\n", | |
| "loss = -(y_true log(sigmoid(y_pred)) + (1 - y_true) log(1 - sigmoid(y_pred)))\n", | |
| "#+end_verse\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "42c656de-94cd-43a9-b109-7cd47ba0057e", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "pQQOKfrvBBS_", | |
| "outputId": "dc423827-2bcd-4c58-942c-31e4c4d56a29" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">Out[1]: '3.3.2'</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">Out[1]: '3.3.2'</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "import lightgbm as lgb\n", | |
| "from sklearn.datasets import load_iris\n", | |
| "from sklearn.model_selection import train_test_split\n", | |
| "import numpy as np\n", | |
| "lgb.__version__" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "1d19b4b6-39cc-4a33-879d-76ddef6ef53d", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "lVSXd_TKBBTB", | |
| "outputId": "d9d50a2b-8acc-4dd3-a9b1-e82755fc41ea" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">/databricks/python/lib/python3.8/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n", | |
| " _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "\n", | |
| "[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.008743 seconds.\n", | |
| "You can set `force_row_wise=true` to remove the overhead.\n", | |
| "And if memory is not enough, you can set `force_col_wise=true`.\n", | |
| "[LightGBM] [Info] Total Bins 90\n", | |
| "[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n", | |
| "[LightGBM] [Info] Start training from score -1.049822\n", | |
| "[LightGBM] [Info] Start training from score -1.176574\n", | |
| "[LightGBM] [Info] Start training from score -1.073920\n", | |
| "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n", | |
| "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n", | |
| "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n", | |
| "[LightGBM] [Info] Number of positive: 42, number of negative: 78\n", | |
| "[LightGBM] [Info] Number of positive: 37, number of negative: 83\n", | |
| "[LightGBM] [Info] Number of positive: 41, number of negative: 79\n", | |
| "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000041 seconds.\n", | |
| "You can set `force_col_wise=true` to remove the overhead.\n", | |
| "[LightGBM] [Info] Total Bins 90\n", | |
| "[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n", | |
| "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.350000 -> initscore=-0.619039\n", | |
| "[LightGBM] [Info] Start training from score -0.619039\n", | |
| "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.308333 -> initscore=-0.807923\n", | |
| "[LightGBM] [Info] Start training from score -0.807923\n", | |
| "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.341667 -> initscore=-0.655876\n", | |
| "[LightGBM] [Info] Start training from score -0.655876\n", | |
| "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n", | |
| "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n", | |
| "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n", | |
| "</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">/databricks/python/lib/python3.8/site-packages/lightgbm/engine.py:239: UserWarning: 'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. Pass 'log_evaluation()' callback via 'callbacks' argument instead.\n _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "\n[LightGBM] [Warning] Auto-choosing row-wise multi-threading, the overhead of testing was 0.008743 seconds.\nYou can set `force_row_wise=true` to remove the overhead.\nAnd if memory is not enough, you can set `force_col_wise=true`.\n[LightGBM] [Info] Total Bins 90\n[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n[LightGBM] [Info] Start training from score -1.049822\n[LightGBM] [Info] Start training from score -1.176574\n[LightGBM] [Info] Start training from score -1.073920\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Info] Number of positive: 42, number of negative: 78\n[LightGBM] [Info] Number of positive: 37, number of negative: 83\n[LightGBM] [Info] Number of positive: 41, number of negative: 79\n[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000041 seconds.\nYou can set `force_col_wise=true` to remove the overhead.\n[LightGBM] [Info] Total Bins 90\n[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.350000 -> initscore=-0.619039\n[LightGBM] [Info] Start training from score -0.619039\n[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.308333 -> initscore=-0.807923\n[LightGBM] [Info] Start training from score -0.807923\n[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.341667 -> initscore=-0.655876\n[LightGBM] [Info] Start training from score -0.655876\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Cf: https://github.com/chaupmcs/multiclass_vs_multiclassova/blob/master/multiclass_vs_multiclassova.ipynb\n", | |
| "\n", | |
| "iris = load_iris()\n", | |
| "X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)\n", | |
| "\n", | |
| "train_data = lgb.Dataset(X_train, label=y_train)\n", | |
| "test_data = lgb.Dataset(X_test, label=y_test)\n", | |
| "\n", | |
| "params = {\n", | |
| " 'objective': 'multiclass',\n", | |
| " 'num_class': 3,\n", | |
| " 'metric': 'multi_logloss',\n", | |
| " 'deterministic': True\n", | |
| "}\n", | |
| "params2 = {\n", | |
| " 'objective': 'multiclassova',\n", | |
| " 'num_class': 3,\n", | |
| " 'metric': 'multi_logloss',\n", | |
| " 'deterministic': True\n", | |
| "}\n", | |
| "\n", | |
| "\n", | |
| "model = lgb.train(params=params,\n", | |
| " train_set=train_data,\n", | |
| " num_boost_round=1,\n", | |
| " valid_sets=[train_data, test_data],\n", | |
| " verbose_eval=10)\n", | |
| "model_ova = lgb.train(params=params2,\n", | |
| " train_set=train_data,\n", | |
| " num_boost_round=1,\n", | |
| " valid_sets=[train_data, test_data],\n", | |
| " verbose_eval=10)\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "a11fca25-22cb-42d6-8674-abdf5dabe48a", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "OEdFVD6kBBTB", | |
| "outputId": "4f84978d-e958-4670-c48d-66ec89083181" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">[[ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n", | |
| " [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n", | |
| " [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n", | |
| " [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n", | |
| " [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n", | |
| " [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n", | |
| " [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n", | |
| " [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n", | |
| " [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n", | |
| " [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n", | |
| " [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n", | |
| " [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n", | |
| " [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n", | |
| " [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n", | |
| " [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n", | |
| " [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n", | |
| " [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n", | |
| " [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n", | |
| " [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n", | |
| " [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n", | |
| " [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n", | |
| " [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n", | |
| " [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n", | |
| " [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n", | |
| " [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n", | |
| " [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n", | |
| " [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n", | |
| " [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n", | |
| " [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n", | |
| " [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]]\n", | |
| "</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">[[ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.3154204 0.37626921 0.3083104 -1.15238622 -0.9759877 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.31737239 0.28352994 0.39909767 -1.15238622 -1.26514433 -0.92325589]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.41833426 0.27662626 0.30503948 -0.85934594 -1.27295937 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]\n [ 0.41509666 0.27816598 0.30673736 -0.87266596 -1.27295937 -1.1751855 ]\n [ 0.31911585 0.36896159 0.31192256 -1.15238622 -1.00724788 -1.1751855 ]\n [ 0.31171677 0.27847739 0.40980584 -1.15238622 -1.26514433 -0.87879772]]\n</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "y_pred_leaf = model.predict(X_test, pred_leaf=True)\n", | |
| "y_pred_proba = model.predict(X_test, raw_score=False)\n", | |
| "y_pred_score = model.predict(X_test, raw_score=True)\n", | |
| "print(np.hstack([y_pred_proba,y_pred_score]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "44cb071b-84d9-4420-837d-9ca4668d02c4", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "L169NEQkBBTC", | |
| "outputId": "b94f7d92-9853-4ffa-a348-0d16683532c4" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">Out[4]: 3.3766115072321297e-16</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">Out[4]: 3.3766115072321297e-16</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# 'objective': 'multiclass' gives softmax \n", | |
| "_t1 = np.exp(y_pred_score)\n", | |
| "_t2 = np.sum(_t1, axis=1)\n", | |
| "_y = _t1/_t2.repeat(repeats=3).reshape(30,3)\n", | |
| "np.linalg.norm(_y - y_pred_proba)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "878915dc-2fce-411b-aa23-7e879d0295a7", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "N4HxDSLBBBTC", | |
| "outputId": "d7129ca5-2725-4cbf-b138-91ce9dc0a25b" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">[[ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n", | |
| " [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n", | |
| " [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n", | |
| " [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n", | |
| " [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n", | |
| " [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n", | |
| " [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n", | |
| " [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n", | |
| " [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n", | |
| " [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n", | |
| " [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n", | |
| " [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n", | |
| " [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n", | |
| " [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n", | |
| " [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n", | |
| " [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n", | |
| " [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n", | |
| " [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n", | |
| " [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n", | |
| " [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n", | |
| " [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n", | |
| " [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n", | |
| " [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n", | |
| " [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n", | |
| " [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n", | |
| " [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n", | |
| " [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n", | |
| " [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n", | |
| " [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n", | |
| " [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]]\n", | |
| "</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">[[ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.37588685 0.30836493 -0.77288536 -0.5070435 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.31585528 0.28074313 0.39415496 -0.77288536 -0.94077845 -0.4298801 ]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.41743183 0.27838213 0.30836493 -0.33332494 -0.95250101 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]\n [ 0.41258121 0.27838213 0.30836493 -0.35330496 -0.95250101 -0.80777452]\n [ 0.31585528 0.36495223 0.30836493 -0.77288536 -0.55393376 -0.80777452]\n [ 0.31585528 0.28074313 0.41018689 -0.77288536 -0.94077845 -0.36319285]]\n</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "y_pred_leaf_ova = model_ova.predict(X_test, pred_leaf=True)\n", | |
| "y_pred_proba_ova = model_ova.predict(X_test, raw_score=False)\n", | |
| "y_pred_score_ova = model_ova.predict(X_test, raw_score=True)\n", | |
| "print(np.hstack([y_pred_proba_ova,y_pred_score_ova]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "92d62b97-df11-4168-af67-9fad8f2b74ea", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "0-mIDG1YBBTC", | |
| "outputId": "b2e32bfd-3bbe-49b9-dc3a-aac9c70a2d35" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\"></div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\"></div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "def sigmoid(x):\n", | |
| " return 1.0 / (1.0 + np.exp(-x))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "f9218b6a-92c6-48c7-ade5-2906568b41af", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "-NQRA6Q9BBTC", | |
| "outputId": "ff596466-2230-4d17-d626-3a9d30d255f0" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">Out[7]: 0.0</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">Out[7]: 0.0</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# the proba is obtained from the sigmoid for 'objective': 'multiclassova'\n", | |
| "np.linalg.norm(y_pred_proba_ova - sigmoid(y_pred_score_ova))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "cdec4744-bd3b-4386-afc3-8d31b3275e6e", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "PJjMY94YBBTD", | |
| "outputId": "7447047c-5d61-4f5a-ea41-f5cc40dfa44b" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">Out[8]: (array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", | |
| " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n", | |
| " array([0.98917244, 1.00010707, 0.99075337, 1.00417889, 1.00010707,\n", | |
| " 1.00417889, 1.00678529, 0.99932827, 0.99075337, 1.00010707,\n", | |
| " 1.00010707, 1.00010707, 0.98917244, 0.99075337, 0.98917244,\n", | |
| " 0.99932827, 1.00678529, 0.99932827, 1.00010707, 1.00010707,\n", | |
| " 1.00678529, 0.99075337, 1.00678529, 0.99075337, 0.99932827,\n", | |
| " 1.00417889, 1.00678529, 0.99932827, 0.98917244, 1.00678529]))</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">Out[8]: (array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),\n array([0.98917244, 1.00010707, 0.99075337, 1.00417889, 1.00010707,\n 1.00417889, 1.00678529, 0.99932827, 0.99075337, 1.00010707,\n 1.00010707, 1.00010707, 0.98917244, 0.99075337, 0.98917244,\n 0.99932827, 1.00678529, 0.99932827, 1.00010707, 1.00010707,\n 1.00678529, 0.99075337, 1.00678529, 0.99075337, 0.99932827,\n 1.00417889, 1.00678529, 0.99932827, 0.98917244, 1.00678529]))</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# the proba for all classes doesn't sum up to one for 'objective': 'multiclassova'\n", | |
| "np.sum(y_pred_proba,axis=1), np.sum(y_pred_proba_ova,axis=1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "d6cf2282-dfc3-4910-adde-51813596d2e0", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "XMCVYPRwBBTD", | |
| "outputId": "92b48857-ab4d-4760-e4e7-e53a11eb750f" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\"></div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\"></div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "import pandas as pd\n", | |
| "pd.set_option('display.expand_frame_repr', False) # doesn't allow a row to print across multiple lines \n", | |
| "pd.set_option('display.precision', 5)\n", | |
| "pd.set_option('display.max_rows', 100)\n", | |
| "pd.set_option('display.min_rows',30)\n", | |
| "pd.set_option('display.max_columns', 100)\n", | |
| "\n", | |
| "np.set_printoptions(threshold=np.inf) # print full np array without truncation Cf: " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "e77784c0-02e2-4580-9886-0e8bfc6c25a4", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "j7lpmuWkBBTD", | |
| "outputId": "67b9b34b-8a20-4c42-f882-7de2fb6bf935" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">Out[10]: array([[-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n", | |
| " [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n", | |
| " [-1.15238622, -1.26514433, -0.92325589, 3. , 2. , 1. ],\n", | |
| " [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 3. ],\n", | |
| " [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n", | |
| " [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 0. ],\n", | |
| " [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n", | |
| " [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n", | |
| " [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n", | |
| " [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n", | |
| " [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n", | |
| " [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n", | |
| " [-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n", | |
| " [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n", | |
| " [-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n", | |
| " [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n", | |
| " [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n", | |
| " [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n", | |
| " [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n", | |
| " [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n", | |
| " [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n", | |
| " [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n", | |
| " [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n", | |
| " [-1.15238622, -1.26514433, -0.92325589, 3. , 2. , 1. ],\n", | |
| " [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n", | |
| " [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 0. ],\n", | |
| " [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n", | |
| " [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 3. ],\n", | |
| " [-1.15238622, -1.00724788, -1.1751855 , 1. , 3. , 3. ],\n", | |
| " [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ]])</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">Out[10]: array([[-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n [-1.15238622, -1.26514433, -0.92325589, 3. , 2. , 1. ],\n [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 3. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 0. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 0. ],\n [-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n [-1.15238622, -1.00724788, -1.1751855 , 3. , 3. , 3. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n [-1.15238622, -0.9759877 , -1.1751855 , 1. , 1. , 3. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-1.15238622, -1.26514433, -0.92325589, 1. , 2. , 1. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-1.15238622, -1.26514433, -0.92325589, 3. , 2. , 1. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 0. ],\n [-0.85934594, -1.27295937, -1.1751855 , 0. , 0. , 0. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ],\n [-0.87266596, -1.27295937, -1.1751855 , 2. , 0. , 3. ],\n [-1.15238622, -1.00724788, -1.1751855 , 1. , 3. , 3. ],\n [-1.15238622, -1.26514433, -0.87879772, 3. , 2. , 2. ]])</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "np.set_printoptions(linewidth=128)\n", | |
| "np.hstack([y_pred_score,y_pred_leaf])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "a45b6ba3-a1f3-47ae-bae0-40c5acac4b51", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "xeMxnw4DBBTE", | |
| "outputId": "0d8132c1-ec9e-4c11-abc7-f0d5e5ec4f94" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">(21, 14)\n", | |
| "Out[11]: </div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">(21, 14)\nOut[11]: </div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| }, | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>tree_index</th>\n", | |
| " <th>node_depth</th>\n", | |
| " <th>left_child</th>\n", | |
| " <th>right_child</th>\n", | |
| " <th>parent_index</th>\n", | |
| " <th>split_feature</th>\n", | |
| " <th>split_gain</th>\n", | |
| " <th>threshold</th>\n", | |
| " <th>decision_type</th>\n", | |
| " <th>missing_direction</th>\n", | |
| " <th>missing_type</th>\n", | |
| " <th>value</th>\n", | |
| " <th>weight</th>\n", | |
| " <th>count</th>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>node_index</th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " <th></th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0-L0</th>\n", | |
| " <td>0</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>0-S1</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-0.85935</td>\n", | |
| " <td>7.16625</td>\n", | |
| " <td>21</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>0-L2</th>\n", | |
| " <td>0</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>0-S1</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-0.87267</td>\n", | |
| " <td>7.50750</td>\n", | |
| " <td>22</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>0-L1</th>\n", | |
| " <td>0</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>0-S2</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-1.15239</td>\n", | |
| " <td>15.69750</td>\n", | |
| " <td>46</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>0-L3</th>\n", | |
| " <td>0</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>0-S2</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-1.15239</td>\n", | |
| " <td>10.57875</td>\n", | |
| " <td>31</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1-L0</th>\n", | |
| " <td>1</td>\n", | |
| " <td>2</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>1-S0</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-1.27296</td>\n", | |
| " <td>12.79583</td>\n", | |
| " <td>40</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1-L1</th>\n", | |
| " <td>1</td>\n", | |
| " <td>4</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>1-S2</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-0.97599</td>\n", | |
| " <td>6.39792</td>\n", | |
| " <td>20</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1-L3</th>\n", | |
| " <td>1</td>\n", | |
| " <td>4</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>1-S2</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-1.00725</td>\n", | |
| " <td>6.39792</td>\n", | |
| " <td>20</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1-L2</th>\n", | |
| " <td>1</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>1-S1</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-1.26514</td>\n", | |
| " <td>12.79583</td>\n", | |
| " <td>40</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2-L0</th>\n", | |
| " <td>2</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>2-S2</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-1.17519</td>\n", | |
| " <td>16.53240</td>\n", | |
| " <td>49</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2-L3</th>\n", | |
| " <td>2</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>2-S2</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-1.17519</td>\n", | |
| " <td>9.10969</td>\n", | |
| " <td>27</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2-L1</th>\n", | |
| " <td>2</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>2-S1</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-0.92326</td>\n", | |
| " <td>6.74792</td>\n", | |
| " <td>20</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2-L2</th>\n", | |
| " <td>2</td>\n", | |
| " <td>3</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>2-S1</td>\n", | |
| " <td>None</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>NaN</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>None</td>\n", | |
| " <td>-0.87880</td>\n", | |
| " <td>8.09750</td>\n", | |
| " <td>24</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>tree_index</th>\n <th>node_depth</th>\n <th>left_child</th>\n <th>right_child</th>\n <th>parent_index</th>\n <th>split_feature</th>\n <th>split_gain</th>\n <th>threshold</th>\n <th>decision_type</th>\n <th>missing_direction</th>\n <th>missing_type</th>\n <th>value</th>\n <th>weight</th>\n <th>count</th>\n </tr>\n <tr>\n <th>node_index</th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n <th></th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0-L0</th>\n <td>0</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>0-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.85935</td>\n <td>7.16625</td>\n <td>21</td>\n </tr>\n <tr>\n <th>0-L2</th>\n <td>0</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>0-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.87267</td>\n <td>7.50750</td>\n <td>22</td>\n </tr>\n <tr>\n <th>0-L1</th>\n <td>0</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>0-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.15239</td>\n <td>15.69750</td>\n <td>46</td>\n </tr>\n <tr>\n <th>0-L3</th>\n <td>0</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>0-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.15239</td>\n <td>10.57875</td>\n <td>31</td>\n </tr>\n <tr>\n <th>1-L0</th>\n <td>1</td>\n <td>2</td>\n <td>None</td>\n <td>None</td>\n <td>1-S0</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.27296</td>\n <td>12.79583</td>\n <td>40</td>\n </tr>\n <tr>\n <th>1-L1</th>\n <td>1</td>\n <td>4</td>\n <td>None</td>\n <td>None</td>\n <td>1-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.97599</td>\n <td>6.39792</td>\n <td>20</td>\n </tr>\n <tr>\n <th>1-L3</th>\n <td>1</td>\n <td>4</td>\n <td>None</td>\n <td>None</td>\n <td>1-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.00725</td>\n <td>6.39792</td>\n <td>20</td>\n </tr>\n <tr>\n <th>1-L2</th>\n <td>1</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>1-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.26514</td>\n <td>12.79583</td>\n <td>40</td>\n </tr>\n <tr>\n <th>2-L0</th>\n <td>2</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>2-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.17519</td>\n <td>16.53240</td>\n <td>49</td>\n </tr>\n <tr>\n <th>2-L3</th>\n <td>2</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>2-S2</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-1.17519</td>\n <td>9.10969</td>\n <td>27</td>\n </tr>\n <tr>\n <th>2-L1</th>\n <td>2</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>2-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.92326</td>\n <td>6.74792</td>\n <td>20</td>\n </tr>\n <tr>\n <th>2-L2</th>\n <td>2</td>\n <td>3</td>\n <td>None</td>\n <td>None</td>\n <td>2-S1</td>\n <td>None</td>\n <td>NaN</td>\n <td>NaN</td>\n <td>None</td>\n <td>None</td>\n <td>None</td>\n <td>-0.87880</td>\n <td>8.09750</td>\n <td>24</td>\n </tr>\n </tbody>\n</table>\n</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "textData": null, | |
| "type": "htmlSandbox" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# the score is 'value' of the leaf of the tree for the corresponding class\n", | |
| "bst = model\n", | |
| "df_tree = bst.trees_to_dataframe().set_index('node_index')\n", | |
| "sel = df_tree['left_child'].isnull()\n", | |
| "print(df_tree.shape)\n", | |
| "df_tree[sel]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "HQszhTXIBBTE" | |
| }, | |
| "source": [ | |
| "\n", | |
| "# multiclassova gives the same score for the binary classificiation problem for the corresponding class" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "4dfd32c2-c27d-4201-85f2-6bfa52aac240", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "FEpoZg6XBBTE", | |
| "outputId": "d9b64e79-89e2-47c3-8812-fefe63460508" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">[LightGBM] [Info] Number of positive: 37, number of negative: 83\n", | |
| "[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000051 seconds.\n", | |
| "You can set `force_col_wise=true` to remove the overhead.\n", | |
| "[LightGBM] [Info] Total Bins 90\n", | |
| "[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n", | |
| "[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.308333 -> initscore=-0.807923\n", | |
| "[LightGBM] [Info] Start training from score -0.807923\n", | |
| "[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n", | |
| "</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">[LightGBM] [Info] Number of positive: 37, number of negative: 83\n[LightGBM] [Warning] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000051 seconds.\nYou can set `force_col_wise=true` to remove the overhead.\n[LightGBM] [Info] Total Bins 90\n[LightGBM] [Info] Number of data points in the train set: 120, number of used features: 4\n[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.308333 -> initscore=-0.807923\n[LightGBM] [Info] Start training from score -0.807923\n[LightGBM] [Warning] No further splits with positive gain, best gain: -inf\n</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "y_train_bin = np.copy(y_train)\n", | |
| "y_test_bin = np.copy(y_test)\n", | |
| "y_train_bin[y_train_bin==2] = 0\n", | |
| "y_test_bin[y_test_bin==2] = 0\n", | |
| "params_bin = {\n", | |
| " 'objective': 'binary',\n", | |
| " 'num_class': 1,\n", | |
| " 'metric': 'bin_logloss',\n", | |
| " 'deterministic': True\n", | |
| "}\n", | |
| "train_data_bin = lgb.Dataset(X_train, label=y_train_bin)\n", | |
| "test_data_bin = lgb.Dataset(X_test, label=y_test_bin)\n", | |
| "model_bin = lgb.train(params=params_bin,\n", | |
| " train_set= train_data_bin,\n", | |
| " num_boost_round=1,\n", | |
| " valid_sets=[train_data_bin, test_data_bin],\n", | |
| " verbose_eval=10)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "dd4afa3c-a337-4e9d-80ad-78ac8e95fc71", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "A4fUfOBaBBTE", | |
| "outputId": "b6429be9-6f5e-4987-a52e-c6accc41a457" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">[[ 0.36495223 -0.55393376]\n", | |
| " [ 0.37588685 -0.5070435 ]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.27838213 -0.95250101]\n", | |
| " [ 0.37588685 -0.5070435 ]\n", | |
| " [ 0.27838213 -0.95250101]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.27838213 -0.95250101]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.37588685 -0.5070435 ]\n", | |
| " [ 0.37588685 -0.5070435 ]\n", | |
| " [ 0.37588685 -0.5070435 ]\n", | |
| " [ 0.36495223 -0.55393376]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.36495223 -0.55393376]\n", | |
| " [ 0.27838213 -0.95250101]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.27838213 -0.95250101]\n", | |
| " [ 0.37588685 -0.5070435 ]\n", | |
| " [ 0.37588685 -0.5070435 ]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.27838213 -0.95250101]\n", | |
| " [ 0.27838213 -0.95250101]\n", | |
| " [ 0.28074313 -0.94077845]\n", | |
| " [ 0.27838213 -0.95250101]\n", | |
| " [ 0.36495223 -0.55393376]\n", | |
| " [ 0.28074313 -0.94077845]]\n", | |
| "</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">[[ 0.36495223 -0.55393376]\n [ 0.37588685 -0.5070435 ]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.37588685 -0.5070435 ]\n [ 0.27838213 -0.95250101]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.28074313 -0.94077845]\n [ 0.37588685 -0.5070435 ]\n [ 0.37588685 -0.5070435 ]\n [ 0.37588685 -0.5070435 ]\n [ 0.36495223 -0.55393376]\n [ 0.28074313 -0.94077845]\n [ 0.36495223 -0.55393376]\n [ 0.27838213 -0.95250101]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.37588685 -0.5070435 ]\n [ 0.37588685 -0.5070435 ]\n [ 0.28074313 -0.94077845]\n [ 0.28074313 -0.94077845]\n [ 0.28074313 -0.94077845]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.27838213 -0.95250101]\n [ 0.28074313 -0.94077845]\n [ 0.27838213 -0.95250101]\n [ 0.36495223 -0.55393376]\n [ 0.28074313 -0.94077845]]\n</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "y_pred_leaf_bin = model_bin.predict(X_test, pred_leaf=True)\n", | |
| "y_pred_proba_bin = model_bin.predict(X_test, raw_score=False)\n", | |
| "y_pred_score_bin = model_bin.predict(X_test, raw_score=True)\n", | |
| "print(np.hstack([y_pred_proba_bin[:,np.newaxis],y_pred_score_bin[:,np.newaxis]]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": { | |
| "byteLimit": 2048000, | |
| "rowLimit": 10000 | |
| }, | |
| "inputWidgets": {}, | |
| "nuid": "2fe53b14-d9be-4a88-ac0c-41fe63c76621", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "f-2htuP9BBTF", | |
| "outputId": "f819287a-4145-4001-cb2a-27bf58737cbb" | |
| }, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/html": [ | |
| "<style scoped>\n", | |
| " .ansiout {\n", | |
| " display: block;\n", | |
| " unicode-bidi: embed;\n", | |
| " white-space: pre-wrap;\n", | |
| " word-wrap: break-word;\n", | |
| " word-break: break-all;\n", | |
| " font-family: \"Source Code Pro\", \"Menlo\", monospace;;\n", | |
| " font-size: 13px;\n", | |
| " color: #555;\n", | |
| " margin-left: 4px;\n", | |
| " line-height: 19px;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<div class=\"ansiout\">Out[20]: (0.0, 0.0)</div>" | |
| ] | |
| }, | |
| "metadata": { | |
| "application/vnd.databricks.v1+output": { | |
| "addedWidgets": {}, | |
| "arguments": {}, | |
| "data": "<div class=\"ansiout\">Out[20]: (0.0, 0.0)</div>", | |
| "datasetInfos": [], | |
| "metadata": {}, | |
| "removedWidgets": [], | |
| "type": "html" | |
| } | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "np.linalg.norm(y_pred_score_bin - y_pred_score_ova[:,1]), np.linalg.norm(y_pred_proba_bin - y_pred_proba_ova[:,1])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "application/vnd.databricks.v1+cell": { | |
| "cellMetadata": {}, | |
| "inputWidgets": {}, | |
| "nuid": "d154b099-0da7-41c2-b3f5-c35aabb8439c", | |
| "showTitle": false, | |
| "title": "" | |
| }, | |
| "id": "U7akV5sFBBTF" | |
| }, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "application/vnd.databricks.v1+notebook": { | |
| "dashboards": [], | |
| "language": "python", | |
| "notebookMetadata": { | |
| "pythonIndentUnit": 2 | |
| }, | |
| "notebookName": "202301", | |
| "notebookOrigID": 3593247923918104, | |
| "widgets": {} | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.9.12" | |
| }, | |
| "colab": { | |
| "provenance": [], | |
| "name": "lightgbm_multiclass.ipynb", | |
| "include_colab_link": true | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment