TensorFlow学习之不平衡数据的分类

TensorFlow学习之不平衡数据的分类

参考

导入模块

import tensorflow as tf
from tensorflow import keras

import os
import datetime
import tempfile

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
2021-09-14 15:38:33.691409: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
colors
['#1f77b4',
 '#ff7f0e',
 '#2ca02c',
 '#d62728',
 '#9467bd',
 '#8c564b',
 '#e377c2',
 '#7f7f7f',
 '#bcbd22',
 '#17becf']

导入数据

  • 将预测数据是否会信用卡欺诈
train_path = './TensorFlow学习之不平衡数据的分类/creditcard.csv'
raw_data = pd.read_csv(train_path)

数据探索

raw_data.head()
Time V1 V2 V3 V4 V5 V6 V7 V8 V9 ... V21 V22 V23 V24 V25 V26 V27 V28 Amount Class
0 0.0 -1.359807 -0.072781 2.536347 1.378155 -0.338321 0.462388 0.239599 0.098698 0.363787 ... -0.018307 0.277838 -0.110474 0.066928 0.128539 -0.189115 0.133558 -0.021053 149.62 0
1 0.0 1.191857 0.266151 0.166480 0.448154 0.060018 -0.082361 -0.078803 0.085102 -0.255425 ... -0.225775 -0.638672 0.101288 -0.339846 0.167170 0.125895 -0.008983 0.014724 2.69 0
2 1.0 -1.358354 -1.340163 1.773209 0.379780 -0.503198 1.800499 0.791461 0.247676 -1.514654 ... 0.247998 0.771679 0.909412 -0.689281 -0.327642 -0.139097 -0.055353 -0.059752 378.66 0
3 1.0 -0.966272 -0.185226 1.792993 -0.863291 -0.010309 1.247203 0.237609 0.377436 -1.387024 ... -0.108300 0.005274 -0.190321 -1.175575 0.647376 -0.221929 0.062723 0.061458 123.50 0
4 2.0 -1.158233 0.877737 1.548718 0.403034 -0.407193 0.095921 0.592941 -0.270533 0.817739 ... -0.009431 0.798278 -0.137458 0.141267 -0.206010 0.502292 0.219422 0.215153 69.99 0

5 rows × 31 columns

raw_data.info(show_counts=True)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 284807 entries, 0 to 284806
Data columns (total 31 columns):
 #   Column  Non-Null Count   Dtype  
---  ------  --------------   -----  
 0   Time    284807 non-null  float64
 1   V1      284807 non-null  float64
 2   V2      284807 non-null  float64
 3   V3      284807 non-null  float64
 4   V4      284807 non-null  float64
 5   V5      284807 non-null  float64
 6   V6      284807 non-null  float64
 7   V7      284807 non-null  float64
 8   V8      284807 non-null  float64
 9   V9      284807 non-null  float64
 10  V10     284807 non-null  float64
 11  V11     284807 non-null  float64
 12  V12     284807 non-null  float64
 13  V13     284807 non-null  float64
 14  V14     284807 non-null  float64
 15  V15     284807 non-null  float64
 16  V16     284807 non-null  float64
 17  V17     284807 non-null  float64
 18  V18     284807 non-null  float64
 19  V19     284807 non-null  float64
 20  V20     284807 non-null  float64
 21  V21     284807 non-null  float64
 22  V22     284807 non-null  float64
 23  V23     284807 non-null  float64
 24  V24     284807 non-null  float64
 25  V25     284807 non-null  float64
 26  V26     284807 non-null  float64
 27  V27     284807 non-null  float64
 28  V28     284807 non-null  float64
 29  Amount  284807 non-null  float64
 30  Class   284807 non-null  int64  
dtypes: float64(30), int64(1)
memory usage: 67.4 MB
raw_data.describe(include='all').T
count mean std min 25% 50% 75% max
Time 284807.0 9.481386e+04 47488.145955 0.000000 54201.500000 84692.000000 139320.500000 172792.000000
V1 284807.0 1.168375e-15 1.958696 -56.407510 -0.920373 0.018109 1.315642 2.454930
V2 284807.0 3.416908e-16 1.651309 -72.715728 -0.598550 0.065486 0.803724 22.057729
V3 284807.0 -1.379537e-15 1.516255 -48.325589 -0.890365 0.179846 1.027196 9.382558
V4 284807.0 2.074095e-15 1.415869 -5.683171 -0.848640 -0.019847 0.743341 16.875344
V5 284807.0 9.604066e-16 1.380247 -113.743307 -0.691597 -0.054336 0.611926 34.801666
V6 284807.0 1.487313e-15 1.332271 -26.160506 -0.768296 -0.274187 0.398565 73.301626
V7 284807.0 -5.556467e-16 1.237094 -43.557242 -0.554076 0.040103 0.570436 120.589494
V8 284807.0 1.213481e-16 1.194353 -73.216718 -0.208630 0.022358 0.327346 20.007208
V9 284807.0 -2.406331e-15 1.098632 -13.434066 -0.643098 -0.051429 0.597139 15.594995
V10 284807.0 2.239053e-15 1.088850 -24.588262 -0.535426 -0.092917 0.453923 23.745136
V11 284807.0 1.673327e-15 1.020713 -4.797473 -0.762494 -0.032757 0.739593 12.018913
V12 284807.0 -1.247012e-15 0.999201 -18.683715 -0.405571 0.140033 0.618238 7.848392
V13 284807.0 8.190001e-16 0.995274 -5.791881 -0.648539 -0.013568 0.662505 7.126883
V14 284807.0 1.207294e-15 0.958596 -19.214325 -0.425574 0.050601 0.493150 10.526766
V15 284807.0 4.887456e-15 0.915316 -4.498945 -0.582884 0.048072 0.648821 8.877742
V16 284807.0 1.437716e-15 0.876253 -14.129855 -0.468037 0.066413 0.523296 17.315112
V17 284807.0 -3.772171e-16 0.849337 -25.162799 -0.483748 -0.065676 0.399675 9.253526
V18 284807.0 9.564149e-16 0.838176 -9.498746 -0.498850 -0.003636 0.500807 5.041069
V19 284807.0 1.039917e-15 0.814041 -7.213527 -0.456299 0.003735 0.458949 5.591971
V20 284807.0 6.406204e-16 0.770925 -54.497720 -0.211721 -0.062481 0.133041 39.420904
V21 284807.0 1.654067e-16 0.734524 -34.830382 -0.228395 -0.029450 0.186377 27.202839
V22 284807.0 -3.568593e-16 0.725702 -10.933144 -0.542350 0.006782 0.528554 10.503090
V23 284807.0 2.578648e-16 0.624460 -44.807735 -0.161846 -0.011193 0.147642 22.528412
V24 284807.0 4.473266e-15 0.605647 -2.836627 -0.354586 0.040976 0.439527 4.584549
V25 284807.0 5.340915e-16 0.521278 -10.295397 -0.317145 0.016594 0.350716 7.519589
V26 284807.0 1.683437e-15 0.482227 -2.604551 -0.326984 -0.052139 0.240952 3.517346
V27 284807.0 -3.660091e-16 0.403632 -22.565679 -0.070840 0.001342 0.091045 31.612198
V28 284807.0 -1.227390e-16 0.330083 -15.430084 -0.052960 0.011244 0.078280 33.847808
Amount 284807.0 8.834962e+01 250.120109 0.000000 5.600000 22.000000 77.165000 25691.160000
Class 284807.0 1.727486e-03 0.041527 0.000000 0.000000 0.000000 0.000000 1.000000

检查平衡性

target = 'Class'
raw_data.groupby(target, as_index=False)['Time'].count()
Class Time
0 0 284315
1 1 492
neg, pos = np.bincount(raw_data[target])
total = neg + pos
print('Examples:\n    Total: {}\n    Positive: {} ({:.2f}% of total)\n'.format(
    total, pos, 100 * pos / total))
Examples:
    Total: 284807
    Positive: 492 (0.17% of total)
dataframe = raw_data.copy()

数据处理

  1. 剔除无用特征,对Amount进行log变换
  2. 数据标准化
  3. 截断

log变换

dataframe.pop('Time')

eps = 0.001 # 0 => 0.1¢
for feature in ['Amount']:
    dataframe[feature+'_log'] = np.log(dataframe.pop(feature)+eps)

划分训练集、验证集、测试集

train, test = train_test_split(dataframe, test_size=0.2, stratify=dataframe[target], random_state=2030)
train, val = train_test_split(train, test_size=0.2, stratify=train[target], random_state=2030)
print(len(train), 'train examples')
print(len(val), 'validation examples')
print(len(test), 'test examples')
182276 train examples
45569 validation examples
56962 test examples
# Form np arrays of labels and features.
train_labels = np.array(train.pop(target))
bool_train_labels = train_labels != 0
val_labels = np.array(val.pop(target))
test_labels = np.array(test.pop(target))

train_features = np.array(train)
val_features = np.array(val)
test_features = np.array(test)

标准化

  • 注意用训练集的数据来fit transform验证集、测试集
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)

val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)

train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)


print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)

print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)

观察数据分布

pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train.columns)

sns.jointplot(pos_df['V5'], pos_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")

sns.jointplot(neg_df['V5'], neg_df['V6'],
              kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/meiyunhe/softwares/miniconda3/envs/env_tensorflow/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(
/home/meiyunhe/softwares/miniconda3/envs/env_tensorflow/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(

output_25_1

output_25_2

pos_df.describe().T
count mean std min 25% 50% 75% max
V1 315.0 -1.760906 1.909379 -5.000000 -2.995964 -1.194895 -0.232911 1.087466
V2 315.0 1.821280 1.886729 -5.000000 0.708689 1.566040 3.037499 5.000000
V3 315.0 -2.966703 1.843923 -5.000000 -5.000000 -3.302910 -1.408258 1.486048
V4 315.0 2.940414 1.616449 -0.924989 1.633506 2.848541 4.487645 5.000000
V5 315.0 -1.514576 2.295449 -5.000000 -3.413844 -1.120011 0.140397 5.000000
V6 315.0 -1.041315 1.378416 -4.764846 -1.883094 -1.055371 -0.242898 4.816880
V7 315.0 -2.473715 2.118062 -5.000000 -5.000000 -2.385480 -0.803709 3.008926
V8 315.0 0.661945 2.325665 -5.000000 -0.136614 0.471054 1.472319 5.000000
V9 315.0 -2.107191 1.819530 -5.000000 -3.521353 -2.012870 -0.719422 3.049323
V10 315.0 -3.418757 1.897714 -5.000000 -5.000000 -4.057184 -2.573482 3.691563
V11 315.0 3.202421 1.776641 -1.665756 2.001885 3.536426 5.000000 5.000000
V12 315.0 -3.778767 1.782103 -5.000000 -5.000000 -5.000000 -2.989429 1.234618
V13 315.0 -0.122043 1.112873 -3.142507 -0.973313 -0.067780 0.622709 2.828751
V14 315.0 -4.120755 1.711212 -5.000000 -5.000000 -5.000000 -4.441194 1.639314
V15 315.0 -0.127527 1.146112 -3.382938 -0.759891 -0.073017 0.701528 2.697669
V16 315.0 -3.043271 2.275678 -5.000000 -5.000000 -3.836717 -1.316847 2.841122
V17 315.0 -3.171004 2.850112 -5.000000 -5.000000 -5.000000 -1.612624 5.000000
V18 315.0 -2.069377 2.456956 -5.000000 -5.000000 -1.969603 -0.060276 3.926839
V19 315.0 0.875712 1.872091 -4.032744 -0.339676 0.841347 1.961852 5.000000
V20 315.0 0.486428 1.460203 -5.000000 -0.160274 0.370908 1.086496 5.000000
V21 315.0 0.798396 1.932013 -5.000000 0.032131 0.790971 1.855851 5.000000
V22 315.0 0.137193 1.411641 -5.000000 -0.657904 0.159077 0.897130 5.000000
V23 315.0 0.019576 1.377033 -5.000000 -0.546028 -0.127559 0.525010 5.000000
V24 315.0 -0.124175 0.838200 -3.342668 -0.666720 -0.042600 0.534642 1.799373
V25 315.0 0.084645 1.419304 -4.153589 -0.605800 0.145072 0.876187 4.132529
V26 315.0 0.072469 0.977484 -2.388598 -0.543418 0.000739 0.760928 5.000000
V27 315.0 0.675460 2.413487 -5.000000 -0.045323 0.923004 2.036083 5.000000
V28 315.0 0.239053 1.653805 -5.000000 -0.355816 0.465788 1.105123 5.000000
Amount_log 315.0 -0.382619 1.617817 -4.855759 -1.450333 -0.356358 0.847578 2.325855
neg_df.describe().T
count mean std min 25% 50% 75% max
V1 181961.0 0.012509 0.915748 -5.000000 -0.465855 0.011296 0.672076 1.251724
V2 181961.0 0.008627 0.856092 -5.000000 -0.359485 0.040111 0.481902 5.000000
V3 181961.0 0.011040 0.936744 -5.000000 -0.584617 0.120097 0.678324 2.790418
V4 181961.0 -0.006194 0.983917 -3.924999 -0.600361 -0.017643 0.521791 5.000000
V5 181961.0 0.005026 0.882939 -5.000000 -0.493601 -0.038049 0.438492 5.000000
V6 181961.0 0.001449 0.965131 -5.000000 -0.570013 -0.203485 0.298300 5.000000
V7 181961.0 0.005842 0.818787 -5.000000 -0.437768 0.031821 0.452537 5.000000
V8 181961.0 0.017644 0.745925 -5.000000 -0.172570 0.019889 0.272858 5.000000
V9 181961.0 0.002619 0.983578 -5.000000 -0.583394 -0.045386 0.543770 5.000000
V10 181961.0 0.000444 0.884508 -5.000000 -0.488516 -0.085221 0.419347 5.000000
V11 181961.0 -0.006827 0.980387 -4.583258 -0.747892 -0.032732 0.720054 5.000000
V12 181961.0 0.011853 0.936165 -5.000000 -0.402320 0.141804 0.617894 5.000000
V13 181961.0 0.000204 0.999690 -5.000000 -0.650306 -0.012781 0.666657 5.000000
V14 181961.0 0.013696 0.919757 -5.000000 -0.441005 0.055106 0.513936 5.000000
V15 181961.0 0.000164 0.999363 -4.793884 -0.638472 0.053209 0.708801 5.000000
V16 181961.0 0.008581 0.957594 -5.000000 -0.530996 0.077795 0.596470 5.000000
V17 181961.0 0.014459 0.855516 -5.000000 -0.566987 -0.076122 0.470954 5.000000
V18 181961.0 0.004775 0.983692 -5.000000 -0.595008 -0.002535 0.599786 5.000000
V19 181961.0 -0.001509 0.996377 -5.000000 -0.562193 0.003682 0.561334 5.000000
V20 181961.0 0.001119 0.768444 -5.000000 -0.273441 -0.082090 0.168334 5.000000
V21 181961.0 -0.009335 0.713355 -5.000000 -0.306948 -0.039843 0.251401 5.000000
V22 181961.0 0.000416 0.988440 -5.000000 -0.743404 0.007948 0.728703 5.000000
V23 181961.0 0.003516 0.695287 -5.000000 -0.258290 -0.019468 0.233072 5.000000
V24 181961.0 -0.000064 0.998703 -4.656135 -0.583868 0.067759 0.725433 5.000000
V25 181961.0 0.000240 0.989529 -5.000000 -0.609040 0.032075 0.672402 5.000000
V26 181961.0 -0.000407 0.998447 -5.000000 -0.677797 -0.109895 0.498649 5.000000
V27 181961.0 0.005844 0.799027 -5.000000 -0.176850 0.002015 0.224085 5.000000
V28 181961.0 -0.001089 0.700654 -5.000000 -0.161470 0.034111 0.239486 5.000000
Amount_log 181961.0 0.000662 0.998482 -4.855759 -0.587677 0.075498 0.691266 3.554182

定义模型和metrics

METRICS = [
    keras.metrics.TruePositives(name='tp'),
    keras.metrics.FalsePositives(name='fp'),
    keras.metrics.TrueNegatives(name='tn'),
    keras.metrics.FalseNegatives(name='fn'),
    keras.metrics.BinaryAccuracy(name='acc'),
    keras.metrics.Precision(name='precision'),
    keras.metrics.Recall(name='recall'),
    keras.metrics.AUC(name='auc'),
    keras.metrics.AUC(name='prc', curve='PR'),
]
2021-09-14 15:38:41.927807: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-09-14 15:38:41.929747: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2021-09-14 15:38:42.132272: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:06:00.0 name: Tesla P40 computeCapability: 6.1
coreClock: 1.531GHz coreCount: 30 deviceMemorySize: 22.38GiB deviceMemoryBandwidth: 323.21GiB/s
2021-09-14 15:38:42.132316: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-14 15:38:42.135128: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-14 15:38:42.135177: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-09-14 15:38:42.136258: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2021-09-14 15:38:42.136515: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2021-09-14 15:38:42.138962: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2021-09-14 15:38:42.139484: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2021-09-14 15:38:42.139633: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2021-09-14 15:38:42.142269: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
2021-09-14 15:38:42.145697: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-09-14 15:38:42.148120: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:06:00.0 name: Tesla P40 computeCapability: 6.1
coreClock: 1.531GHz coreCount: 30 deviceMemorySize: 22.38GiB deviceMemoryBandwidth: 323.21GiB/s
2021-09-14 15:38:42.148158: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-14 15:38:42.148193: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-14 15:38:42.148224: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-09-14 15:38:42.148252: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2021-09-14 15:38:42.148280: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2021-09-14 15:38:42.148308: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2021-09-14 15:38:42.148336: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2021-09-14 15:38:42.148364: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2021-09-14 15:38:42.152837: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
2021-09-14 15:38:42.152897: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-14 15:38:42.781704: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-09-14 15:38:42.781748: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267]      0 
2021-09-14 15:38:42.781756: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 0:   N 
2021-09-14 15:38:42.785308: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 21298 MB memory) -> physical GPU (device: 0, name: Tesla P40, pci bus id: 0000:06:00.0, compute capability: 6.1)
def func_make_model(n_x, metrics=None, output_bias=None):
    """
    :param n_x: 输入特征维度
    :param metrics:
    :param output_bias:
    :return model
    """
    if metrics is None:
        metrics = [keras.metrics.BinaryAccuracy(name='acc')]
    if output_bias is not None:
        output_bias = tf.keras.initializers.Constant(output_bias)

    model = keras.Sequential([
        keras.layers.Dense(16, activation='relu', input_shape=(n_x, )),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(1, activation='sigmoid', bias_initializer=output_bias),
    ])

    model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
                  loss=keras.losses.BinaryCrossentropy(),
                  metrics=metrics)
    return model

baseline model

build the model

  • restore_best_weights: Whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.
EPOCHS = 100
BATCH_SIZE = 2048

early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_prc',
                                                  verbose=1,
                                                  patience=10,
                                                  mode='max',
                                                  restore_best_weights=True)
model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_baseline.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 16)                480       
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________
model_baseline.predict(train_features[:10])
2021-09-14 15:38:43.448415: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-09-14 15:38:43.449074: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2394435000 Hz
2021-09-14 15:38:43.566365: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-14 15:38:43.851378: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11





array([[0.7945324 ],
       [0.9206671 ],
       [0.7186621 ],
       [0.8676042 ],
       [0.86442024],
       [0.820997  ],
       [0.8121756 ],
       [0.8080936 ],
       [0.820931  ],
       [0.9337303 ]], dtype=float32)

set the correct initial bias

参考训练神经网络的方法

  1. With the default bias initialization the loss should be about math.log(2)=0.69314
  2. The correct bias to set can be derived from:
  1. with correct bias the loss should be about:
## 1
results = model_baseline.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print(results[0])
1.9581257104873657
## 2
correct_initial_bias = np.log([pos/neg])
correct_initial_bias
array([-6.35935934])
model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS, output_bias=correct_initial_bias)
model_baseline.predict(train_features[:10])
array([[0.00340313],
       [0.00266778],
       [0.00226334],
       [0.00272266],
       [0.00042819],
       [0.00054607],
       [0.00602424],
       [0.00321096],
       [0.0007532 ],
       [0.00322443]], dtype=float32)
## 3
p_0 = pos / (pos+neg)
-1*p_0*np.log(p_0) - (1-p_0)*np.log(1-p_0)
0.012714681335936208
results = model_baseline.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print(results[0])
0.015545893460512161

checkpoint the initial weights

  • 为了方便之后几个模型的参数初始化,故将这次纠正过的初始化参数保存下来
# os.path.join(tempfile.mkdtemp(), 'initial_weights')
correct_initial_bias_path = './tmp/correct_initial_bias/correct_initial_bias'
model_baseline.save_weights(correct_initial_bias_path)

看一下纠正参数初始化后的效果

model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_baseline.load_weights(correct_initial_bias_path)
model_baseline.layers[-1].bias.assign([0.0])  # 设置最后一层的bias为0
zero_bias_history = model_baseline.fit(train_features, train_labels, batch_size=BATCH_SIZE, epochs=20, validation_data=(val_features, val_labels), verbose=0)
model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_baseline.load_weights(correct_initial_bias_path)
correct_bias_history = model_baseline.fit(train_features, train_labels, batch_size=BATCH_SIZE, epochs=20, validation_data=(val_features, val_labels), verbose=0)
def func_plot_loss(history, label, color):
    # Use a log scale on y-axis to show the wide range of values.
    plt.semilogy(history.epoch, history.history['loss'], color=color, label='Train ' + label)
    plt.semilogy(history.epoch, history.history['val_loss'], color=color, label='Val ' + label, linestyle="--")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
func_plot_loss(zero_bias_history, "Zero Bias", 'red')
func_plot_loss(correct_bias_history, "Careful Bias", 'blue')

output_54_0

可以看到,纠正过后,损失值更小

train the baseline model

model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_baseline.load_weights(correct_initial_bias_path)
baseline_history = model_baseline.fit(train_features, train_labels, batch_size=BATCH_SIZE, epochs=EPOCHS, 
                                      callbacks=[early_stopping], validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 4s 21ms/step - loss: 0.0147 - tp: 64.8352 - fp: 45.5934 - tn: 139430.5055 - fn: 168.6374 - acc: 0.9985 - precision: 0.5985 - recall: 0.3220 - auc: 0.7673 - prc: 0.3309 - val_loss: 0.0073 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 45490.0000 - val_fn: 79.0000 - val_acc: 0.9983 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.8791 - val_prc: 0.6609
Epoch 2/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0078 - tp: 48.4505 - fp: 13.4286 - tn: 93966.1978 - fn: 112.4945 - acc: 0.9987 - precision: 0.8123 - recall: 0.3008 - auc: 0.8303 - prc: 0.4306 - val_loss: 0.0051 - val_tp: 29.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 50.0000 - val_acc: 0.9988 - val_precision: 0.8286 - val_recall: 0.3671 - val_auc: 0.8922 - val_prc: 0.6677
Epoch 3/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0081 - tp: 65.3846 - fp: 14.6813 - tn: 93958.2637 - fn: 102.2418 - acc: 0.9986 - precision: 0.7926 - recall: 0.3719 - auc: 0.8072 - prc: 0.4533 - val_loss: 0.0045 - val_tp: 44.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 35.0000 - val_acc: 0.9991 - val_precision: 0.8800 - val_recall: 0.5570 - val_auc: 0.9048 - val_prc: 0.6881
Epoch 4/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0068 - tp: 71.2308 - fp: 15.0110 - tn: 93965.9011 - fn: 88.4286 - acc: 0.9989 - precision: 0.8163 - recall: 0.4337 - auc: 0.8344 - prc: 0.5153 - val_loss: 0.0042 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 32.0000 - val_acc: 0.9992 - val_precision: 0.8868 - val_recall: 0.5949 - val_auc: 0.9175 - val_prc: 0.7249
Epoch 5/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0057 - tp: 76.8901 - fp: 19.1648 - tn: 93963.4286 - fn: 81.0879 - acc: 0.9989 - precision: 0.7970 - recall: 0.4733 - auc: 0.8757 - prc: 0.5857 - val_loss: 0.0040 - val_tp: 48.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 31.0000 - val_acc: 0.9992 - val_precision: 0.8889 - val_recall: 0.6076 - val_auc: 0.9175 - val_prc: 0.7392
Epoch 6/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0064 - tp: 74.9560 - fp: 15.1319 - tn: 93958.1099 - fn: 92.3736 - acc: 0.9988 - precision: 0.8518 - recall: 0.4300 - auc: 0.8618 - prc: 0.5888 - val_loss: 0.0039 - val_tp: 52.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 27.0000 - val_acc: 0.9993 - val_precision: 0.8966 - val_recall: 0.6582 - val_auc: 0.9174 - val_prc: 0.7480
Epoch 7/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0047 - tp: 88.3846 - fp: 12.4615 - tn: 93968.0549 - fn: 71.6703 - acc: 0.9991 - precision: 0.8904 - recall: 0.5613 - auc: 0.9146 - prc: 0.7047 - val_loss: 0.0038 - val_tp: 54.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 25.0000 - val_acc: 0.9993 - val_precision: 0.9000 - val_recall: 0.6835 - val_auc: 0.9174 - val_prc: 0.7568
Epoch 8/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0043 - tp: 95.3297 - fp: 19.2418 - tn: 93959.6923 - fn: 66.3077 - acc: 0.9991 - precision: 0.8432 - recall: 0.6045 - auc: 0.9324 - prc: 0.7021 - val_loss: 0.0037 - val_tp: 52.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 27.0000 - val_acc: 0.9993 - val_precision: 0.8966 - val_recall: 0.6582 - val_auc: 0.9174 - val_prc: 0.7711
Epoch 9/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0049 - tp: 94.0220 - fp: 11.9890 - tn: 93963.3736 - fn: 71.1868 - acc: 0.9991 - precision: 0.9108 - recall: 0.5838 - auc: 0.8996 - prc: 0.7061 - val_loss: 0.0036 - val_tp: 53.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 26.0000 - val_acc: 0.9993 - val_precision: 0.8983 - val_recall: 0.6709 - val_auc: 0.9175 - val_prc: 0.7826
Epoch 10/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0045 - tp: 91.1209 - fp: 20.6923 - tn: 93955.2747 - fn: 73.4835 - acc: 0.9990 - precision: 0.8317 - recall: 0.5825 - auc: 0.9340 - prc: 0.6933 - val_loss: 0.0036 - val_tp: 53.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 26.0000 - val_acc: 0.9993 - val_precision: 0.8983 - val_recall: 0.6709 - val_auc: 0.9175 - val_prc: 0.7929
Epoch 11/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0045 - tp: 84.9231 - fp: 18.3956 - tn: 93962.7363 - fn: 74.5165 - acc: 0.9990 - precision: 0.8351 - recall: 0.5411 - auc: 0.9233 - prc: 0.6955 - val_loss: 0.0035 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.7973
Epoch 12/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0046 - tp: 101.3626 - fp: 17.9341 - tn: 93954.2088 - fn: 67.0659 - acc: 0.9991 - precision: 0.8479 - recall: 0.6011 - auc: 0.9329 - prc: 0.7196 - val_loss: 0.0035 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9175 - val_prc: 0.7939
Epoch 13/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0050 - tp: 94.1099 - fp: 18.0110 - tn: 93957.0110 - fn: 71.4396 - acc: 0.9990 - precision: 0.8307 - recall: 0.5390 - auc: 0.9075 - prc: 0.6494 - val_loss: 0.0035 - val_tp: 56.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 23.0000 - val_acc: 0.9994 - val_precision: 0.9032 - val_recall: 0.7089 - val_auc: 0.9238 - val_prc: 0.7989
Epoch 14/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0047 - tp: 89.5495 - fp: 17.6264 - tn: 93958.5604 - fn: 74.8352 - acc: 0.9990 - precision: 0.8187 - recall: 0.5544 - auc: 0.9229 - prc: 0.6758 - val_loss: 0.0035 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9062 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.7989
Epoch 15/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0043 - tp: 87.0000 - fp: 16.5934 - tn: 93963.8571 - fn: 73.1209 - acc: 0.9991 - precision: 0.8336 - recall: 0.5472 - auc: 0.9099 - prc: 0.6537 - val_loss: 0.0035 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9062 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.7999
Epoch 16/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0040 - tp: 93.5275 - fp: 14.4725 - tn: 93967.1099 - fn: 65.4615 - acc: 0.9992 - precision: 0.8623 - recall: 0.5874 - auc: 0.9295 - prc: 0.7207 - val_loss: 0.0034 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9206 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.8039
Epoch 17/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0044 - tp: 87.9451 - fp: 15.0000 - tn: 93970.7033 - fn: 66.9231 - acc: 0.9991 - precision: 0.8475 - recall: 0.5557 - auc: 0.9207 - prc: 0.6808 - val_loss: 0.0034 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8054
Epoch 18/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0042 - tp: 100.8462 - fp: 18.8791 - tn: 93962.7253 - fn: 58.1209 - acc: 0.9991 - precision: 0.8238 - recall: 0.6282 - auc: 0.9299 - prc: 0.7125 - val_loss: 0.0034 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9206 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.8091
Epoch 19/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0038 - tp: 91.3956 - fp: 16.1758 - tn: 93965.0989 - fn: 67.9011 - acc: 0.9991 - precision: 0.8653 - recall: 0.5567 - auc: 0.9329 - prc: 0.7471 - val_loss: 0.0034 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8081
Epoch 20/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0035 - tp: 101.1978 - fp: 20.1209 - tn: 93958.4615 - fn: 60.7912 - acc: 0.9992 - precision: 0.8429 - recall: 0.6451 - auc: 0.9545 - prc: 0.7794 - val_loss: 0.0034 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8084
Epoch 21/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0033 - tp: 102.5495 - fp: 16.1868 - tn: 93968.0000 - fn: 53.8352 - acc: 0.9993 - precision: 0.8732 - recall: 0.6806 - auc: 0.9385 - prc: 0.7763 - val_loss: 0.0034 - val_tp: 52.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 27.0000 - val_acc: 0.9993 - val_precision: 0.9123 - val_recall: 0.6582 - val_auc: 0.9238 - val_prc: 0.8112
Epoch 22/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0044 - tp: 89.4615 - fp: 16.7363 - tn: 93964.2857 - fn: 70.0879 - acc: 0.9991 - precision: 0.8323 - recall: 0.5465 - auc: 0.9189 - prc: 0.6915 - val_loss: 0.0033 - val_tp: 55.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 24.0000 - val_acc: 0.9994 - val_precision: 0.9167 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8102
Epoch 23/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0038 - tp: 98.5385 - fp: 15.2198 - tn: 93966.8791 - fn: 59.9341 - acc: 0.9992 - precision: 0.8713 - recall: 0.6000 - auc: 0.9229 - prc: 0.7481 - val_loss: 0.0033 - val_tp: 55.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 24.0000 - val_acc: 0.9994 - val_precision: 0.9167 - val_recall: 0.6962 - val_auc: 0.9237 - val_prc: 0.8129
Epoch 24/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0039 - tp: 97.1868 - fp: 18.4505 - tn: 93961.9890 - fn: 62.9451 - acc: 0.9992 - precision: 0.8358 - recall: 0.6341 - auc: 0.9337 - prc: 0.7182 - val_loss: 0.0034 - val_tp: 54.0000 - val_fp: 1.0000 - val_tn: 45489.0000 - val_fn: 25.0000 - val_acc: 0.9994 - val_precision: 0.9818 - val_recall: 0.6835 - val_auc: 0.9238 - val_prc: 0.8156
Epoch 25/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0042 - tp: 99.9231 - fp: 16.0769 - tn: 93957.1319 - fn: 67.4396 - acc: 0.9991 - precision: 0.8482 - recall: 0.6010 - auc: 0.9216 - prc: 0.7227 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8132
Epoch 26/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0037 - tp: 103.0769 - fp: 16.0110 - tn: 93956.2747 - fn: 65.2088 - acc: 0.9992 - precision: 0.8841 - recall: 0.6262 - auc: 0.9341 - prc: 0.7564 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8134
Epoch 27/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0038 - tp: 98.8022 - fp: 15.9011 - tn: 93962.5604 - fn: 63.3077 - acc: 0.9992 - precision: 0.8723 - recall: 0.6362 - auc: 0.9365 - prc: 0.7486 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 20.0000 - val_acc: 0.9994 - val_precision: 0.9077 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8124
Epoch 28/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0040 - tp: 98.9451 - fp: 15.8022 - tn: 93965.5714 - fn: 60.2527 - acc: 0.9992 - precision: 0.8606 - recall: 0.6430 - auc: 0.9241 - prc: 0.7183 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 20.0000 - val_acc: 0.9994 - val_precision: 0.9077 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8148
Epoch 29/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0036 - tp: 104.4286 - fp: 13.1648 - tn: 93965.1429 - fn: 57.8352 - acc: 0.9993 - precision: 0.8989 - recall: 0.6587 - auc: 0.9396 - prc: 0.7463 - val_loss: 0.0033 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9062 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.8150
Epoch 30/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0039 - tp: 104.1319 - fp: 12.7033 - tn: 93967.4066 - fn: 56.3297 - acc: 0.9992 - precision: 0.8847 - recall: 0.6371 - auc: 0.9197 - prc: 0.7073 - val_loss: 0.0033 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8163
Epoch 31/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0037 - tp: 91.6154 - fp: 17.1648 - tn: 93966.5495 - fn: 65.2418 - acc: 0.9992 - precision: 0.8379 - recall: 0.5814 - auc: 0.9231 - prc: 0.7264 - val_loss: 0.0033 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8168
Epoch 32/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0038 - tp: 98.9231 - fp: 14.8681 - tn: 93962.0440 - fn: 64.7363 - acc: 0.9991 - precision: 0.8635 - recall: 0.5902 - auc: 0.9405 - prc: 0.7401 - val_loss: 0.0033 - val_tp: 56.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 23.0000 - val_acc: 0.9994 - val_precision: 0.9032 - val_recall: 0.7089 - val_auc: 0.9238 - val_prc: 0.8164
Epoch 33/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0037 - tp: 102.0549 - fp: 16.8242 - tn: 93959.9231 - fn: 61.7692 - acc: 0.9992 - precision: 0.8656 - recall: 0.6255 - auc: 0.9333 - prc: 0.7486 - val_loss: 0.0034 - val_tp: 54.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 25.0000 - val_acc: 0.9994 - val_precision: 0.9643 - val_recall: 0.6835 - val_auc: 0.9238 - val_prc: 0.8174
Epoch 34/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0030 - tp: 106.4396 - fp: 10.6593 - tn: 93972.9231 - fn: 50.5495 - acc: 0.9994 - precision: 0.9169 - recall: 0.6839 - auc: 0.9287 - prc: 0.7915 - val_loss: 0.0034 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8158
Epoch 35/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0039 - tp: 96.9780 - fp: 15.2857 - tn: 93964.0220 - fn: 64.2857 - acc: 0.9992 - precision: 0.8767 - recall: 0.5986 - auc: 0.9371 - prc: 0.7327 - val_loss: 0.0034 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8157
Epoch 36/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0039 - tp: 97.0440 - fp: 16.4725 - tn: 93962.8242 - fn: 64.2308 - acc: 0.9992 - precision: 0.8561 - recall: 0.6162 - auc: 0.9231 - prc: 0.7042 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8161
Epoch 37/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0040 - tp: 98.1978 - fp: 16.4176 - tn: 93961.2747 - fn: 64.6813 - acc: 0.9992 - precision: 0.8608 - recall: 0.6091 - auc: 0.9226 - prc: 0.7067 - val_loss: 0.0034 - val_tp: 54.0000 - val_fp: 3.0000 - val_tn: 45487.0000 - val_fn: 25.0000 - val_acc: 0.9994 - val_precision: 0.9474 - val_recall: 0.6835 - val_auc: 0.9238 - val_prc: 0.8160
Epoch 38/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0036 - tp: 97.2198 - fp: 15.0879 - tn: 93970.4835 - fn: 57.7802 - acc: 0.9992 - precision: 0.8639 - recall: 0.6325 - auc: 0.9437 - prc: 0.7358 - val_loss: 0.0034 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8155
Epoch 39/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0041 - tp: 107.5385 - fp: 18.5604 - tn: 93953.8352 - fn: 60.6374 - acc: 0.9991 - precision: 0.8364 - recall: 0.6481 - auc: 0.9290 - prc: 0.7210 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8160
Epoch 40/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0041 - tp: 96.3846 - fp: 16.7143 - tn: 93964.8352 - fn: 62.6374 - acc: 0.9992 - precision: 0.8532 - recall: 0.6206 - auc: 0.9248 - prc: 0.6902 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8154
Epoch 41/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0037 - tp: 103.1758 - fp: 14.4505 - tn: 93961.2198 - fn: 61.7253 - acc: 0.9992 - precision: 0.8754 - recall: 0.6197 - auc: 0.9334 - prc: 0.7572 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8157
Epoch 42/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0034 - tp: 90.2088 - fp: 15.3846 - tn: 93967.3516 - fn: 67.6264 - acc: 0.9991 - precision: 0.8634 - recall: 0.5541 - auc: 0.9456 - prc: 0.7623 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8152
Epoch 43/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0040 - tp: 96.2308 - fp: 21.1209 - tn: 93957.6813 - fn: 65.5385 - acc: 0.9990 - precision: 0.8026 - recall: 0.5805 - auc: 0.9326 - prc: 0.7090 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8150
Restoring model weights from the end of the best epoch.
Epoch 00043: early stopping

查看训练模型的history图

def func_plot_metrics(history):
    metrics = ['loss', 'prc', 'precision', 'recall']
    for n, metric in enumerate(metrics):
        name = metric.replace("_"," ").capitalize()
        plt.subplot(2,2,n+1)
        plt.plot(history.epoch, history.history[metric], color='blue', label='Train')
        plt.plot(history.epoch, history.history['val_'+metric], color='blue', linestyle="--", label='Val')
        plt.xlabel('Epoch')
        plt.ylabel(name)
        if metric == 'loss':
            pass
#             plt.ylim([0, plt.ylim()[1]])
        elif metric == 'auc':
            plt.ylim([0.8,1])
        elif metric == 'recall':
            plt.ylim([0,1])
        else:
            plt.ylim([0,1])
        plt.legend()
    plt.tight_layout()
func_plot_metrics(baseline_history)

output_60_0

查看一下混淆矩阵

train_predictions_baseline = model_baseline.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model_baseline.predict(test_features, batch_size=BATCH_SIZE)
def func_plot_cm(labels, predictions, p=0.5):
    cm = confusion_matrix(labels, predictions > p)
    plt.figure(figsize=(5,5))
    sns.heatmap(cm, annot=True, fmt="d")
    plt.title('Confusion matrix @{:.2f}'.format(p))
    plt.ylabel('Actual label')
    plt.xlabel('Predicted label')

    print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
    print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
    print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
    print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
    print('Total Fraudulent Transactions: ', np.sum(cm[1]))
baseline_results = model_baseline.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model_baseline.metrics_names, baseline_results):
    print(name, ': ', value)
print()

func_plot_cm(test_labels, test_predictions_baseline)
loss :  0.002730185631662607
tp :  72.0
fp :  9.0
tn :  56855.0
fn :  26.0
acc :  0.9993855357170105
precision :  0.8888888955116272
recall :  0.7346938848495483
auc :  0.9487220048904419
prc :  0.8363494873046875

Legitimate Transactions Detected (True Negatives):  56855
Legitimate Transactions Incorrectly Detected (False Positives):  9
Fraudulent Transactions Missed (False Negatives):  26
Fraudulent Transactions Detected (True Positives):  72
Total Fraudulent Transactions:  98

output_64_1

画roc图

def func_plot_roc(name, labels, predictions, **kwargs):
    fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

    plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
    plt.xlabel('False positives [%]')
    plt.ylabel('True positives [%]')
    plt.xlim([-0.5,20])
    plt.ylim([80,100.5])
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
func_plot_roc("Train Baseline", train_labels, train_predictions_baseline, color='red')
func_plot_roc("Test Baseline", test_labels, test_predictions_baseline, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2cc0264550>

output_67_1

画PR图

def func_plot_prc(name, labels, predictions, **kwargs):
    precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)

    plt.plot(precision, recall, label=name, linewidth=2, **kwargs)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
func_plot_prc("Train Baseline", train_labels, train_predictions_baseline, color='red')
func_plot_prc("Test Baseline", test_labels, test_predictions_baseline, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2cc00d0e20>

output_70_1

model_1: add class_weight

计算class_weight

weight_for_0 = (1 / neg) * (total / 2.0)  # total / (2*neg)
weight_for_1 = (1 / pos) * (total / 2.0)  # total / (2*pos)

class_weight = {
    0: weight_for_0,
    1: weight_for_1
}

class_weight
{0: 0.5008652375006595, 1: 289.43800813008136}

train a model with class weight

model_weight = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_weight.load_weights(correct_initial_bias_path)
model_weight_history = model_weight.fit(train_features, train_labels, 
                                        batch_size=BATCH_SIZE, epochs=EPOCHS, 
                                        callbacks=[early_stopping], 
                                        validation_data=(val_features, val_labels),
                                        # the class weights go here
                                        class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 4s 20ms/step - loss: 3.3108 - tp: 85.5495 - fp: 103.0659 - tn: 150735.5495 - fn: 178.4066 - acc: 0.9983 - precision: 0.5085 - recall: 0.3590 - auc: 0.7769 - prc: 0.3236 - val_loss: 0.0081 - val_tp: 10.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 69.0000 - val_acc: 0.9984 - val_precision: 0.6250 - val_recall: 0.1266 - val_auc: 0.9132 - val_prc: 0.3528
Epoch 2/100
90/90 [==============================] - 1s 9ms/step - loss: 1.5420 - tp: 58.9451 - fp: 177.4945 - tn: 93802.3516 - fn: 101.7802 - acc: 0.9972 - precision: 0.2352 - recall: 0.3316 - auc: 0.7668 - prc: 0.1856 - val_loss: 0.0093 - val_tp: 56.0000 - val_fp: 32.0000 - val_tn: 45458.0000 - val_fn: 23.0000 - val_acc: 0.9988 - val_precision: 0.6364 - val_recall: 0.7089 - val_auc: 0.9470 - val_prc: 0.5701
Epoch 3/100
90/90 [==============================] - 1s 8ms/step - loss: 0.8803 - tp: 92.3956 - fp: 448.8462 - tn: 93533.2418 - fn: 66.0879 - acc: 0.9947 - precision: 0.1718 - recall: 0.5509 - auc: 0.8812 - prc: 0.2806 - val_loss: 0.0147 - val_tp: 62.0000 - val_fp: 94.0000 - val_tn: 45396.0000 - val_fn: 17.0000 - val_acc: 0.9976 - val_precision: 0.3974 - val_recall: 0.7848 - val_auc: 0.9680 - val_prc: 0.6804
Epoch 4/100
90/90 [==============================] - 1s 8ms/step - loss: 0.6487 - tp: 110.7582 - fp: 801.8242 - tn: 93174.9560 - fn: 53.0330 - acc: 0.9914 - precision: 0.1226 - recall: 0.6634 - auc: 0.9129 - prc: 0.2730 - val_loss: 0.0227 - val_tp: 66.0000 - val_fp: 166.0000 - val_tn: 45324.0000 - val_fn: 13.0000 - val_acc: 0.9961 - val_precision: 0.2845 - val_recall: 0.8354 - val_auc: 0.9691 - val_prc: 0.7008
Epoch 5/100
90/90 [==============================] - 1s 9ms/step - loss: 0.4749 - tp: 126.5165 - fp: 1314.5934 - tn: 92661.6813 - fn: 37.7802 - acc: 0.9859 - precision: 0.0907 - recall: 0.7836 - auc: 0.9293 - prc: 0.2892 - val_loss: 0.0308 - val_tp: 66.0000 - val_fp: 281.0000 - val_tn: 45209.0000 - val_fn: 13.0000 - val_acc: 0.9935 - val_precision: 0.1902 - val_recall: 0.8354 - val_auc: 0.9708 - val_prc: 0.6695
Epoch 6/100
90/90 [==============================] - 1s 9ms/step - loss: 0.4803 - tp: 128.5824 - fp: 1652.0330 - tn: 92324.8791 - fn: 35.0769 - acc: 0.9824 - precision: 0.0749 - recall: 0.7747 - auc: 0.9277 - prc: 0.2729 - val_loss: 0.0380 - val_tp: 66.0000 - val_fp: 448.0000 - val_tn: 45042.0000 - val_fn: 13.0000 - val_acc: 0.9899 - val_precision: 0.1284 - val_recall: 0.8354 - val_auc: 0.9745 - val_prc: 0.6359
Epoch 7/100
90/90 [==============================] - 1s 9ms/step - loss: 0.3692 - tp: 136.0110 - fp: 2004.2967 - tn: 91972.8242 - fn: 27.4396 - acc: 0.9787 - precision: 0.0646 - recall: 0.8400 - auc: 0.9357 - prc: 0.2538 - val_loss: 0.0446 - val_tp: 66.0000 - val_fp: 548.0000 - val_tn: 44942.0000 - val_fn: 13.0000 - val_acc: 0.9877 - val_precision: 0.1075 - val_recall: 0.8354 - val_auc: 0.9741 - val_prc: 0.6173
Epoch 8/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3391 - tp: 133.3846 - fp: 2250.4615 - tn: 91731.2527 - fn: 25.4725 - acc: 0.9759 - precision: 0.0547 - recall: 0.8384 - auc: 0.9464 - prc: 0.2333 - val_loss: 0.0502 - val_tp: 66.0000 - val_fp: 602.0000 - val_tn: 44888.0000 - val_fn: 13.0000 - val_acc: 0.9865 - val_precision: 0.0988 - val_recall: 0.8354 - val_auc: 0.9740 - val_prc: 0.6335
Epoch 9/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3891 - tp: 134.6374 - fp: 2574.4066 - tn: 91403.6154 - fn: 27.9121 - acc: 0.9727 - precision: 0.0507 - recall: 0.8210 - auc: 0.9373 - prc: 0.2371 - val_loss: 0.0550 - val_tp: 66.0000 - val_fp: 643.0000 - val_tn: 44847.0000 - val_fn: 13.0000 - val_acc: 0.9856 - val_precision: 0.0931 - val_recall: 0.8354 - val_auc: 0.9756 - val_prc: 0.6202
Epoch 10/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3159 - tp: 137.5714 - fp: 2726.1538 - tn: 91248.8462 - fn: 28.0000 - acc: 0.9712 - precision: 0.0490 - recall: 0.8459 - auc: 0.9529 - prc: 0.2447 - val_loss: 0.0597 - val_tp: 67.0000 - val_fp: 707.0000 - val_tn: 44783.0000 - val_fn: 12.0000 - val_acc: 0.9842 - val_precision: 0.0866 - val_recall: 0.8481 - val_auc: 0.9754 - val_prc: 0.6138
Epoch 11/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2898 - tp: 142.0110 - fp: 2837.5055 - tn: 91137.0440 - fn: 24.0110 - acc: 0.9695 - precision: 0.0480 - recall: 0.8655 - auc: 0.9555 - prc: 0.2209 - val_loss: 0.0630 - val_tp: 67.0000 - val_fp: 753.0000 - val_tn: 44737.0000 - val_fn: 12.0000 - val_acc: 0.9832 - val_precision: 0.0817 - val_recall: 0.8481 - val_auc: 0.9755 - val_prc: 0.6147
Epoch 12/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2809 - tp: 134.1429 - fp: 2854.5824 - tn: 91132.2088 - fn: 19.6374 - acc: 0.9695 - precision: 0.0454 - recall: 0.8817 - auc: 0.9487 - prc: 0.2253 - val_loss: 0.0639 - val_tp: 67.0000 - val_fp: 770.0000 - val_tn: 44720.0000 - val_fn: 12.0000 - val_acc: 0.9828 - val_precision: 0.0800 - val_recall: 0.8481 - val_auc: 0.9754 - val_prc: 0.6218
Epoch 13/100
90/90 [==============================] - 1s 9ms/step - loss: 0.3396 - tp: 139.3956 - fp: 2952.3626 - tn: 91026.8462 - fn: 21.9670 - acc: 0.9683 - precision: 0.0439 - recall: 0.8461 - auc: 0.9395 - prc: 0.2182 - val_loss: 0.0630 - val_tp: 67.0000 - val_fp: 758.0000 - val_tn: 44732.0000 - val_fn: 12.0000 - val_acc: 0.9831 - val_precision: 0.0812 - val_recall: 0.8481 - val_auc: 0.9763 - val_prc: 0.6152
Epoch 14/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2931 - tp: 133.0000 - fp: 2956.3736 - tn: 91029.8462 - fn: 21.3516 - acc: 0.9685 - precision: 0.0404 - recall: 0.8468 - auc: 0.9493 - prc: 0.2041 - val_loss: 0.0675 - val_tp: 67.0000 - val_fp: 811.0000 - val_tn: 44679.0000 - val_fn: 12.0000 - val_acc: 0.9819 - val_precision: 0.0763 - val_recall: 0.8481 - val_auc: 0.9769 - val_prc: 0.6114
Restoring model weights from the end of the best epoch.
Epoch 00014: early stopping

check training history

func_plot_metrics(model_weight_history)

output_76_0

查看一下混淆矩阵

train_predictions_weight = model_weight.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weight = model_weight.predict(test_features, batch_size=BATCH_SIZE)
results = model_weight.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model_weight.metrics_names, results):
    print(name, ': ', value)
print()

func_plot_cm(test_labels, test_predictions_weight)
loss :  0.02326195128262043
tp :  83.0
fp :  210.0
tn :  56654.0
fn :  15.0
acc :  0.9960500001907349
precision :  0.28327643871307373
recall :  0.8469387888908386
auc :  0.969890296459198
prc :  0.7208744287490845

Legitimate Transactions Detected (True Negatives):  56654
Legitimate Transactions Incorrectly Detected (False Positives):  210
Fraudulent Transactions Missed (False Negatives):  15
Fraudulent Transactions Detected (True Positives):  83
Total Fraudulent Transactions:  98

output_79_1

画roc图

func_plot_roc("Train weight", train_labels, train_predictions_weight, color='red')
func_plot_roc("Test weight", test_labels, test_predictions_weight, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2c9c227dc0>

output_81_1

画PR图

func_plot_prc("Train weight", train_labels, train_predictions_weight, color='red')
func_plot_prc("Test weight", test_labels, test_predictions_weight, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2c9c19fd90>

output_83_1

model_2: 上采样少的那个类别

上采样

pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]

pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]

using numpy

ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features))  # Generates a random sample from a given 1-D array

res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]

res_pos_features.shape
(181961, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)

# 随机打乱
order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]

resampled_features.shape
(363922, 29)

using tf.data

def func_make_ds(features, labels):
    ds = tf.data.Dataset.from_tensor_slices((features, labels))
    ds = ds.shuffle(len(features)).repeat()  # 重复无数次
    return ds
pos_ds = func_make_ds(pos_features, pos_labels)
neg_ds = func_make_ds(neg_features, neg_labels)

for features, label in pos_ds.take(1):
    print(features, label)
tf.Tensor(
[-0.72499268  2.48494345 -5.          4.71183888 -3.39522677 -1.54142183
 -5.          2.32846081 -2.95406275 -5.          5.         -5.
  0.5493042  -5.          0.99578148 -5.         -5.         -5.
  3.88524349  1.80758865  2.51455711  0.55774627  0.9575459  -1.26789652
 -3.3464076   1.04065595  4.89460213  2.17526156 -1.45033267], shape=(29,), dtype=float64) tf.Tensor(1, shape=(), dtype=int64)
resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
for features, label in resampled_ds.take(1):
    print(label.numpy().mean())
0.50537109375
resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE)  # 正负样本量/batch_size
resampled_steps_per_epoch
278.0

train a model on oversample data

model_sampled = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_sampled.load_weights(correct_initial_bias_path)

# reset the bias to zero, since this dataset is balanced
model_sampled.layers[-1].bias.assign([0.0])  # 设置最后一层的bias为0

model_sampled_history = model_sampled.fit(resampled_ds, epochs=EPOCHS, steps_per_epoch=resampled_steps_per_epoch,
                                          callbacks=[early_stopping], 
                                          validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 9s 25ms/step - loss: 0.7187 - tp: 106770.7025 - fp: 41735.0538 - tn: 158500.6093 - fn: 36668.2939 - acc: 0.7680 - precision: 0.6690 - recall: 0.6851 - auc: 0.8375 - prc: 0.7999 - val_loss: 0.1986 - val_tp: 67.0000 - val_fp: 1074.0000 - val_tn: 44416.0000 - val_fn: 12.0000 - val_acc: 0.9762 - val_precision: 0.0587 - val_recall: 0.8481 - val_auc: 0.9660 - val_prc: 0.7480
Epoch 2/100
278/278 [==============================] - 6s 23ms/step - loss: 0.2155 - tp: 128022.4910 - fp: 8284.8065 - tn: 135214.0609 - fn: 15191.3011 - acc: 0.9150 - precision: 0.9348 - recall: 0.8921 - auc: 0.9683 - prc: 0.9752 - val_loss: 0.1105 - val_tp: 67.0000 - val_fp: 704.0000 - val_tn: 44786.0000 - val_fn: 12.0000 - val_acc: 0.9843 - val_precision: 0.0869 - val_recall: 0.8481 - val_auc: 0.9679 - val_prc: 0.7478
Epoch 3/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1596 - tp: 129871.1828 - fp: 4884.9964 - tn: 138569.3584 - fn: 13387.1219 - acc: 0.9358 - precision: 0.9629 - recall: 0.9063 - auc: 0.9833 - prc: 0.9858 - val_loss: 0.0815 - val_tp: 68.0000 - val_fp: 640.0000 - val_tn: 44850.0000 - val_fn: 11.0000 - val_acc: 0.9857 - val_precision: 0.0960 - val_recall: 0.8608 - val_auc: 0.9694 - val_prc: 0.7165
Epoch 4/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1347 - tp: 131100.8423 - fp: 4062.0036 - tn: 139467.2903 - fn: 12082.5233 - acc: 0.9432 - precision: 0.9699 - recall: 0.9146 - auc: 0.9890 - prc: 0.9900 - val_loss: 0.0661 - val_tp: 68.0000 - val_fp: 571.0000 - val_tn: 44919.0000 - val_fn: 11.0000 - val_acc: 0.9872 - val_precision: 0.1064 - val_recall: 0.8608 - val_auc: 0.9694 - val_prc: 0.7182
Epoch 5/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1192 - tp: 132650.0143 - fp: 3900.9642 - tn: 139417.2115 - fn: 10744.4695 - acc: 0.9486 - precision: 0.9712 - recall: 0.9245 - auc: 0.9918 - prc: 0.9923 - val_loss: 0.0570 - val_tp: 67.0000 - val_fp: 569.0000 - val_tn: 44921.0000 - val_fn: 12.0000 - val_acc: 0.9873 - val_precision: 0.1053 - val_recall: 0.8481 - val_auc: 0.9670 - val_prc: 0.7172
Epoch 6/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1076 - tp: 133341.7706 - fp: 3745.2294 - tn: 139902.4552 - fn: 9723.2043 - acc: 0.9524 - precision: 0.9724 - recall: 0.9312 - auc: 0.9936 - prc: 0.9936 - val_loss: 0.0479 - val_tp: 67.0000 - val_fp: 508.0000 - val_tn: 44982.0000 - val_fn: 12.0000 - val_acc: 0.9886 - val_precision: 0.1165 - val_recall: 0.8481 - val_auc: 0.9682 - val_prc: 0.7194
Epoch 7/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0995 - tp: 134289.4444 - fp: 3585.8961 - tn: 140065.9319 - fn: 8771.3871 - acc: 0.9566 - precision: 0.9739 - recall: 0.9381 - auc: 0.9947 - prc: 0.9946 - val_loss: 0.0433 - val_tp: 67.0000 - val_fp: 502.0000 - val_tn: 44988.0000 - val_fn: 12.0000 - val_acc: 0.9887 - val_precision: 0.1178 - val_recall: 0.8481 - val_auc: 0.9685 - val_prc: 0.7209
Epoch 8/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0906 - tp: 136117.1470 - fp: 3446.8566 - tn: 139792.3333 - fn: 7356.3226 - acc: 0.9623 - precision: 0.9752 - recall: 0.9490 - auc: 0.9956 - prc: 0.9955 - val_loss: 0.0391 - val_tp: 67.0000 - val_fp: 458.0000 - val_tn: 45032.0000 - val_fn: 12.0000 - val_acc: 0.9897 - val_precision: 0.1276 - val_recall: 0.8481 - val_auc: 0.9647 - val_prc: 0.7208
Epoch 9/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0871 - tp: 136469.3692 - fp: 3419.2043 - tn: 139679.6810 - fn: 7144.4050 - acc: 0.9630 - precision: 0.9756 - recall: 0.9498 - auc: 0.9959 - prc: 0.9957 - val_loss: 0.0363 - val_tp: 67.0000 - val_fp: 448.0000 - val_tn: 45042.0000 - val_fn: 12.0000 - val_acc: 0.9899 - val_precision: 0.1301 - val_recall: 0.8481 - val_auc: 0.9607 - val_prc: 0.7211
Epoch 10/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0823 - tp: 136684.5591 - fp: 3467.4803 - tn: 140006.9319 - fn: 6553.6882 - acc: 0.9649 - precision: 0.9751 - recall: 0.9540 - auc: 0.9963 - prc: 0.9961 - val_loss: 0.0333 - val_tp: 67.0000 - val_fp: 437.0000 - val_tn: 45053.0000 - val_fn: 12.0000 - val_acc: 0.9901 - val_precision: 0.1329 - val_recall: 0.8481 - val_auc: 0.9611 - val_prc: 0.7214
Epoch 11/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0779 - tp: 137468.5520 - fp: 3409.0538 - tn: 139987.1111 - fn: 5847.9427 - acc: 0.9673 - precision: 0.9760 - recall: 0.9581 - auc: 0.9967 - prc: 0.9965 - val_loss: 0.0317 - val_tp: 67.0000 - val_fp: 436.0000 - val_tn: 45054.0000 - val_fn: 12.0000 - val_acc: 0.9902 - val_precision: 0.1332 - val_recall: 0.8481 - val_auc: 0.9515 - val_prc: 0.7120
Restoring model weights from the end of the best epoch.
Epoch 00011: early stopping

check training history

func_plot_metrics(model_sampled_history)

output_98_0

查看一下混淆矩阵

train_predictions_sampled = model_sampled.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_sampled = model_sampled.predict(test_features, batch_size=BATCH_SIZE)
results = model_sampled.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model_sampled.metrics_names, results):
    print(name, ': ', value)
print()

func_plot_cm(test_labels, test_predictions_sampled)
loss :  0.19912180304527283
tp :  91.0
fp :  1298.0
tn :  55566.0
fn :  7.0
acc :  0.9770900011062622
precision :  0.06551475822925568
recall :  0.9285714030265808
auc :  0.9868985414505005
prc :  0.7712134718894958

Legitimate Transactions Detected (True Negatives):  55566
Legitimate Transactions Incorrectly Detected (False Positives):  1298
Fraudulent Transactions Missed (False Negatives):  7
Fraudulent Transactions Detected (True Positives):  91
Total Fraudulent Transactions:  98

output_101_1

画roc图

func_plot_roc("Train sampled", train_labels, train_predictions_sampled, color='red')
func_plot_roc("Test sampled", test_labels, test_predictions_sampled, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2c5c561fd0>

output_103_1

画PR图

func_plot_prc("Train sampled", train_labels, train_predictions_sampled, color='red')
func_plot_prc("Test sampled", test_labels, test_predictions_sampled, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2c5c7c57c0>

output_105_1




文章作者: Myhaa
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Myhaa !
评论
  目录