{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tmva 1 0 1_ Training\n",
    "This tutorial show how you can train a machine learning model with any package\n",
    "reading the training data directly from ROOT files. Using XGBoost, we illustrate\n",
    "how you can convert an externally trained model in a format serializable and readable\n",
    "with the fast tree inference engine offered by TMVA.\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "**Author:** Stefan Wunsch  \n",
    "<i><small>This notebook tutorial was automatically generated with <a href= \"https://github.com/root-project/root/blob/master/documentation/doxygen/converttonotebook.py\">ROOTBOOK-izer</a> from the macro found in the ROOT repository  on Tuesday, May 24, 2022 at 05:45 PM.</small></i>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import ROOT\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "from tmva100_DataPreparation import variables\n",
    "\n",
    "\n",
    "def load_data(signal_filename, background_filename):\n",
    "    # Read data from ROOT files\n",
    "    data_sig = ROOT.RDataFrame(\"Events\", signal_filename).AsNumpy()\n",
    "    data_bkg = ROOT.RDataFrame(\"Events\", background_filename).AsNumpy()\n",
    "\n",
    "    # Convert inputs to format readable by machine learning tools\n",
    "    x_sig = np.vstack([data_sig[var] for var in variables]).T\n",
    "    x_bkg = np.vstack([data_bkg[var] for var in variables]).T\n",
    "    x = np.vstack([x_sig, x_bkg])\n",
    "\n",
    "    # Create labels\n",
    "    num_sig = x_sig.shape[0]\n",
    "    num_bkg = x_bkg.shape[0]\n",
    "    y = np.hstack([np.ones(num_sig), np.zeros(num_bkg)])\n",
    "\n",
    "    # Compute weights balancing both classes\n",
    "    num_all = num_sig + num_bkg\n",
    "    w = np.hstack([np.ones(num_sig) * num_all / num_sig, np.ones(num_bkg) * num_all / num_bkg])\n",
    "\n",
    "    return x, y, w\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    # Load data\n",
    "    x, y, w = load_data(\"train_signal.root\", \"train_background.root\")\n",
    "\n",
    "    # Fit xgboost model\n",
    "    from xgboost import XGBClassifier\n",
    "    bdt = XGBClassifier(max_depth=3, n_estimators=500)\n",
    "    bdt.fit(x, y, w)\n",
    "\n",
    "    # Save model in TMVA format\n",
    "    ROOT.TMVA.Experimental.SaveXGBoost(bdt, \"myBDT\", \"tmva101.root\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
