-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathestimator.pyn
1 lines (1 loc) · 64.4 KB
/
estimator.pyn
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"estimator.ipynb","version":"0.3.2","provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"TPU"},"cells":[{"metadata":{"id":"3YIz_1vuL1Kc","colab_type":"code","colab":{}},"cell_type":"code","source":["import tensorflow as tf\n","import tensorflow.feature_column as fc \n","\n","import os\n","import sys\n","\n","import matplotlib.pyplot as plt\n","from IPython.display import clear_output"],"execution_count":0,"outputs":[]},{"metadata":{"id":"X4HQiwsZMLRP","colab_type":"code","colab":{}},"cell_type":"code","source":["tf.enable_eager_execution()"],"execution_count":0,"outputs":[]},{"metadata":{"id":"czgazPmcMN7o","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":161},"outputId":"85f0996c-80d4-4d13-d3d7-76165d459ec1","executionInfo":{"status":"ok","timestamp":1543305827646,"user_tz":-330,"elapsed":62513,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["! pip install -q requests\n","! git clone --depth 1 https://github.com/tensorflow/models"],"execution_count":3,"outputs":[{"output_type":"stream","text":["Cloning into 'models'...\n","remote: Enumerating objects: 3053, done.\u001b[K\n","remote: Counting objects: 100% (3053/3053), done.\u001b[K\n","remote: Compressing objects: 100% (2562/2562), done.\u001b[K\n","remote: Total 3053 (delta 526), reused 2235 (delta 414), pack-reused 0\u001b[K\n","Receiving objects: 100% (3053/3053), 377.08 MiB | 8.89 MiB/s, done.\n","Resolving deltas: 100% (526/526), done.\n","Checking out files: 100% (2879/2879), done.\n"],"name":"stdout"}]},{"metadata":{"id":"QHG39oa2MY-m","colab_type":"code","colab":{}},"cell_type":"code","source":["models_path = os.path.join(os.getcwd(), 'models')\n","\n","sys.path.append(models_path)"],"execution_count":0,"outputs":[]},{"metadata":{"id":"HV8ZfCW6MvMB","colab_type":"code","colab":{}},"cell_type":"code","source":["from official.wide_deep import census_dataset\n","from official.wide_deep import census_main\n","\n","census_dataset.download(\"/tmp/census_data/\")"],"execution_count":0,"outputs":[]},{"metadata":{"id":"M55VhPgxNIEq","colab_type":"code","colab":{}},"cell_type":"code","source":["#export PYTHONPATH=${PYTHONPATH}:\"$(pwd)/models\"\n","#running from python you need to set the `os.environ` or the subprocess will not see the directory.\n","\n","if \"PYTHONPATH\" in os.environ:\n"," os.environ['PYTHONPATH'] += os.pathsep + models_path\n","else:\n"," os.environ['PYTHONPATH'] = models_path"],"execution_count":0,"outputs":[]},{"metadata":{"id":"OxXd3Y19NRTt","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":935},"outputId":"32d3cdb9-dc17-4ca0-e3ae-21ced4281993","executionInfo":{"status":"ok","timestamp":1543306039200,"user_tz":-330,"elapsed":8128,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["!python -m official.wide_deep.census_main --help"],"execution_count":8,"outputs":[{"output_type":"stream","text":["Train DNN on census income dataset.\n","flags:\n","\n","/content/models/official/wide_deep/census_main.py:\n"," -bs,--batch_size:\n"," Batch size for training and evaluation. When using multiple gpus, this is\n"," the\n"," global batch size for all devices. For example, if the batch size is 32 and\n"," there are 4 GPUs, each GPU will get 8 examples on each step.\n"," (default: '40')\n"," (an integer)\n"," --[no]clean:\n"," If set, model_dir will be removed if it exists.\n"," (default: 'false')\n"," -dd,--data_dir:\n"," The location of the input data.\n"," (default: '/tmp/census_data')\n"," --[no]download_if_missing:\n"," Download data to data_dir if it is not already present.\n"," (default: 'true')\n"," -ebe,--epochs_between_evals:\n"," The number of training epochs to run between evaluations.\n"," (default: '2')\n"," (an integer)\n"," -ed,--export_dir:\n"," If set, a SavedModel serialization of the model will be exported to this\n"," directory at the end of training. See the README for more details and\n"," relevant\n"," links.\n"," -hk,--hooks:\n"," A list of (case insensitive) strings to specify the names of training hooks.\n"," Hook:\n"," loggingtensorhook\n"," profilerhook\n"," examplespersecondhook\n"," loggingmetrichook\n"," Example: `--hooks ProfilerHook,ExamplesPerSecondHook`\n"," See official.utils.logs.hooks_helper for details.\n"," (default: 'LoggingTensorHook')\n"," (a comma separated list)\n"," -md,--model_dir:\n"," The location of the model checkpoint files.\n"," (default: '/tmp/census_model')\n"," -mt,--model_type: <wide|deep|wide_deep>: Select model topology.\n"," (default: 'wide_deep')\n"," -te,--train_epochs:\n"," The number of epochs used to train.\n"," (default: '40')\n"," (an integer)\n","\n","Try --helpfull to get a list of all flags.\n"],"name":"stdout"}]},{"metadata":{"id":"rdUKcXEZNXo9","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":3619},"outputId":"1135617c-8fb5-4454-f32b-99d177d81645","executionInfo":{"status":"ok","timestamp":1543306100025,"user_tz":-330,"elapsed":40593,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["!python -m official.wide_deep.census_main --model_type=wide --train_epochs=4"],"execution_count":9,"outputs":[{"output_type":"stream","text":["I1127 08:07:43.596814 139937464895360 tf_logging.py:115] Using config: {'_model_dir': '/tmp/census_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': device_count {\n"," key: \"GPU\"\n","}\n",", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f45912f74e0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n","W1127 08:07:43.598325 139937464895360 tf_logging.py:120] 'cpuinfo' not imported. CPU info will not be logged.\n","I1127 08:07:48.702143 139937464895360 tf_logging.py:115] Benchmark run: {'model_name': 'wide_deep', 'dataset': {'name': 'Census Income'}, 'machine_config': {'gpu_info': {'count': 0}, 'memory_total': 13655257088, 'memory_available': 12363923456}, 'test_id': None, 'run_date': '2018-11-27T08:07:43.597771Z', 'tensorflow_version': {'version': '1.12.0', 'git_hash': 'v1.12.0-0-ga6d8ffae09'}, 'tensorflow_environment_variables': [{'name': 'TF_FORCE_GPU_ALLOW_GROWTH', 'value': 'true'}], 'run_parameters': [{'name': 'batch_size', 'long_value': 40}, {'name': 'model_type', 'string_value': 'wide'}, {'name': 'train_epochs', 'long_value': 4}]}\n","I1127 08:07:48.731210 139937464895360 tf_logging.py:115] Parsing /tmp/census_data/adult.data\n","I1127 08:07:48.787902 139937464895360 tf_logging.py:115] Calling model_fn.\n","I1127 08:07:50.158993 139937464895360 tf_logging.py:115] Done calling model_fn.\n","I1127 08:07:50.159475 139937464895360 tf_logging.py:115] Create CheckpointSaverHook.\n","I1127 08:07:50.757225 139937464895360 tf_logging.py:115] Graph was finalized.\n","I1127 08:07:50.875898 139937464895360 tf_logging.py:115] Running local_init_op.\n","I1127 08:07:50.902672 139937464895360 tf_logging.py:115] Done running local_init_op.\n","I1127 08:07:51.791126 139937464895360 tf_logging.py:115] Saving checkpoints for 0 into /tmp/census_model/model.ckpt.\n","I1127 08:07:52.650137 139937464895360 tf_logging.py:115] average_loss = 0.6931472, loss = 27.725887\n","I1127 08:07:52.650769 139937464895360 tf_logging.py:115] loss = 27.725887, step = 1\n","I1127 08:07:53.479779 139937464895360 tf_logging.py:115] global_step/sec: 120.443\n","I1127 08:07:53.480746 139937464895360 tf_logging.py:115] average_loss = 0.46968547, loss = 18.787418 (0.831 sec)\n","I1127 08:07:53.481006 139937464895360 tf_logging.py:115] loss = 18.787418, step = 101 (0.830 sec)\n","I1127 08:07:53.851895 139937464895360 tf_logging.py:115] global_step/sec: 268.658\n","I1127 08:07:53.852893 139937464895360 tf_logging.py:115] average_loss = 0.27091664, loss = 10.836665 (0.372 sec)\n","I1127 08:07:53.853144 139937464895360 tf_logging.py:115] loss = 10.836665, step = 201 (0.372 sec)\n","I1127 08:07:54.211934 139937464895360 tf_logging.py:115] global_step/sec: 277.743\n","I1127 08:07:54.212884 139937464895360 tf_logging.py:115] average_loss = 0.33863768, loss = 13.545507 (0.360 sec)\n","I1127 08:07:54.213126 139937464895360 tf_logging.py:115] loss = 13.545507, step = 301 (0.360 sec)\n","I1127 08:07:54.601094 139937464895360 tf_logging.py:115] global_step/sec: 256.976\n","I1127 08:07:54.602104 139937464895360 tf_logging.py:115] average_loss = 0.4212533, loss = 16.850132 (0.389 sec)\n","I1127 08:07:54.602350 139937464895360 tf_logging.py:115] loss = 16.850132, step = 401 (0.389 sec)\n","I1127 08:07:54.972352 139937464895360 tf_logging.py:115] global_step/sec: 269.351\n","I1127 08:07:54.973420 139937464895360 tf_logging.py:115] average_loss = 0.30443913, loss = 12.177565 (0.371 sec)\n","I1127 08:07:54.973792 139937464895360 tf_logging.py:115] loss = 12.177565, step = 501 (0.371 sec)\n","I1127 08:07:55.335776 139937464895360 tf_logging.py:115] global_step/sec: 275.213\n","I1127 08:07:55.336746 139937464895360 tf_logging.py:115] average_loss = 0.29009664, loss = 11.603866 (0.363 sec)\n","I1127 08:07:55.337017 139937464895360 tf_logging.py:115] loss = 11.603866, step = 601 (0.363 sec)\n","I1127 08:07:55.695524 139937464895360 tf_logging.py:115] global_step/sec: 277.911\n","I1127 08:07:55.697038 139937464895360 tf_logging.py:115] average_loss = 0.33260164, loss = 13.304066 (0.360 sec)\n","I1127 08:07:55.697303 139937464895360 tf_logging.py:115] loss = 13.304066, step = 701 (0.360 sec)\n","I1127 08:07:56.051309 139937464895360 tf_logging.py:115] global_step/sec: 281.069\n","I1127 08:07:56.052269 139937464895360 tf_logging.py:115] average_loss = 0.19858952, loss = 7.9435806 (0.355 sec)\n","I1127 08:07:56.052630 139937464895360 tf_logging.py:115] loss = 7.9435806, step = 801 (0.355 sec)\n","I1127 08:07:56.504485 139937464895360 tf_logging.py:115] global_step/sec: 220.67\n","I1127 08:07:56.505445 139937464895360 tf_logging.py:115] average_loss = 0.33129963, loss = 13.251986 (0.453 sec)\n","I1127 08:07:56.505778 139937464895360 tf_logging.py:115] loss = 13.251986, step = 901 (0.453 sec)\n","I1127 08:07:56.884061 139937464895360 tf_logging.py:115] global_step/sec: 263.445\n","I1127 08:07:56.885076 139937464895360 tf_logging.py:115] average_loss = 0.27387685, loss = 10.955073 (0.380 sec)\n","I1127 08:07:56.885320 139937464895360 tf_logging.py:115] loss = 10.955073, step = 1001 (0.380 sec)\n","I1127 08:07:57.246779 139937464895360 tf_logging.py:115] global_step/sec: 275.762\n","I1127 08:07:57.247735 139937464895360 tf_logging.py:115] average_loss = 0.27138096, loss = 10.855239 (0.363 sec)\n","I1127 08:07:57.247984 139937464895360 tf_logging.py:115] loss = 10.855239, step = 1101 (0.363 sec)\n","I1127 08:07:57.616047 139937464895360 tf_logging.py:115] global_step/sec: 270.758\n","I1127 08:07:57.617026 139937464895360 tf_logging.py:115] average_loss = 0.34365693, loss = 13.746277 (0.369 sec)\n","I1127 08:07:57.617273 139937464895360 tf_logging.py:115] loss = 13.746277, step = 1201 (0.369 sec)\n","I1127 08:07:57.940456 139937464895360 tf_logging.py:115] global_step/sec: 308.218\n","I1127 08:07:57.941362 139937464895360 tf_logging.py:115] average_loss = 0.47960597, loss = 19.184238 (0.324 sec)\n","I1127 08:07:57.941719 139937464895360 tf_logging.py:115] loss = 19.184238, step = 1301 (0.324 sec)\n","I1127 08:07:58.265777 139937464895360 tf_logging.py:115] global_step/sec: 307.43\n","I1127 08:07:58.266605 139937464895360 tf_logging.py:115] average_loss = 0.329755, loss = 13.1902 (0.325 sec)\n","I1127 08:07:58.266871 139937464895360 tf_logging.py:115] loss = 13.1902, step = 1401 (0.325 sec)\n","I1127 08:07:58.584824 139937464895360 tf_logging.py:115] global_step/sec: 313.408\n","I1127 08:07:58.585636 139937464895360 tf_logging.py:115] average_loss = 0.3187499, loss = 12.749996 (0.319 sec)\n","I1127 08:07:58.585861 139937464895360 tf_logging.py:115] loss = 12.749996, step = 1501 (0.319 sec)\n","I1127 08:07:58.887100 139937464895360 tf_logging.py:115] global_step/sec: 330.794\n","I1127 08:07:58.887975 139937464895360 tf_logging.py:115] average_loss = 0.32686654, loss = 13.074661 (0.302 sec)\n","I1127 08:07:58.888185 139937464895360 tf_logging.py:115] loss = 13.074661, step = 1601 (0.302 sec)\n","I1127 08:07:58.981289 139937464895360 tf_logging.py:115] Saving checkpoints for 1629 into /tmp/census_model/model.ckpt.\n","I1127 08:07:59.195343 139937464895360 tf_logging.py:115] Loss for final step: 0.3833938.\n","I1127 08:07:59.209255 139937464895360 tf_logging.py:115] Parsing /tmp/census_data/adult.test\n","I1127 08:07:59.242372 139937464895360 tf_logging.py:115] Calling model_fn.\n","W1127 08:08:00.628344 139937464895360 tf_logging.py:125] Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n","W1127 08:08:00.653543 139937464895360 tf_logging.py:125] Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n","I1127 08:08:00.677987 139937464895360 tf_logging.py:115] Done calling model_fn.\n","I1127 08:08:00.702668 139937464895360 tf_logging.py:115] Starting evaluation at 2018-11-27-08:08:00\n","I1127 08:08:00.863348 139937464895360 tf_logging.py:115] Graph was finalized.\n","I1127 08:08:00.866291 139937464895360 tf_logging.py:115] Restoring parameters from /tmp/census_model/model.ckpt-1629\n","I1127 08:08:00.959494 139937464895360 tf_logging.py:115] Running local_init_op.\n","I1127 08:08:01.001961 139937464895360 tf_logging.py:115] Done running local_init_op.\n","I1127 08:08:03.310803 139937464895360 tf_logging.py:115] Finished evaluation at 2018-11-27-08:08:03\n","I1127 08:08:03.311129 139937464895360 tf_logging.py:115] Saving dict for global step 1629: accuracy = 0.836128, accuracy_baseline = 0.76377374, auc = 0.88423723, auc_precision_recall = 0.6953511, average_loss = 0.35064015, global_step = 1629, label/mean = 0.23622628, loss = 13.992089, precision = 0.69123375, prediction/mean = 0.23587264, recall = 0.55356216\n","I1127 08:08:03.688581 139937464895360 tf_logging.py:115] Saving 'checkpoint_path' summary for global step 1629: /tmp/census_model/model.ckpt-1629\n","I1127 08:08:03.689368 139937464895360 tf_logging.py:115] Results at epoch 2 / 4\n","I1127 08:08:03.689507 139937464895360 tf_logging.py:115] ------------------------------------------------------------\n","I1127 08:08:03.689632 139937464895360 tf_logging.py:115] accuracy: 0.836128\n","I1127 08:08:03.689713 139937464895360 tf_logging.py:115] accuracy_baseline: 0.76377374\n","I1127 08:08:03.689788 139937464895360 tf_logging.py:115] auc: 0.88423723\n","I1127 08:08:03.689870 139937464895360 tf_logging.py:115] auc_precision_recall: 0.6953511\n","I1127 08:08:03.689944 139937464895360 tf_logging.py:115] average_loss: 0.35064015\n","I1127 08:08:03.690024 139937464895360 tf_logging.py:115] global_step: 1629\n","I1127 08:08:03.690094 139937464895360 tf_logging.py:115] label/mean: 0.23622628\n","I1127 08:08:03.690162 139937464895360 tf_logging.py:115] loss: 13.992089\n","I1127 08:08:03.690230 139937464895360 tf_logging.py:115] precision: 0.69123375\n","I1127 08:08:03.690298 139937464895360 tf_logging.py:115] prediction/mean: 0.23587264\n","I1127 08:08:03.690367 139937464895360 tf_logging.py:115] recall: 0.55356216\n","I1127 08:08:03.690540 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'accuracy', 'value': 0.8361279964447021, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.690498Z', 'extras': []}\n","I1127 08:08:03.690697 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'accuracy_baseline', 'value': 0.7637737393379211, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.690673Z', 'extras': []}\n","I1127 08:08:03.690803 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'auc', 'value': 0.8842372298240662, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.690781Z', 'extras': []}\n","I1127 08:08:03.690909 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'auc_precision_recall', 'value': 0.6953511238098145, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.690889Z', 'extras': []}\n","I1127 08:08:03.691008 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'average_loss', 'value': 0.3506401479244232, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.690989Z', 'extras': []}\n","I1127 08:08:03.691108 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'label/mean', 'value': 0.23622627556324005, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.691088Z', 'extras': []}\n","I1127 08:08:03.691207 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'loss', 'value': 13.99208927154541, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.691190Z', 'extras': []}\n","I1127 08:08:03.691303 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'precision', 'value': 0.69123375415802, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.691287Z', 'extras': []}\n","I1127 08:08:03.691400 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'prediction/mean', 'value': 0.23587264120578766, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.691383Z', 'extras': []}\n","I1127 08:08:03.691496 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'recall', 'value': 0.5535621643066406, 'unit': None, 'global_step': 1629, 'timestamp': '2018-11-27T08:08:03.691480Z', 'extras': []}\n","I1127 08:08:03.709854 139937464895360 tf_logging.py:115] Parsing /tmp/census_data/adult.data\n","I1127 08:08:03.743275 139937464895360 tf_logging.py:115] Calling model_fn.\n","I1127 08:08:05.010474 139937464895360 tf_logging.py:115] Done calling model_fn.\n","I1127 08:08:05.010934 139937464895360 tf_logging.py:115] Create CheckpointSaverHook.\n","I1127 08:08:05.266838 139937464895360 tf_logging.py:115] Graph was finalized.\n","I1127 08:08:05.269242 139937464895360 tf_logging.py:115] Restoring parameters from /tmp/census_model/model.ckpt-1629\n","I1127 08:08:05.373376 139937464895360 tf_logging.py:115] Running local_init_op.\n","I1127 08:08:05.405231 139937464895360 tf_logging.py:115] Done running local_init_op.\n","I1127 08:08:06.416283 139937464895360 tf_logging.py:115] Saving checkpoints for 1629 into /tmp/census_model/model.ckpt.\n","I1127 08:08:07.189285 139937464895360 tf_logging.py:115] average_loss = 0.1843472, loss = 7.373888\n","I1127 08:08:07.189635 139937464895360 tf_logging.py:115] loss = 7.373888, step = 1630\n","I1127 08:08:07.941508 139937464895360 tf_logging.py:115] global_step/sec: 132.872\n","I1127 08:08:07.942368 139937464895360 tf_logging.py:115] average_loss = 0.3856508, loss = 15.426033 (0.753 sec)\n","I1127 08:08:07.942588 139937464895360 tf_logging.py:115] loss = 15.426033, step = 1730 (0.753 sec)\n","I1127 08:08:08.271078 139937464895360 tf_logging.py:115] global_step/sec: 303.455\n","I1127 08:08:08.272059 139937464895360 tf_logging.py:115] average_loss = 0.3742516, loss = 14.970064 (0.330 sec)\n","I1127 08:08:08.272405 139937464895360 tf_logging.py:115] loss = 14.970064, step = 1830 (0.330 sec)\n","I1127 08:08:08.592138 139937464895360 tf_logging.py:115] global_step/sec: 311.426\n","I1127 08:08:08.592994 139937464895360 tf_logging.py:115] average_loss = 0.23632841, loss = 9.453136 (0.321 sec)\n","I1127 08:08:08.593194 139937464895360 tf_logging.py:115] loss = 9.453136, step = 1930 (0.321 sec)\n","I1127 08:08:08.901837 139937464895360 tf_logging.py:115] global_step/sec: 322.902\n","I1127 08:08:08.902703 139937464895360 tf_logging.py:115] average_loss = 0.2470378, loss = 9.881512 (0.310 sec)\n","I1127 08:08:08.903037 139937464895360 tf_logging.py:115] loss = 9.881512, step = 2030 (0.310 sec)\n","I1127 08:08:09.209832 139937464895360 tf_logging.py:115] global_step/sec: 324.701\n","I1127 08:08:09.210710 139937464895360 tf_logging.py:115] average_loss = 0.30908054, loss = 12.363222 (0.308 sec)\n","I1127 08:08:09.210924 139937464895360 tf_logging.py:115] loss = 12.363222, step = 2130 (0.308 sec)\n","I1127 08:08:09.537304 139937464895360 tf_logging.py:115] global_step/sec: 305.356\n","I1127 08:08:09.538141 139937464895360 tf_logging.py:115] average_loss = 0.43621874, loss = 17.44875 (0.327 sec)\n","I1127 08:08:09.538358 139937464895360 tf_logging.py:115] loss = 17.44875, step = 2230 (0.327 sec)\n","I1127 08:08:09.845594 139937464895360 tf_logging.py:115] global_step/sec: 324.411\n","I1127 08:08:09.846725 139937464895360 tf_logging.py:115] average_loss = 0.30277234, loss = 12.110893 (0.309 sec)\n","I1127 08:08:09.847110 139937464895360 tf_logging.py:115] loss = 12.110893, step = 2330 (0.309 sec)\n","I1127 08:08:10.168853 139937464895360 tf_logging.py:115] global_step/sec: 309.323\n","I1127 08:08:10.169778 139937464895360 tf_logging.py:115] average_loss = 0.2955112, loss = 11.820448 (0.323 sec)\n","I1127 08:08:10.170096 139937464895360 tf_logging.py:115] loss = 11.820448, step = 2430 (0.323 sec)\n","I1127 08:08:10.562435 139937464895360 tf_logging.py:115] global_step/sec: 254.076\n","I1127 08:08:10.563301 139937464895360 tf_logging.py:115] average_loss = 0.32810897, loss = 13.124359 (0.394 sec)\n","I1127 08:08:10.563669 139937464895360 tf_logging.py:115] loss = 13.124359, step = 2530 (0.394 sec)\n","I1127 08:08:10.903673 139937464895360 tf_logging.py:115] global_step/sec: 293.061\n","I1127 08:08:10.904599 139937464895360 tf_logging.py:115] average_loss = 0.5822643, loss = 23.290573 (0.341 sec)\n","I1127 08:08:10.904953 139937464895360 tf_logging.py:115] loss = 23.290573, step = 2630 (0.341 sec)\n","I1127 08:08:11.250682 139937464895360 tf_logging.py:115] global_step/sec: 288.184\n","I1127 08:08:11.251506 139937464895360 tf_logging.py:115] average_loss = 0.31346822, loss = 12.538729 (0.347 sec)\n","I1127 08:08:11.251856 139937464895360 tf_logging.py:115] loss = 12.538729, step = 2730 (0.347 sec)\n","I1127 08:08:11.592947 139937464895360 tf_logging.py:115] global_step/sec: 292.171\n","I1127 08:08:11.593856 139937464895360 tf_logging.py:115] average_loss = 0.42921144, loss = 17.168457 (0.342 sec)\n","I1127 08:08:11.594071 139937464895360 tf_logging.py:115] loss = 17.168457, step = 2830 (0.342 sec)\n","I1127 08:08:11.905867 139937464895360 tf_logging.py:115] global_step/sec: 319.54\n","I1127 08:08:11.906746 139937464895360 tf_logging.py:115] average_loss = 0.5771195, loss = 23.084782 (0.313 sec)\n","I1127 08:08:11.907114 139937464895360 tf_logging.py:115] loss = 23.084782, step = 2930 (0.313 sec)\n","I1127 08:08:12.245632 139937464895360 tf_logging.py:115] global_step/sec: 294.341\n","I1127 08:08:12.246529 139937464895360 tf_logging.py:115] average_loss = 0.21133348, loss = 8.45334 (0.340 sec)\n","I1127 08:08:12.247070 139937464895360 tf_logging.py:115] loss = 8.45334, step = 3030 (0.340 sec)\n","I1127 08:08:12.581545 139937464895360 tf_logging.py:115] global_step/sec: 297.674\n","I1127 08:08:12.582428 139937464895360 tf_logging.py:115] average_loss = 0.33993918, loss = 13.597567 (0.336 sec)\n","I1127 08:08:12.582686 139937464895360 tf_logging.py:115] loss = 13.597567, step = 3130 (0.336 sec)\n","I1127 08:08:12.907716 139937464895360 tf_logging.py:115] global_step/sec: 306.615\n","I1127 08:08:12.908648 139937464895360 tf_logging.py:115] average_loss = 0.33410692, loss = 13.364277 (0.326 sec)\n","I1127 08:08:12.908975 139937464895360 tf_logging.py:115] loss = 13.364277, step = 3230 (0.326 sec)\n","I1127 08:08:13.000138 139937464895360 tf_logging.py:115] Saving checkpoints for 3258 into /tmp/census_model/model.ckpt.\n","I1127 08:08:13.229955 139937464895360 tf_logging.py:115] Loss for final step: 0.114464805.\n","I1127 08:08:13.240944 139937464895360 tf_logging.py:115] Parsing /tmp/census_data/adult.test\n","I1127 08:08:13.273993 139937464895360 tf_logging.py:115] Calling model_fn.\n","W1127 08:08:14.654273 139937464895360 tf_logging.py:125] Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n","W1127 08:08:14.677444 139937464895360 tf_logging.py:125] Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n","I1127 08:08:14.700820 139937464895360 tf_logging.py:115] Done calling model_fn.\n","I1127 08:08:14.725973 139937464895360 tf_logging.py:115] Starting evaluation at 2018-11-27-08:08:14\n","I1127 08:08:14.877731 139937464895360 tf_logging.py:115] Graph was finalized.\n","I1127 08:08:14.879782 139937464895360 tf_logging.py:115] Restoring parameters from /tmp/census_model/model.ckpt-3258\n","I1127 08:08:14.988103 139937464895360 tf_logging.py:115] Running local_init_op.\n","I1127 08:08:15.035158 139937464895360 tf_logging.py:115] Done running local_init_op.\n","I1127 08:08:17.381218 139937464895360 tf_logging.py:115] Finished evaluation at 2018-11-27-08:08:17\n","I1127 08:08:17.381540 139937464895360 tf_logging.py:115] Saving dict for global step 3258: accuracy = 0.83600515, accuracy_baseline = 0.76377374, auc = 0.88373554, auc_precision_recall = 0.6949612, average_loss = 0.3515615, global_step = 3258, label/mean = 0.23622628, loss = 14.028854, precision = 0.68726116, prediction/mean = 0.23725261, recall = 0.56110245\n","I1127 08:08:17.382174 139937464895360 tf_logging.py:115] Saving 'checkpoint_path' summary for global step 3258: /tmp/census_model/model.ckpt-3258\n","I1127 08:08:17.382891 139937464895360 tf_logging.py:115] Results at epoch 4 / 4\n","I1127 08:08:17.383020 139937464895360 tf_logging.py:115] ------------------------------------------------------------\n","I1127 08:08:17.383113 139937464895360 tf_logging.py:115] accuracy: 0.83600515\n","I1127 08:08:17.383188 139937464895360 tf_logging.py:115] accuracy_baseline: 0.76377374\n","I1127 08:08:17.383261 139937464895360 tf_logging.py:115] auc: 0.88373554\n","I1127 08:08:17.383333 139937464895360 tf_logging.py:115] auc_precision_recall: 0.6949612\n","I1127 08:08:17.383406 139937464895360 tf_logging.py:115] average_loss: 0.3515615\n","I1127 08:08:17.383484 139937464895360 tf_logging.py:115] global_step: 3258\n","I1127 08:08:17.383564 139937464895360 tf_logging.py:115] label/mean: 0.23622628\n","I1127 08:08:17.383710 139937464895360 tf_logging.py:115] loss: 14.028854\n","I1127 08:08:17.383789 139937464895360 tf_logging.py:115] precision: 0.68726116\n","I1127 08:08:17.383867 139937464895360 tf_logging.py:115] prediction/mean: 0.23725261\n","I1127 08:08:17.383934 139937464895360 tf_logging.py:115] recall: 0.56110245\n","I1127 08:08:17.384071 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'accuracy', 'value': 0.8360051512718201, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384030Z', 'extras': []}\n","I1127 08:08:17.384190 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'accuracy_baseline', 'value': 0.7637737393379211, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384167Z', 'extras': []}\n","I1127 08:08:17.384295 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'auc', 'value': 0.8837355375289917, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384277Z', 'extras': []}\n","I1127 08:08:17.384394 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'auc_precision_recall', 'value': 0.6949611902236938, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384376Z', 'extras': []}\n","I1127 08:08:17.384489 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'average_loss', 'value': 0.3515614867210388, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384472Z', 'extras': []}\n","I1127 08:08:17.384681 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'label/mean', 'value': 0.23622627556324005, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384608Z', 'extras': []}\n","I1127 08:08:17.384795 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'loss', 'value': 14.028854370117188, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384777Z', 'extras': []}\n","I1127 08:08:17.384905 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'precision', 'value': 0.687261164188385, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384888Z', 'extras': []}\n","I1127 08:08:17.385003 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'prediction/mean', 'value': 0.2372526079416275, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.384986Z', 'extras': []}\n","I1127 08:08:17.385100 139937464895360 tf_logging.py:115] Benchmark metric: {'name': 'recall', 'value': 0.5611024498939514, 'unit': None, 'global_step': 3258, 'timestamp': '2018-11-27T08:08:17.385083Z', 'extras': []}\n"],"name":"stdout"}]},{"metadata":{"id":"izPFSrjyOnfR","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":35},"outputId":"f52e78b8-ace1-49f2-94bc-21385698b0c7","executionInfo":{"status":"ok","timestamp":1543306469983,"user_tz":-330,"elapsed":8836,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["!ls /tmp/census_data/"],"execution_count":10,"outputs":[{"output_type":"stream","text":["adult.data adult.test\n"],"name":"stdout"}]},{"metadata":{"id":"UWZpPsVrO8eK","colab_type":"code","colab":{}},"cell_type":"code","source":["train_file = \"/tmp/census_data/adult.data\"\n","test_file = \"/tmp/census_data/adult.test\""],"execution_count":0,"outputs":[]},{"metadata":{"id":"DSBtk2SUPHnQ","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":292},"outputId":"765dc34e-77f1-4ed3-eda7-113aed46df2c","executionInfo":{"status":"ok","timestamp":1543306560885,"user_tz":-330,"elapsed":1801,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["import pandas\n","\n","train_df = pandas.read_csv(train_file, header = None, names = census_dataset._CSV_COLUMNS)\n","test_df = pandas.read_csv(test_file, header = None, names = census_dataset._CSV_COLUMNS)\n","\n","train_df.head()"],"execution_count":13,"outputs":[{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n"," .dataframe tbody tr th:only-of-type {\n"," vertical-align: middle;\n"," }\n","\n"," .dataframe tbody tr th {\n"," vertical-align: top;\n"," }\n","\n"," .dataframe thead th {\n"," text-align: right;\n"," }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: right;\">\n"," <th></th>\n"," <th>age</th>\n"," <th>workclass</th>\n"," <th>fnlwgt</th>\n"," <th>education</th>\n"," <th>education_num</th>\n"," <th>marital_status</th>\n"," <th>occupation</th>\n"," <th>relationship</th>\n"," <th>race</th>\n"," <th>gender</th>\n"," <th>capital_gain</th>\n"," <th>capital_loss</th>\n"," <th>hours_per_week</th>\n"," <th>native_country</th>\n"," <th>income_bracket</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <th>0</th>\n"," <td>39</td>\n"," <td>State-gov</td>\n"," <td>77516</td>\n"," <td>Bachelors</td>\n"," <td>13</td>\n"," <td>Never-married</td>\n"," <td>Adm-clerical</td>\n"," <td>Not-in-family</td>\n"," <td>White</td>\n"," <td>Male</td>\n"," <td>2174</td>\n"," <td>0</td>\n"," <td>40</td>\n"," <td>United-States</td>\n"," <td><=50K</td>\n"," </tr>\n"," <tr>\n"," <th>1</th>\n"," <td>50</td>\n"," <td>Self-emp-not-inc</td>\n"," <td>83311</td>\n"," <td>Bachelors</td>\n"," <td>13</td>\n"," <td>Married-civ-spouse</td>\n"," <td>Exec-managerial</td>\n"," <td>Husband</td>\n"," <td>White</td>\n"," <td>Male</td>\n"," <td>0</td>\n"," <td>0</td>\n"," <td>13</td>\n"," <td>United-States</td>\n"," <td><=50K</td>\n"," </tr>\n"," <tr>\n"," <th>2</th>\n"," <td>38</td>\n"," <td>Private</td>\n"," <td>215646</td>\n"," <td>HS-grad</td>\n"," <td>9</td>\n"," <td>Divorced</td>\n"," <td>Handlers-cleaners</td>\n"," <td>Not-in-family</td>\n"," <td>White</td>\n"," <td>Male</td>\n"," <td>0</td>\n"," <td>0</td>\n"," <td>40</td>\n"," <td>United-States</td>\n"," <td><=50K</td>\n"," </tr>\n"," <tr>\n"," <th>3</th>\n"," <td>53</td>\n"," <td>Private</td>\n"," <td>234721</td>\n"," <td>11th</td>\n"," <td>7</td>\n"," <td>Married-civ-spouse</td>\n"," <td>Handlers-cleaners</td>\n"," <td>Husband</td>\n"," <td>Black</td>\n"," <td>Male</td>\n"," <td>0</td>\n"," <td>0</td>\n"," <td>40</td>\n"," <td>United-States</td>\n"," <td><=50K</td>\n"," </tr>\n"," <tr>\n"," <th>4</th>\n"," <td>28</td>\n"," <td>Private</td>\n"," <td>338409</td>\n"," <td>Bachelors</td>\n"," <td>13</td>\n"," <td>Married-civ-spouse</td>\n"," <td>Prof-specialty</td>\n"," <td>Wife</td>\n"," <td>Black</td>\n"," <td>Female</td>\n"," <td>0</td>\n"," <td>0</td>\n"," <td>40</td>\n"," <td>Cuba</td>\n"," <td><=50K</td>\n"," </tr>\n"," </tbody>\n","</table>\n","</div>"],"text/plain":[" age workclass fnlwgt education education_num \\\n","0 39 State-gov 77516 Bachelors 13 \n","1 50 Self-emp-not-inc 83311 Bachelors 13 \n","2 38 Private 215646 HS-grad 9 \n","3 53 Private 234721 11th 7 \n","4 28 Private 338409 Bachelors 13 \n","\n"," marital_status occupation relationship race gender \\\n","0 Never-married Adm-clerical Not-in-family White Male \n","1 Married-civ-spouse Exec-managerial Husband White Male \n","2 Divorced Handlers-cleaners Not-in-family White Male \n","3 Married-civ-spouse Handlers-cleaners Husband Black Male \n","4 Married-civ-spouse Prof-specialty Wife Black Female \n","\n"," capital_gain capital_loss hours_per_week native_country income_bracket \n","0 2174 0 40 United-States <=50K \n","1 0 0 13 United-States <=50K \n","2 0 0 40 United-States <=50K \n","3 0 0 40 United-States <=50K \n","4 0 0 40 Cuba <=50K "]},"metadata":{"tags":[]},"execution_count":13}]},{"metadata":{"id":"SYl18FXxP0Sv","colab_type":"code","colab":{}},"cell_type":"code","source":["def easy_input_function(df, label_key, num_epochs, shuffle, batch_size):\n"," label = df[label_key]\n"," ds = tf.data.Dataset.from_tensor_slices((dict(df),label))\n","\n"," if shuffle:\n"," ds = ds.shuffle(10000)\n","\n"," ds = ds.batch(batch_size).repeat(num_epochs)\n","\n"," return ds"],"execution_count":0,"outputs":[]},{"metadata":{"id":"k4H0ua_OP2Dg","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":143},"outputId":"3fff88a5-21e5-4737-c426-20024b6da084","executionInfo":{"status":"ok","timestamp":1543306774986,"user_tz":-330,"elapsed":1852,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["ds = easy_input_function(train_df, label_key='income_bracket', num_epochs=5, shuffle=True, batch_size=10)\n","\n","for feature_batch, label_batch in ds.take(1):\n"," print('Some feature keys:', list(feature_batch.keys())[:5])\n"," print()\n"," print('A batch of Ages :', feature_batch['age'])\n"," print()\n"," print('A batch of Labels:', label_batch )"],"execution_count":15,"outputs":[{"output_type":"stream","text":["Some feature keys: ['age', 'workclass', 'fnlwgt', 'education', 'education_num']\n","\n","A batch of Ages : tf.Tensor([31 27 24 39 37 38 25 75 34 28], shape=(10,), dtype=int32)\n","\n","A batch of Labels: tf.Tensor(\n","[b'<=50K' b'<=50K' b'<=50K' b'<=50K' b'<=50K' b'>50K' b'<=50K' b'>50K'\n"," b'<=50K' b'<=50K'], shape=(10,), dtype=string)\n"],"name":"stdout"}]},{"metadata":{"id":"mFb5PoLiQSaB","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":521},"outputId":"ed614ea9-2514-4155-c3ed-d4756e97465e","executionInfo":{"status":"ok","timestamp":1543306821709,"user_tz":-330,"elapsed":3884,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["import inspect\n","print(inspect.getsource(census_dataset.input_fn))"],"execution_count":16,"outputs":[{"output_type":"stream","text":["def input_fn(data_file, num_epochs, shuffle, batch_size):\n"," \"\"\"Generate an input function for the Estimator.\"\"\"\n"," assert tf.gfile.Exists(data_file), (\n"," '%s not found. Please make sure you have run census_dataset.py and '\n"," 'set the --data_dir argument to the correct path.' % data_file)\n","\n"," def parse_csv(value):\n"," tf.logging.info('Parsing {}'.format(data_file))\n"," columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)\n"," features = dict(zip(_CSV_COLUMNS, columns))\n"," labels = features.pop('income_bracket')\n"," classes = tf.equal(labels, '>50K') # binary classification\n"," return features, classes\n","\n"," # Extract lines from input files using the Dataset API.\n"," dataset = tf.data.TextLineDataset(data_file)\n","\n"," if shuffle:\n"," dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])\n","\n"," dataset = dataset.map(parse_csv, num_parallel_calls=5)\n","\n"," # We call repeat after shuffling, rather than before, to prevent separate\n"," # epochs from blending together.\n"," dataset = dataset.repeat(num_epochs)\n"," dataset = dataset.batch(batch_size)\n"," return dataset\n","\n"],"name":"stdout"}]},{"metadata":{"id":"m_HbyHjNRC4X","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":521},"outputId":"145ebdd6-777f-4a8d-9fbe-b7f39c3d39d9","executionInfo":{"status":"ok","timestamp":1543307018159,"user_tz":-330,"elapsed":2328,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["import inspect\n","print(inspect.getsource(census_dataset.input_fn))"],"execution_count":17,"outputs":[{"output_type":"stream","text":["def input_fn(data_file, num_epochs, shuffle, batch_size):\n"," \"\"\"Generate an input function for the Estimator.\"\"\"\n"," assert tf.gfile.Exists(data_file), (\n"," '%s not found. Please make sure you have run census_dataset.py and '\n"," 'set the --data_dir argument to the correct path.' % data_file)\n","\n"," def parse_csv(value):\n"," tf.logging.info('Parsing {}'.format(data_file))\n"," columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)\n"," features = dict(zip(_CSV_COLUMNS, columns))\n"," labels = features.pop('income_bracket')\n"," classes = tf.equal(labels, '>50K') # binary classification\n"," return features, classes\n","\n"," # Extract lines from input files using the Dataset API.\n"," dataset = tf.data.TextLineDataset(data_file)\n","\n"," if shuffle:\n"," dataset = dataset.shuffle(buffer_size=_NUM_EXAMPLES['train'])\n","\n"," dataset = dataset.map(parse_csv, num_parallel_calls=5)\n","\n"," # We call repeat after shuffling, rather than before, to prevent separate\n"," # epochs from blending together.\n"," dataset = dataset.repeat(num_epochs)\n"," dataset = dataset.batch(batch_size)\n"," return dataset\n","\n"],"name":"stdout"}]},{"metadata":{"id":"4Es6XoWYSFMB","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":161},"outputId":"9dd16027-b7c7-4e12-8678-7a54e89f3d7b","executionInfo":{"status":"ok","timestamp":1543307385080,"user_tz":-330,"elapsed":2578,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["ds = census_dataset.input_fn(train_file, num_epochs=5, shuffle=True, batch_size=10)\n","\n","for feature_batch, label_batch in ds.take(1):\n"," print('Feature keys:', list(feature_batch.keys())[:5])\n"," print()\n"," print('Age batch :', feature_batch['age'])\n"," print()\n"," print('Label batch :', label_batch )"],"execution_count":18,"outputs":[{"output_type":"stream","text":["INFO:tensorflow:Parsing /tmp/census_data/adult.data\n"],"name":"stdout"},{"output_type":"stream","text":["WARNING: Logging before flag parsing goes to stderr.\n","I1127 08:29:43.140318 140539311126400 tf_logging.py:115] Parsing /tmp/census_data/adult.data\n"],"name":"stderr"},{"output_type":"stream","text":["Feature keys: ['age', 'workclass', 'fnlwgt', 'education', 'education_num']\n","\n","Age batch : tf.Tensor([30 25 58 27 50 22 24 37 21 48], shape=(10,), dtype=int32)\n","\n","Label batch : tf.Tensor([False False False False False False False True False False], shape=(10,), dtype=bool)\n"],"name":"stdout"}]},{"metadata":{"id":"hYi36z9KSpwc","colab_type":"code","colab":{}},"cell_type":"code","source":["import functools\n","\n","train_inpf = functools.partial(census_dataset.input_fn, train_file, num_epochs=2, shuffle=True, batch_size=64)\n","test_inpf = functools.partial(census_dataset.input_fn, test_file, num_epochs=1, shuffle=False, batch_size=64)"],"execution_count":0,"outputs":[]},{"metadata":{"id":"BOLwmq8FSs5e","colab_type":"code","colab":{}},"cell_type":"code","source":["age = fc.numeric_column('age')"],"execution_count":0,"outputs":[]},{"metadata":{"id":"Ii84-fEkVYUM","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":197},"outputId":"82b15b50-2675-43e8-91b9-afe3bea008de","executionInfo":{"status":"ok","timestamp":1543308159191,"user_tz":-330,"elapsed":2027,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["fc.input_layer(feature_batch, [age]).numpy()"],"execution_count":22,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([[30.],\n"," [25.],\n"," [58.],\n"," [27.],\n"," [50.],\n"," [22.],\n"," [24.],\n"," [37.],\n"," [21.],\n"," [48.]], dtype=float32)"]},"metadata":{"tags":[]},"execution_count":22}]},{"metadata":{"id":"Fm_JIwq6VbtP","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":55},"outputId":"50a7588e-292b-4ac5-dfa4-0bebc82812ec","executionInfo":{"status":"ok","timestamp":1543308180904,"user_tz":-330,"elapsed":8,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["classifier = tf.estimator.LinearClassifier(feature_columns=[age])\n","classifier.train(train_inpf)\n","result = classifier.evaluate(test_inpf)\n","\n","clear_output() # used for display in notebook\n","print(result)"],"execution_count":23,"outputs":[{"output_type":"stream","text":["{'accuracy': 0.7631595, 'accuracy_baseline': 0.76377374, 'auc': 0.6783709, 'auc_precision_recall': 0.31141186, 'average_loss': 0.5240148, 'label/mean': 0.23622628, 'loss': 33.456802, 'precision': 0.29166666, 'prediction/mean': 0.22446921, 'recall': 0.0018200728, 'global_step': 1018}\n"],"name":"stdout"}]},{"metadata":{"id":"MXFIsik3VbrF","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":197},"outputId":"11d4aabf-308a-415c-a0f8-70f6c36056b3","executionInfo":{"status":"ok","timestamp":1543308190564,"user_tz":-330,"elapsed":1874,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["education_num = tf.feature_column.numeric_column('education_num')\n","capital_gain = tf.feature_column.numeric_column('capital_gain')\n","capital_loss = tf.feature_column.numeric_column('capital_loss')\n","hours_per_week = tf.feature_column.numeric_column('hours_per_week')\n","\n","my_numeric_columns = [age,education_num, capital_gain, capital_loss, hours_per_week]\n","\n","fc.input_layer(feature_batch, my_numeric_columns).numpy()"],"execution_count":24,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([[30., 0., 0., 13., 40.],\n"," [25., 0., 0., 13., 50.],\n"," [58., 0., 0., 9., 40.],\n"," [27., 0., 0., 7., 45.],\n"," [50., 0., 0., 4., 40.],\n"," [22., 0., 0., 9., 40.],\n"," [24., 0., 0., 9., 30.],\n"," [37., 0., 0., 10., 40.],\n"," [21., 0., 0., 9., 32.],\n"," [48., 0., 0., 9., 37.]], dtype=float32)"]},"metadata":{"tags":[]},"execution_count":24}]},{"metadata":{"id":"xI8ZjyasVboL","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":215},"outputId":"cd0d6143-8219-4212-9194-3bb625f0b8bc","executionInfo":{"status":"ok","timestamp":1543308213798,"user_tz":-330,"elapsed":10,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["classifier = tf.estimator.LinearClassifier(feature_columns=my_numeric_columns)\n","classifier.train(train_inpf)\n","\n","result = classifier.evaluate(test_inpf)\n","\n","clear_output()\n","\n","for key,value in sorted(result.items()):\n"," print('%s: %s' % (key, value))"],"execution_count":25,"outputs":[{"output_type":"stream","text":["accuracy: 0.7826301\n","accuracy_baseline: 0.76377374\n","auc: 0.72746056\n","auc_precision_recall: 0.5237079\n","average_loss: 0.8549547\n","global_step: 1018\n","label/mean: 0.23622628\n","loss: 54.58634\n","precision: 0.61619985\n","prediction/mean: 0.29669306\n","recall: 0.21164846\n"],"name":"stdout"}]},{"metadata":{"id":"fqb_ylrsVbli","colab_type":"code","colab":{}},"cell_type":"code","source":["relationship = fc.categorical_column_with_vocabulary_list(\n"," 'relationship',\n"," ['Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried', 'Other-relative'])"],"execution_count":0,"outputs":[]},{"metadata":{"id":"HqfVDVGaVzBb","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":343},"outputId":"4ff22abf-8828-4f7f-b099-332b63531215","executionInfo":{"status":"ok","timestamp":1543308263508,"user_tz":-330,"elapsed":2016,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["fc.input_layer(feature_batch, [age, fc.indicator_column(relationship)])"],"execution_count":27,"outputs":[{"output_type":"stream","text":["WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/sparse_ops.py:1165: sparse_to_dense (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.\n"],"name":"stdout"},{"output_type":"stream","text":["W1127 08:44:22.706813 140539311126400 tf_logging.py:125] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/sparse_ops.py:1165: sparse_to_dense (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.\n"],"name":"stderr"},{"output_type":"execute_result","data":{"text/plain":["<tf.Tensor: id=4578, shape=(10, 7), dtype=float32, numpy=\n","array([[30., 0., 0., 0., 0., 1., 0.],\n"," [25., 0., 1., 0., 0., 0., 0.],\n"," [58., 0., 1., 0., 0., 0., 0.],\n"," [27., 0., 1., 0., 0., 0., 0.],\n"," [50., 0., 1., 0., 0., 0., 0.],\n"," [22., 0., 0., 0., 1., 0., 0.],\n"," [24., 0., 1., 0., 0., 0., 0.],\n"," [37., 1., 0., 0., 0., 0., 0.],\n"," [21., 0., 0., 0., 1., 0., 0.],\n"," [48., 0., 0., 0., 0., 1., 0.]], dtype=float32)>"]},"metadata":{"tags":[]},"execution_count":27}]},{"metadata":{"id":"myf_pJJuV003","colab_type":"code","colab":{}},"cell_type":"code","source":["occupation = tf.feature_column.categorical_column_with_hash_bucket(\n"," 'occupation', hash_bucket_size=1000)"],"execution_count":0,"outputs":[]},{"metadata":{"id":"TaDm84pNV4LS","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":197},"outputId":"2ba11a3f-858d-42b5-fbf8-6970d68e51a9","executionInfo":{"status":"ok","timestamp":1543308284410,"user_tz":-330,"elapsed":1919,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["for item in feature_batch['occupation'].numpy():\n"," print(item.decode())"],"execution_count":29,"outputs":[{"output_type":"stream","text":["Sales\n","Exec-managerial\n","Adm-clerical\n","Craft-repair\n","Craft-repair\n","Machine-op-inspct\n","Sales\n","Other-service\n","Other-service\n","Adm-clerical\n"],"name":"stdout"}]},{"metadata":{"id":"NpyhTyrwV6zL","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":35},"outputId":"f89e0c37-a4a5-4a99-90b6-166a3904f5cc","executionInfo":{"status":"ok","timestamp":1543308295201,"user_tz":-330,"elapsed":1513,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["occupation_result = fc.input_layer(feature_batch, [fc.indicator_column(occupation)])\n","\n","occupation_result.numpy().shape"],"execution_count":30,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(10, 1000)"]},"metadata":{"tags":[]},"execution_count":30}]},{"metadata":{"id":"FvnFwFWsV-4o","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":35},"outputId":"d87c6574-5831-4d82-ad2d-fa827813e478","executionInfo":{"status":"ok","timestamp":1543308312773,"user_tz":-330,"elapsed":2558,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["tf.argmax(occupation_result, axis=1).numpy()"],"execution_count":31,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([631, 800, 96, 466, 466, 911, 631, 527, 527, 96])"]},"metadata":{"tags":[]},"execution_count":31}]},{"metadata":{"id":"BttAzj2YWBaL","colab_type":"code","colab":{}},"cell_type":"code","source":["education = tf.feature_column.categorical_column_with_vocabulary_list(\n"," 'education', [\n"," 'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',\n"," 'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',\n"," '5th-6th', '10th', '1st-4th', 'Preschool', '12th'])\n","\n","marital_status = tf.feature_column.categorical_column_with_vocabulary_list(\n"," 'marital_status', [\n"," 'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',\n"," 'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])\n","\n","workclass = tf.feature_column.categorical_column_with_vocabulary_list(\n"," 'workclass', [\n"," 'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',\n"," 'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])\n","\n","\n","my_categorical_columns = [relationship, occupation, education, marital_status, workclass]"],"execution_count":0,"outputs":[]},{"metadata":{"id":"ot3fR3NfWEcM","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":215},"outputId":"8eeaefd3-ed57-4d00-9a8b-17b38f3a7f3b","executionInfo":{"status":"ok","timestamp":1543308348306,"user_tz":-330,"elapsed":6,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["classifier = tf.estimator.LinearClassifier(feature_columns=my_numeric_columns+my_categorical_columns)\n","classifier.train(train_inpf)\n","result = classifier.evaluate(test_inpf)\n","\n","clear_output()\n","\n","for key,value in sorted(result.items()):\n"," print('%s: %s' % (key, value))"],"execution_count":33,"outputs":[{"output_type":"stream","text":["accuracy: 0.8179473\n","accuracy_baseline: 0.76377374\n","auc: 0.8668444\n","auc_precision_recall: 0.6395996\n","average_loss: 2.8736408\n","global_step: 1018\n","label/mean: 0.23622628\n","loss: 183.47351\n","precision: 0.6082474\n","prediction/mean: 0.28163025\n","recall: 0.64430577\n"],"name":"stdout"}]},{"metadata":{"id":"Y12Yp5iHWMXA","colab_type":"code","colab":{}},"cell_type":"code","source":["age_buckets = tf.feature_column.bucketized_column(\n"," age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])"],"execution_count":0,"outputs":[]},{"metadata":{"id":"ThP4z3ScWPFD","colab_type":"code","colab":{}},"cell_type":"code","source":["education_x_occupation = tf.feature_column.crossed_column(\n"," ['education', 'occupation'], hash_bucket_size=1000)"],"execution_count":0,"outputs":[]},{"metadata":{"id":"6jH_-vTFWRRr","colab_type":"code","colab":{}},"cell_type":"code","source":["age_buckets_x_education_x_occupation = tf.feature_column.crossed_column(\n"," [age_buckets, 'education', 'occupation'], hash_bucket_size=1000)"],"execution_count":0,"outputs":[]},{"metadata":{"id":"YlJ1egOSWVwq","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":325},"outputId":"fad16430-b970-41f7-e276-d3fe66d38692","executionInfo":{"status":"ok","timestamp":1543308405335,"user_tz":-330,"elapsed":1629,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["import tempfile\n","\n","base_columns = [\n"," education, marital_status, relationship, workclass, occupation,\n"," age_buckets,\n","]\n","\n","crossed_columns = [\n"," tf.feature_column.crossed_column(\n"," ['education', 'occupation'], hash_bucket_size=1000),\n"," tf.feature_column.crossed_column(\n"," [age_buckets, 'education', 'occupation'], hash_bucket_size=1000),\n","]\n","\n","model = tf.estimator.LinearClassifier(\n"," model_dir=tempfile.mkdtemp(), \n"," feature_columns=base_columns + crossed_columns,\n"," optimizer=tf.train.FtrlOptimizer(learning_rate=0.1))"],"execution_count":37,"outputs":[{"output_type":"stream","text":["INFO:tensorflow:Using default config.\n"],"name":"stdout"},{"output_type":"stream","text":["I1127 08:46:44.373726 140539311126400 tf_logging.py:115] Using default config.\n"],"name":"stderr"},{"output_type":"stream","text":["INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp_xzg48qz', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n","graph_options {\n"," rewrite_options {\n"," meta_optimizer_iterations: ONE\n"," }\n","}\n",", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd1a06a8828>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"],"name":"stdout"},{"output_type":"stream","text":["I1127 08:46:44.377197 140539311126400 tf_logging.py:115] Using config: {'_model_dir': '/tmp/tmp_xzg48qz', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n","graph_options {\n"," rewrite_options {\n"," meta_optimizer_iterations: ONE\n"," }\n","}\n",", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fd1a06a8828>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n"],"name":"stderr"}]},{"metadata":{"id":"iy7ugMwtWWmR","colab_type":"code","colab":{}},"cell_type":"code","source":["train_inpf = functools.partial(census_dataset.input_fn, train_file, \n"," num_epochs=40, shuffle=True, batch_size=64)\n","\n","model.train(train_inpf)\n","\n","clear_output() # used for notebook display"],"execution_count":0,"outputs":[]},{"metadata":{"id":"aalKNCTMWa_S","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":215},"outputId":"8cbb63fc-20dd-49fe-c57b-93987a7a73f1","executionInfo":{"status":"ok","timestamp":1543308535126,"user_tz":-330,"elapsed":5,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["results = model.evaluate(test_inpf)\n","\n","clear_output()\n","\n","for key,value in sorted(result.items()):\n"," print('%s: %0.2f' % (key, value))"],"execution_count":39,"outputs":[{"output_type":"stream","text":["accuracy: 0.82\n","accuracy_baseline: 0.76\n","auc: 0.87\n","auc_precision_recall: 0.64\n","average_loss: 2.87\n","global_step: 1018.00\n","label/mean: 0.24\n","loss: 183.47\n","precision: 0.61\n","prediction/mean: 0.28\n","recall: 0.64\n"],"name":"stdout"}]},{"metadata":{"id":"Q0DwNMVJWRLm","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":669},"outputId":"aa027c35-14ed-4db0-8d45-83343bb75dfb","executionInfo":{"status":"ok","timestamp":1543308565193,"user_tz":-330,"elapsed":7,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["import numpy as np\n","\n","predict_df = test_df[:20].copy()\n","\n","pred_iter = model.predict(\n"," lambda:easy_input_function(predict_df, label_key='income_bracket',\n"," num_epochs=1, shuffle=False, batch_size=10))\n","\n","classes = np.array(['<=50K', '>50K'])\n","pred_class_id = []\n","\n","for pred_dict in pred_iter:\n"," pred_class_id.append(pred_dict['class_ids'])\n","\n","predict_df['predicted_class'] = classes[np.array(pred_class_id)]\n","predict_df['correct'] = predict_df['predicted_class'] == predict_df['income_bracket']\n","\n","clear_output()\n","\n","predict_df[['income_bracket','predicted_class', 'correct']]"],"execution_count":40,"outputs":[{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n"," .dataframe tbody tr th:only-of-type {\n"," vertical-align: middle;\n"," }\n","\n"," .dataframe tbody tr th {\n"," vertical-align: top;\n"," }\n","\n"," .dataframe thead th {\n"," text-align: right;\n"," }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n"," <thead>\n"," <tr style=\"text-align: right;\">\n"," <th></th>\n"," <th>income_bracket</th>\n"," <th>predicted_class</th>\n"," <th>correct</th>\n"," </tr>\n"," </thead>\n"," <tbody>\n"," <tr>\n"," <th>0</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>1</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>2</th>\n"," <td>>50K</td>\n"," <td><=50K</td>\n"," <td>False</td>\n"," </tr>\n"," <tr>\n"," <th>3</th>\n"," <td>>50K</td>\n"," <td><=50K</td>\n"," <td>False</td>\n"," </tr>\n"," <tr>\n"," <th>4</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>5</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>6</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>7</th>\n"," <td>>50K</td>\n"," <td>>50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>8</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>9</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>10</th>\n"," <td>>50K</td>\n"," <td><=50K</td>\n"," <td>False</td>\n"," </tr>\n"," <tr>\n"," <th>11</th>\n"," <td><=50K</td>\n"," <td>>50K</td>\n"," <td>False</td>\n"," </tr>\n"," <tr>\n"," <th>12</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>13</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>14</th>\n"," <td>>50K</td>\n"," <td><=50K</td>\n"," <td>False</td>\n"," </tr>\n"," <tr>\n"," <th>15</th>\n"," <td>>50K</td>\n"," <td>>50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>16</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>17</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>18</th>\n"," <td><=50K</td>\n"," <td><=50K</td>\n"," <td>True</td>\n"," </tr>\n"," <tr>\n"," <th>19</th>\n"," <td>>50K</td>\n"," <td>>50K</td>\n"," <td>True</td>\n"," </tr>\n"," </tbody>\n","</table>\n","</div>"],"text/plain":[" income_bracket predicted_class correct\n","0 <=50K <=50K True\n","1 <=50K <=50K True\n","2 >50K <=50K False\n","3 >50K <=50K False\n","4 <=50K <=50K True\n","5 <=50K <=50K True\n","6 <=50K <=50K True\n","7 >50K >50K True\n","8 <=50K <=50K True\n","9 <=50K <=50K True\n","10 >50K <=50K False\n","11 <=50K >50K False\n","12 <=50K <=50K True\n","13 <=50K <=50K True\n","14 >50K <=50K False\n","15 >50K >50K True\n","16 <=50K <=50K True\n","17 <=50K <=50K True\n","18 <=50K <=50K True\n","19 >50K >50K True"]},"metadata":{"tags":[]},"execution_count":40}]},{"metadata":{"id":"QcJP33byXFme","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":241},"outputId":"6f3bbc74-95c7-4630-cd66-11fd25f7d6a5","executionInfo":{"status":"error","timestamp":1543308602398,"user_tz":-330,"elapsed":2549,"user":{"displayName":"RISHABH CHAKARABARTY","photoUrl":"","userId":"12942336028013708336"}}},"cell_type":"code","source":["plt.figure()\n","_ = plt.hist(weights_base, bins=np.linspace(-3,3,30))\n","plt.title('Base Model')\n","plt.ylim([0,500])"],"execution_count":41,"outputs":[{"output_type":"error","ename":"NameError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-41-9d3e0793d2da>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfigure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights_base\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbins\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m30\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtitle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Base Model'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mylim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m500\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mNameError\u001b[0m: name 'weights_base' is not defined"]},{"output_type":"display_data","data":{"text/plain":["<matplotlib.figure.Figure at 0x7fd1a02544a8>"]},"metadata":{"tags":[]}}]}]}