TensorFlow学习之不平衡数据的分类
参考
导入模块
import tensorflow as tf
from tensorflow import keras
import os
import datetime
import tempfile
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
2021-09-14 15:38:33.691409: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
mpl.rcParams['figure.figsize'] = (12, 10)
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
colors
['#1f77b4',
'#ff7f0e',
'#2ca02c',
'#d62728',
'#9467bd',
'#8c564b',
'#e377c2',
'#7f7f7f',
'#bcbd22',
'#17becf']
导入数据
train_path = './TensorFlow学习之不平衡数据的分类/creditcard.csv'
raw_data = pd.read_csv(train_path)
数据探索
raw_data.head()
|
Time |
V1 |
V2 |
V3 |
V4 |
V5 |
V6 |
V7 |
V8 |
V9 |
... |
V21 |
V22 |
V23 |
V24 |
V25 |
V26 |
V27 |
V28 |
Amount |
Class |
0 |
0.0 |
-1.359807 |
-0.072781 |
2.536347 |
1.378155 |
-0.338321 |
0.462388 |
0.239599 |
0.098698 |
0.363787 |
... |
-0.018307 |
0.277838 |
-0.110474 |
0.066928 |
0.128539 |
-0.189115 |
0.133558 |
-0.021053 |
149.62 |
0 |
1 |
0.0 |
1.191857 |
0.266151 |
0.166480 |
0.448154 |
0.060018 |
-0.082361 |
-0.078803 |
0.085102 |
-0.255425 |
... |
-0.225775 |
-0.638672 |
0.101288 |
-0.339846 |
0.167170 |
0.125895 |
-0.008983 |
0.014724 |
2.69 |
0 |
2 |
1.0 |
-1.358354 |
-1.340163 |
1.773209 |
0.379780 |
-0.503198 |
1.800499 |
0.791461 |
0.247676 |
-1.514654 |
... |
0.247998 |
0.771679 |
0.909412 |
-0.689281 |
-0.327642 |
-0.139097 |
-0.055353 |
-0.059752 |
378.66 |
0 |
3 |
1.0 |
-0.966272 |
-0.185226 |
1.792993 |
-0.863291 |
-0.010309 |
1.247203 |
0.237609 |
0.377436 |
-1.387024 |
... |
-0.108300 |
0.005274 |
-0.190321 |
-1.175575 |
0.647376 |
-0.221929 |
0.062723 |
0.061458 |
123.50 |
0 |
4 |
2.0 |
-1.158233 |
0.877737 |
1.548718 |
0.403034 |
-0.407193 |
0.095921 |
0.592941 |
-0.270533 |
0.817739 |
... |
-0.009431 |
0.798278 |
-0.137458 |
0.141267 |
-0.206010 |
0.502292 |
0.219422 |
0.215153 |
69.99 |
0 |
5 rows × 31 columns
raw_data.info(show_counts=True)
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 284807 entries, 0 to 284806
Data columns (total 31 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Time 284807 non-null float64
1 V1 284807 non-null float64
2 V2 284807 non-null float64
3 V3 284807 non-null float64
4 V4 284807 non-null float64
5 V5 284807 non-null float64
6 V6 284807 non-null float64
7 V7 284807 non-null float64
8 V8 284807 non-null float64
9 V9 284807 non-null float64
10 V10 284807 non-null float64
11 V11 284807 non-null float64
12 V12 284807 non-null float64
13 V13 284807 non-null float64
14 V14 284807 non-null float64
15 V15 284807 non-null float64
16 V16 284807 non-null float64
17 V17 284807 non-null float64
18 V18 284807 non-null float64
19 V19 284807 non-null float64
20 V20 284807 non-null float64
21 V21 284807 non-null float64
22 V22 284807 non-null float64
23 V23 284807 non-null float64
24 V24 284807 non-null float64
25 V25 284807 non-null float64
26 V26 284807 non-null float64
27 V27 284807 non-null float64
28 V28 284807 non-null float64
29 Amount 284807 non-null float64
30 Class 284807 non-null int64
dtypes: float64(30), int64(1)
memory usage: 67.4 MB
raw_data.describe(include='all').T
|
count |
mean |
std |
min |
25% |
50% |
75% |
max |
Time |
284807.0 |
9.481386e+04 |
47488.145955 |
0.000000 |
54201.500000 |
84692.000000 |
139320.500000 |
172792.000000 |
V1 |
284807.0 |
1.168375e-15 |
1.958696 |
-56.407510 |
-0.920373 |
0.018109 |
1.315642 |
2.454930 |
V2 |
284807.0 |
3.416908e-16 |
1.651309 |
-72.715728 |
-0.598550 |
0.065486 |
0.803724 |
22.057729 |
V3 |
284807.0 |
-1.379537e-15 |
1.516255 |
-48.325589 |
-0.890365 |
0.179846 |
1.027196 |
9.382558 |
V4 |
284807.0 |
2.074095e-15 |
1.415869 |
-5.683171 |
-0.848640 |
-0.019847 |
0.743341 |
16.875344 |
V5 |
284807.0 |
9.604066e-16 |
1.380247 |
-113.743307 |
-0.691597 |
-0.054336 |
0.611926 |
34.801666 |
V6 |
284807.0 |
1.487313e-15 |
1.332271 |
-26.160506 |
-0.768296 |
-0.274187 |
0.398565 |
73.301626 |
V7 |
284807.0 |
-5.556467e-16 |
1.237094 |
-43.557242 |
-0.554076 |
0.040103 |
0.570436 |
120.589494 |
V8 |
284807.0 |
1.213481e-16 |
1.194353 |
-73.216718 |
-0.208630 |
0.022358 |
0.327346 |
20.007208 |
V9 |
284807.0 |
-2.406331e-15 |
1.098632 |
-13.434066 |
-0.643098 |
-0.051429 |
0.597139 |
15.594995 |
V10 |
284807.0 |
2.239053e-15 |
1.088850 |
-24.588262 |
-0.535426 |
-0.092917 |
0.453923 |
23.745136 |
V11 |
284807.0 |
1.673327e-15 |
1.020713 |
-4.797473 |
-0.762494 |
-0.032757 |
0.739593 |
12.018913 |
V12 |
284807.0 |
-1.247012e-15 |
0.999201 |
-18.683715 |
-0.405571 |
0.140033 |
0.618238 |
7.848392 |
V13 |
284807.0 |
8.190001e-16 |
0.995274 |
-5.791881 |
-0.648539 |
-0.013568 |
0.662505 |
7.126883 |
V14 |
284807.0 |
1.207294e-15 |
0.958596 |
-19.214325 |
-0.425574 |
0.050601 |
0.493150 |
10.526766 |
V15 |
284807.0 |
4.887456e-15 |
0.915316 |
-4.498945 |
-0.582884 |
0.048072 |
0.648821 |
8.877742 |
V16 |
284807.0 |
1.437716e-15 |
0.876253 |
-14.129855 |
-0.468037 |
0.066413 |
0.523296 |
17.315112 |
V17 |
284807.0 |
-3.772171e-16 |
0.849337 |
-25.162799 |
-0.483748 |
-0.065676 |
0.399675 |
9.253526 |
V18 |
284807.0 |
9.564149e-16 |
0.838176 |
-9.498746 |
-0.498850 |
-0.003636 |
0.500807 |
5.041069 |
V19 |
284807.0 |
1.039917e-15 |
0.814041 |
-7.213527 |
-0.456299 |
0.003735 |
0.458949 |
5.591971 |
V20 |
284807.0 |
6.406204e-16 |
0.770925 |
-54.497720 |
-0.211721 |
-0.062481 |
0.133041 |
39.420904 |
V21 |
284807.0 |
1.654067e-16 |
0.734524 |
-34.830382 |
-0.228395 |
-0.029450 |
0.186377 |
27.202839 |
V22 |
284807.0 |
-3.568593e-16 |
0.725702 |
-10.933144 |
-0.542350 |
0.006782 |
0.528554 |
10.503090 |
V23 |
284807.0 |
2.578648e-16 |
0.624460 |
-44.807735 |
-0.161846 |
-0.011193 |
0.147642 |
22.528412 |
V24 |
284807.0 |
4.473266e-15 |
0.605647 |
-2.836627 |
-0.354586 |
0.040976 |
0.439527 |
4.584549 |
V25 |
284807.0 |
5.340915e-16 |
0.521278 |
-10.295397 |
-0.317145 |
0.016594 |
0.350716 |
7.519589 |
V26 |
284807.0 |
1.683437e-15 |
0.482227 |
-2.604551 |
-0.326984 |
-0.052139 |
0.240952 |
3.517346 |
V27 |
284807.0 |
-3.660091e-16 |
0.403632 |
-22.565679 |
-0.070840 |
0.001342 |
0.091045 |
31.612198 |
V28 |
284807.0 |
-1.227390e-16 |
0.330083 |
-15.430084 |
-0.052960 |
0.011244 |
0.078280 |
33.847808 |
Amount |
284807.0 |
8.834962e+01 |
250.120109 |
0.000000 |
5.600000 |
22.000000 |
77.165000 |
25691.160000 |
Class |
284807.0 |
1.727486e-03 |
0.041527 |
0.000000 |
0.000000 |
0.000000 |
0.000000 |
1.000000 |
检查平衡性
target = 'Class'
raw_data.groupby(target, as_index=False)['Time'].count()
|
Class |
Time |
0 |
0 |
284315 |
1 |
1 |
492 |
neg, pos = np.bincount(raw_data[target])
total = neg + pos
print('Examples:\n Total: {}\n Positive: {} ({:.2f}% of total)\n'.format(
total, pos, 100 * pos / total))
Examples:
Total: 284807
Positive: 492 (0.17% of total)
dataframe = raw_data.copy()
数据处理
- 剔除无用特征,对Amount进行log变换
- 数据标准化
- 截断
log变换
dataframe.pop('Time')
eps = 0.001 # 0 => 0.1¢
for feature in ['Amount']:
dataframe[feature+'_log'] = np.log(dataframe.pop(feature)+eps)
划分训练集、验证集、测试集
train, test = train_test_split(dataframe, test_size=0.2, stratify=dataframe[target], random_state=2030)
train, val = train_test_split(train, test_size=0.2, stratify=train[target], random_state=2030)
print(len(train), 'train examples')
print(len(val), 'validation examples')
print(len(test), 'test examples')
182276 train examples
45569 validation examples
56962 test examples
# Form np arrays of labels and features.
train_labels = np.array(train.pop(target))
bool_train_labels = train_labels != 0
val_labels = np.array(val.pop(target))
test_labels = np.array(test.pop(target))
train_features = np.array(train)
val_features = np.array(val)
test_features = np.array(test)
标准化
- 注意用训练集的数据来fit transform验证集、测试集
scaler = StandardScaler()
train_features = scaler.fit_transform(train_features)
val_features = scaler.transform(val_features)
test_features = scaler.transform(test_features)
train_features = np.clip(train_features, -5, 5)
val_features = np.clip(val_features, -5, 5)
test_features = np.clip(test_features, -5, 5)
print('Training labels shape:', train_labels.shape)
print('Validation labels shape:', val_labels.shape)
print('Test labels shape:', test_labels.shape)
print('Training features shape:', train_features.shape)
print('Validation features shape:', val_features.shape)
print('Test features shape:', test_features.shape)
Training labels shape: (182276,)
Validation labels shape: (45569,)
Test labels shape: (56962,)
Training features shape: (182276, 29)
Validation features shape: (45569, 29)
Test features shape: (56962, 29)
观察数据分布
pos_df = pd.DataFrame(train_features[ bool_train_labels], columns=train.columns)
neg_df = pd.DataFrame(train_features[~bool_train_labels], columns=train.columns)
sns.jointplot(pos_df['V5'], pos_df['V6'],
kind='hex', xlim=(-5,5), ylim=(-5,5))
plt.suptitle("Positive distribution")
sns.jointplot(neg_df['V5'], neg_df['V6'],
kind='hex', xlim=(-5,5), ylim=(-5,5))
_ = plt.suptitle("Negative distribution")
/home/meiyunhe/softwares/miniconda3/envs/env_tensorflow/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
/home/meiyunhe/softwares/miniconda3/envs/env_tensorflow/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
pos_df.describe().T
|
count |
mean |
std |
min |
25% |
50% |
75% |
max |
V1 |
315.0 |
-1.760906 |
1.909379 |
-5.000000 |
-2.995964 |
-1.194895 |
-0.232911 |
1.087466 |
V2 |
315.0 |
1.821280 |
1.886729 |
-5.000000 |
0.708689 |
1.566040 |
3.037499 |
5.000000 |
V3 |
315.0 |
-2.966703 |
1.843923 |
-5.000000 |
-5.000000 |
-3.302910 |
-1.408258 |
1.486048 |
V4 |
315.0 |
2.940414 |
1.616449 |
-0.924989 |
1.633506 |
2.848541 |
4.487645 |
5.000000 |
V5 |
315.0 |
-1.514576 |
2.295449 |
-5.000000 |
-3.413844 |
-1.120011 |
0.140397 |
5.000000 |
V6 |
315.0 |
-1.041315 |
1.378416 |
-4.764846 |
-1.883094 |
-1.055371 |
-0.242898 |
4.816880 |
V7 |
315.0 |
-2.473715 |
2.118062 |
-5.000000 |
-5.000000 |
-2.385480 |
-0.803709 |
3.008926 |
V8 |
315.0 |
0.661945 |
2.325665 |
-5.000000 |
-0.136614 |
0.471054 |
1.472319 |
5.000000 |
V9 |
315.0 |
-2.107191 |
1.819530 |
-5.000000 |
-3.521353 |
-2.012870 |
-0.719422 |
3.049323 |
V10 |
315.0 |
-3.418757 |
1.897714 |
-5.000000 |
-5.000000 |
-4.057184 |
-2.573482 |
3.691563 |
V11 |
315.0 |
3.202421 |
1.776641 |
-1.665756 |
2.001885 |
3.536426 |
5.000000 |
5.000000 |
V12 |
315.0 |
-3.778767 |
1.782103 |
-5.000000 |
-5.000000 |
-5.000000 |
-2.989429 |
1.234618 |
V13 |
315.0 |
-0.122043 |
1.112873 |
-3.142507 |
-0.973313 |
-0.067780 |
0.622709 |
2.828751 |
V14 |
315.0 |
-4.120755 |
1.711212 |
-5.000000 |
-5.000000 |
-5.000000 |
-4.441194 |
1.639314 |
V15 |
315.0 |
-0.127527 |
1.146112 |
-3.382938 |
-0.759891 |
-0.073017 |
0.701528 |
2.697669 |
V16 |
315.0 |
-3.043271 |
2.275678 |
-5.000000 |
-5.000000 |
-3.836717 |
-1.316847 |
2.841122 |
V17 |
315.0 |
-3.171004 |
2.850112 |
-5.000000 |
-5.000000 |
-5.000000 |
-1.612624 |
5.000000 |
V18 |
315.0 |
-2.069377 |
2.456956 |
-5.000000 |
-5.000000 |
-1.969603 |
-0.060276 |
3.926839 |
V19 |
315.0 |
0.875712 |
1.872091 |
-4.032744 |
-0.339676 |
0.841347 |
1.961852 |
5.000000 |
V20 |
315.0 |
0.486428 |
1.460203 |
-5.000000 |
-0.160274 |
0.370908 |
1.086496 |
5.000000 |
V21 |
315.0 |
0.798396 |
1.932013 |
-5.000000 |
0.032131 |
0.790971 |
1.855851 |
5.000000 |
V22 |
315.0 |
0.137193 |
1.411641 |
-5.000000 |
-0.657904 |
0.159077 |
0.897130 |
5.000000 |
V23 |
315.0 |
0.019576 |
1.377033 |
-5.000000 |
-0.546028 |
-0.127559 |
0.525010 |
5.000000 |
V24 |
315.0 |
-0.124175 |
0.838200 |
-3.342668 |
-0.666720 |
-0.042600 |
0.534642 |
1.799373 |
V25 |
315.0 |
0.084645 |
1.419304 |
-4.153589 |
-0.605800 |
0.145072 |
0.876187 |
4.132529 |
V26 |
315.0 |
0.072469 |
0.977484 |
-2.388598 |
-0.543418 |
0.000739 |
0.760928 |
5.000000 |
V27 |
315.0 |
0.675460 |
2.413487 |
-5.000000 |
-0.045323 |
0.923004 |
2.036083 |
5.000000 |
V28 |
315.0 |
0.239053 |
1.653805 |
-5.000000 |
-0.355816 |
0.465788 |
1.105123 |
5.000000 |
Amount_log |
315.0 |
-0.382619 |
1.617817 |
-4.855759 |
-1.450333 |
-0.356358 |
0.847578 |
2.325855 |
neg_df.describe().T
|
count |
mean |
std |
min |
25% |
50% |
75% |
max |
V1 |
181961.0 |
0.012509 |
0.915748 |
-5.000000 |
-0.465855 |
0.011296 |
0.672076 |
1.251724 |
V2 |
181961.0 |
0.008627 |
0.856092 |
-5.000000 |
-0.359485 |
0.040111 |
0.481902 |
5.000000 |
V3 |
181961.0 |
0.011040 |
0.936744 |
-5.000000 |
-0.584617 |
0.120097 |
0.678324 |
2.790418 |
V4 |
181961.0 |
-0.006194 |
0.983917 |
-3.924999 |
-0.600361 |
-0.017643 |
0.521791 |
5.000000 |
V5 |
181961.0 |
0.005026 |
0.882939 |
-5.000000 |
-0.493601 |
-0.038049 |
0.438492 |
5.000000 |
V6 |
181961.0 |
0.001449 |
0.965131 |
-5.000000 |
-0.570013 |
-0.203485 |
0.298300 |
5.000000 |
V7 |
181961.0 |
0.005842 |
0.818787 |
-5.000000 |
-0.437768 |
0.031821 |
0.452537 |
5.000000 |
V8 |
181961.0 |
0.017644 |
0.745925 |
-5.000000 |
-0.172570 |
0.019889 |
0.272858 |
5.000000 |
V9 |
181961.0 |
0.002619 |
0.983578 |
-5.000000 |
-0.583394 |
-0.045386 |
0.543770 |
5.000000 |
V10 |
181961.0 |
0.000444 |
0.884508 |
-5.000000 |
-0.488516 |
-0.085221 |
0.419347 |
5.000000 |
V11 |
181961.0 |
-0.006827 |
0.980387 |
-4.583258 |
-0.747892 |
-0.032732 |
0.720054 |
5.000000 |
V12 |
181961.0 |
0.011853 |
0.936165 |
-5.000000 |
-0.402320 |
0.141804 |
0.617894 |
5.000000 |
V13 |
181961.0 |
0.000204 |
0.999690 |
-5.000000 |
-0.650306 |
-0.012781 |
0.666657 |
5.000000 |
V14 |
181961.0 |
0.013696 |
0.919757 |
-5.000000 |
-0.441005 |
0.055106 |
0.513936 |
5.000000 |
V15 |
181961.0 |
0.000164 |
0.999363 |
-4.793884 |
-0.638472 |
0.053209 |
0.708801 |
5.000000 |
V16 |
181961.0 |
0.008581 |
0.957594 |
-5.000000 |
-0.530996 |
0.077795 |
0.596470 |
5.000000 |
V17 |
181961.0 |
0.014459 |
0.855516 |
-5.000000 |
-0.566987 |
-0.076122 |
0.470954 |
5.000000 |
V18 |
181961.0 |
0.004775 |
0.983692 |
-5.000000 |
-0.595008 |
-0.002535 |
0.599786 |
5.000000 |
V19 |
181961.0 |
-0.001509 |
0.996377 |
-5.000000 |
-0.562193 |
0.003682 |
0.561334 |
5.000000 |
V20 |
181961.0 |
0.001119 |
0.768444 |
-5.000000 |
-0.273441 |
-0.082090 |
0.168334 |
5.000000 |
V21 |
181961.0 |
-0.009335 |
0.713355 |
-5.000000 |
-0.306948 |
-0.039843 |
0.251401 |
5.000000 |
V22 |
181961.0 |
0.000416 |
0.988440 |
-5.000000 |
-0.743404 |
0.007948 |
0.728703 |
5.000000 |
V23 |
181961.0 |
0.003516 |
0.695287 |
-5.000000 |
-0.258290 |
-0.019468 |
0.233072 |
5.000000 |
V24 |
181961.0 |
-0.000064 |
0.998703 |
-4.656135 |
-0.583868 |
0.067759 |
0.725433 |
5.000000 |
V25 |
181961.0 |
0.000240 |
0.989529 |
-5.000000 |
-0.609040 |
0.032075 |
0.672402 |
5.000000 |
V26 |
181961.0 |
-0.000407 |
0.998447 |
-5.000000 |
-0.677797 |
-0.109895 |
0.498649 |
5.000000 |
V27 |
181961.0 |
0.005844 |
0.799027 |
-5.000000 |
-0.176850 |
0.002015 |
0.224085 |
5.000000 |
V28 |
181961.0 |
-0.001089 |
0.700654 |
-5.000000 |
-0.161470 |
0.034111 |
0.239486 |
5.000000 |
Amount_log |
181961.0 |
0.000662 |
0.998482 |
-4.855759 |
-0.587677 |
0.075498 |
0.691266 |
3.554182 |
定义模型和metrics
METRICS = [
keras.metrics.TruePositives(name='tp'),
keras.metrics.FalsePositives(name='fp'),
keras.metrics.TrueNegatives(name='tn'),
keras.metrics.FalseNegatives(name='fn'),
keras.metrics.BinaryAccuracy(name='acc'),
keras.metrics.Precision(name='precision'),
keras.metrics.Recall(name='recall'),
keras.metrics.AUC(name='auc'),
keras.metrics.AUC(name='prc', curve='PR'),
]
2021-09-14 15:38:41.927807: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-09-14 15:38:41.929747: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2021-09-14 15:38:42.132272: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties:
pciBusID: 0000:06:00.0 name: Tesla P40 computeCapability: 6.1
coreClock: 1.531GHz coreCount: 30 deviceMemorySize: 22.38GiB deviceMemoryBandwidth: 323.21GiB/s
2021-09-14 15:38:42.132316: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-14 15:38:42.135128: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-14 15:38:42.135177: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-09-14 15:38:42.136258: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2021-09-14 15:38:42.136515: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2021-09-14 15:38:42.138962: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2021-09-14 15:38:42.139484: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2021-09-14 15:38:42.139633: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2021-09-14 15:38:42.142269: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
2021-09-14 15:38:42.145697: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-09-14 15:38:42.148120: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties:
pciBusID: 0000:06:00.0 name: Tesla P40 computeCapability: 6.1
coreClock: 1.531GHz coreCount: 30 deviceMemorySize: 22.38GiB deviceMemoryBandwidth: 323.21GiB/s
2021-09-14 15:38:42.148158: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-14 15:38:42.148193: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-14 15:38:42.148224: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
2021-09-14 15:38:42.148252: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2021-09-14 15:38:42.148280: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2021-09-14 15:38:42.148308: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.10
2021-09-14 15:38:42.148336: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusparse.so.11
2021-09-14 15:38:42.148364: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
2021-09-14 15:38:42.152837: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
2021-09-14 15:38:42.152897: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2021-09-14 15:38:42.781704: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-09-14 15:38:42.781748: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267] 0
2021-09-14 15:38:42.781756: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 0: N
2021-09-14 15:38:42.785308: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 21298 MB memory) -> physical GPU (device: 0, name: Tesla P40, pci bus id: 0000:06:00.0, compute capability: 6.1)
def func_make_model(n_x, metrics=None, output_bias=None):
"""
:param n_x: 输入特征维度
:param metrics:
:param output_bias:
:return model
"""
if metrics is None:
metrics = [keras.metrics.BinaryAccuracy(name='acc')]
if output_bias is not None:
output_bias = tf.keras.initializers.Constant(output_bias)
model = keras.Sequential([
keras.layers.Dense(16, activation='relu', input_shape=(n_x, )),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid', bias_initializer=output_bias),
])
model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss=keras.losses.BinaryCrossentropy(),
metrics=metrics)
return model
baseline model
build the model
restore_best_weights
: Whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.
EPOCHS = 100
BATCH_SIZE = 2048
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_prc',
verbose=1,
patience=10,
mode='max',
restore_best_weights=True)
model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_baseline.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 16) 480
_________________________________________________________________
dropout (Dropout) (None, 16) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 17
=================================================================
Total params: 497
Trainable params: 497
Non-trainable params: 0
_________________________________________________________________
model_baseline.predict(train_features[:10])
2021-09-14 15:38:43.448415: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-09-14 15:38:43.449074: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2394435000 Hz
2021-09-14 15:38:43.566365: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2021-09-14 15:38:43.851378: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublasLt.so.11
array([[0.7945324 ],
[0.9206671 ],
[0.7186621 ],
[0.8676042 ],
[0.86442024],
[0.820997 ],
[0.8121756 ],
[0.8080936 ],
[0.820931 ],
[0.9337303 ]], dtype=float32)
set the correct initial bias
参考训练神经网络的方法
- With the default bias initialization the loss should be about
math.log(2)=0.69314
- The correct bias to set can be derived from:
- with correct bias the loss should be about:
## 1
results = model_baseline.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print(results[0])
1.9581257104873657
## 2
correct_initial_bias = np.log([pos/neg])
correct_initial_bias
array([-6.35935934])
model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS, output_bias=correct_initial_bias)
model_baseline.predict(train_features[:10])
array([[0.00340313],
[0.00266778],
[0.00226334],
[0.00272266],
[0.00042819],
[0.00054607],
[0.00602424],
[0.00321096],
[0.0007532 ],
[0.00322443]], dtype=float32)
## 3
p_0 = pos / (pos+neg)
-1*p_0*np.log(p_0) - (1-p_0)*np.log(1-p_0)
0.012714681335936208
results = model_baseline.evaluate(train_features, train_labels, batch_size=BATCH_SIZE, verbose=0)
print(results[0])
0.015545893460512161
checkpoint the initial weights
- 为了方便之后几个模型的参数初始化,故将这次纠正过的初始化参数保存下来
# os.path.join(tempfile.mkdtemp(), 'initial_weights')
correct_initial_bias_path = './tmp/correct_initial_bias/correct_initial_bias'
model_baseline.save_weights(correct_initial_bias_path)
看一下纠正参数初始化后的效果
model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_baseline.load_weights(correct_initial_bias_path)
model_baseline.layers[-1].bias.assign([0.0]) # 设置最后一层的bias为0
zero_bias_history = model_baseline.fit(train_features, train_labels, batch_size=BATCH_SIZE, epochs=20, validation_data=(val_features, val_labels), verbose=0)
model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_baseline.load_weights(correct_initial_bias_path)
correct_bias_history = model_baseline.fit(train_features, train_labels, batch_size=BATCH_SIZE, epochs=20, validation_data=(val_features, val_labels), verbose=0)
def func_plot_loss(history, label, color):
# Use a log scale on y-axis to show the wide range of values.
plt.semilogy(history.epoch, history.history['loss'], color=color, label='Train ' + label)
plt.semilogy(history.epoch, history.history['val_loss'], color=color, label='Val ' + label, linestyle="--")
plt.xlabel('Epoch')
plt.ylabel('Loss')
func_plot_loss(zero_bias_history, "Zero Bias", 'red')
func_plot_loss(correct_bias_history, "Careful Bias", 'blue')
可以看到,纠正过后,损失值更小
train the baseline model
model_baseline = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_baseline.load_weights(correct_initial_bias_path)
baseline_history = model_baseline.fit(train_features, train_labels, batch_size=BATCH_SIZE, epochs=EPOCHS,
callbacks=[early_stopping], validation_data=(val_features, val_labels))
Epoch 1/100
90/90 [==============================] - 4s 21ms/step - loss: 0.0147 - tp: 64.8352 - fp: 45.5934 - tn: 139430.5055 - fn: 168.6374 - acc: 0.9985 - precision: 0.5985 - recall: 0.3220 - auc: 0.7673 - prc: 0.3309 - val_loss: 0.0073 - val_tp: 0.0000e+00 - val_fp: 0.0000e+00 - val_tn: 45490.0000 - val_fn: 79.0000 - val_acc: 0.9983 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_auc: 0.8791 - val_prc: 0.6609
Epoch 2/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0078 - tp: 48.4505 - fp: 13.4286 - tn: 93966.1978 - fn: 112.4945 - acc: 0.9987 - precision: 0.8123 - recall: 0.3008 - auc: 0.8303 - prc: 0.4306 - val_loss: 0.0051 - val_tp: 29.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 50.0000 - val_acc: 0.9988 - val_precision: 0.8286 - val_recall: 0.3671 - val_auc: 0.8922 - val_prc: 0.6677
Epoch 3/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0081 - tp: 65.3846 - fp: 14.6813 - tn: 93958.2637 - fn: 102.2418 - acc: 0.9986 - precision: 0.7926 - recall: 0.3719 - auc: 0.8072 - prc: 0.4533 - val_loss: 0.0045 - val_tp: 44.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 35.0000 - val_acc: 0.9991 - val_precision: 0.8800 - val_recall: 0.5570 - val_auc: 0.9048 - val_prc: 0.6881
Epoch 4/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0068 - tp: 71.2308 - fp: 15.0110 - tn: 93965.9011 - fn: 88.4286 - acc: 0.9989 - precision: 0.8163 - recall: 0.4337 - auc: 0.8344 - prc: 0.5153 - val_loss: 0.0042 - val_tp: 47.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 32.0000 - val_acc: 0.9992 - val_precision: 0.8868 - val_recall: 0.5949 - val_auc: 0.9175 - val_prc: 0.7249
Epoch 5/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0057 - tp: 76.8901 - fp: 19.1648 - tn: 93963.4286 - fn: 81.0879 - acc: 0.9989 - precision: 0.7970 - recall: 0.4733 - auc: 0.8757 - prc: 0.5857 - val_loss: 0.0040 - val_tp: 48.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 31.0000 - val_acc: 0.9992 - val_precision: 0.8889 - val_recall: 0.6076 - val_auc: 0.9175 - val_prc: 0.7392
Epoch 6/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0064 - tp: 74.9560 - fp: 15.1319 - tn: 93958.1099 - fn: 92.3736 - acc: 0.9988 - precision: 0.8518 - recall: 0.4300 - auc: 0.8618 - prc: 0.5888 - val_loss: 0.0039 - val_tp: 52.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 27.0000 - val_acc: 0.9993 - val_precision: 0.8966 - val_recall: 0.6582 - val_auc: 0.9174 - val_prc: 0.7480
Epoch 7/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0047 - tp: 88.3846 - fp: 12.4615 - tn: 93968.0549 - fn: 71.6703 - acc: 0.9991 - precision: 0.8904 - recall: 0.5613 - auc: 0.9146 - prc: 0.7047 - val_loss: 0.0038 - val_tp: 54.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 25.0000 - val_acc: 0.9993 - val_precision: 0.9000 - val_recall: 0.6835 - val_auc: 0.9174 - val_prc: 0.7568
Epoch 8/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0043 - tp: 95.3297 - fp: 19.2418 - tn: 93959.6923 - fn: 66.3077 - acc: 0.9991 - precision: 0.8432 - recall: 0.6045 - auc: 0.9324 - prc: 0.7021 - val_loss: 0.0037 - val_tp: 52.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 27.0000 - val_acc: 0.9993 - val_precision: 0.8966 - val_recall: 0.6582 - val_auc: 0.9174 - val_prc: 0.7711
Epoch 9/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0049 - tp: 94.0220 - fp: 11.9890 - tn: 93963.3736 - fn: 71.1868 - acc: 0.9991 - precision: 0.9108 - recall: 0.5838 - auc: 0.8996 - prc: 0.7061 - val_loss: 0.0036 - val_tp: 53.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 26.0000 - val_acc: 0.9993 - val_precision: 0.8983 - val_recall: 0.6709 - val_auc: 0.9175 - val_prc: 0.7826
Epoch 10/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0045 - tp: 91.1209 - fp: 20.6923 - tn: 93955.2747 - fn: 73.4835 - acc: 0.9990 - precision: 0.8317 - recall: 0.5825 - auc: 0.9340 - prc: 0.6933 - val_loss: 0.0036 - val_tp: 53.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 26.0000 - val_acc: 0.9993 - val_precision: 0.8983 - val_recall: 0.6709 - val_auc: 0.9175 - val_prc: 0.7929
Epoch 11/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0045 - tp: 84.9231 - fp: 18.3956 - tn: 93962.7363 - fn: 74.5165 - acc: 0.9990 - precision: 0.8351 - recall: 0.5411 - auc: 0.9233 - prc: 0.6955 - val_loss: 0.0035 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.7973
Epoch 12/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0046 - tp: 101.3626 - fp: 17.9341 - tn: 93954.2088 - fn: 67.0659 - acc: 0.9991 - precision: 0.8479 - recall: 0.6011 - auc: 0.9329 - prc: 0.7196 - val_loss: 0.0035 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9175 - val_prc: 0.7939
Epoch 13/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0050 - tp: 94.1099 - fp: 18.0110 - tn: 93957.0110 - fn: 71.4396 - acc: 0.9990 - precision: 0.8307 - recall: 0.5390 - auc: 0.9075 - prc: 0.6494 - val_loss: 0.0035 - val_tp: 56.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 23.0000 - val_acc: 0.9994 - val_precision: 0.9032 - val_recall: 0.7089 - val_auc: 0.9238 - val_prc: 0.7989
Epoch 14/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0047 - tp: 89.5495 - fp: 17.6264 - tn: 93958.5604 - fn: 74.8352 - acc: 0.9990 - precision: 0.8187 - recall: 0.5544 - auc: 0.9229 - prc: 0.6758 - val_loss: 0.0035 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9062 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.7989
Epoch 15/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0043 - tp: 87.0000 - fp: 16.5934 - tn: 93963.8571 - fn: 73.1209 - acc: 0.9991 - precision: 0.8336 - recall: 0.5472 - auc: 0.9099 - prc: 0.6537 - val_loss: 0.0035 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9062 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.7999
Epoch 16/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0040 - tp: 93.5275 - fp: 14.4725 - tn: 93967.1099 - fn: 65.4615 - acc: 0.9992 - precision: 0.8623 - recall: 0.5874 - auc: 0.9295 - prc: 0.7207 - val_loss: 0.0034 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9206 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.8039
Epoch 17/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0044 - tp: 87.9451 - fp: 15.0000 - tn: 93970.7033 - fn: 66.9231 - acc: 0.9991 - precision: 0.8475 - recall: 0.5557 - auc: 0.9207 - prc: 0.6808 - val_loss: 0.0034 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8054
Epoch 18/100
90/90 [==============================] - 1s 11ms/step - loss: 0.0042 - tp: 100.8462 - fp: 18.8791 - tn: 93962.7253 - fn: 58.1209 - acc: 0.9991 - precision: 0.8238 - recall: 0.6282 - auc: 0.9299 - prc: 0.7125 - val_loss: 0.0034 - val_tp: 58.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9206 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.8091
Epoch 19/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0038 - tp: 91.3956 - fp: 16.1758 - tn: 93965.0989 - fn: 67.9011 - acc: 0.9991 - precision: 0.8653 - recall: 0.5567 - auc: 0.9329 - prc: 0.7471 - val_loss: 0.0034 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8081
Epoch 20/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0035 - tp: 101.1978 - fp: 20.1209 - tn: 93958.4615 - fn: 60.7912 - acc: 0.9992 - precision: 0.8429 - recall: 0.6451 - auc: 0.9545 - prc: 0.7794 - val_loss: 0.0034 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8084
Epoch 21/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0033 - tp: 102.5495 - fp: 16.1868 - tn: 93968.0000 - fn: 53.8352 - acc: 0.9993 - precision: 0.8732 - recall: 0.6806 - auc: 0.9385 - prc: 0.7763 - val_loss: 0.0034 - val_tp: 52.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 27.0000 - val_acc: 0.9993 - val_precision: 0.9123 - val_recall: 0.6582 - val_auc: 0.9238 - val_prc: 0.8112
Epoch 22/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0044 - tp: 89.4615 - fp: 16.7363 - tn: 93964.2857 - fn: 70.0879 - acc: 0.9991 - precision: 0.8323 - recall: 0.5465 - auc: 0.9189 - prc: 0.6915 - val_loss: 0.0033 - val_tp: 55.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 24.0000 - val_acc: 0.9994 - val_precision: 0.9167 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8102
Epoch 23/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0038 - tp: 98.5385 - fp: 15.2198 - tn: 93966.8791 - fn: 59.9341 - acc: 0.9992 - precision: 0.8713 - recall: 0.6000 - auc: 0.9229 - prc: 0.7481 - val_loss: 0.0033 - val_tp: 55.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 24.0000 - val_acc: 0.9994 - val_precision: 0.9167 - val_recall: 0.6962 - val_auc: 0.9237 - val_prc: 0.8129
Epoch 24/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0039 - tp: 97.1868 - fp: 18.4505 - tn: 93961.9890 - fn: 62.9451 - acc: 0.9992 - precision: 0.8358 - recall: 0.6341 - auc: 0.9337 - prc: 0.7182 - val_loss: 0.0034 - val_tp: 54.0000 - val_fp: 1.0000 - val_tn: 45489.0000 - val_fn: 25.0000 - val_acc: 0.9994 - val_precision: 0.9818 - val_recall: 0.6835 - val_auc: 0.9238 - val_prc: 0.8156
Epoch 25/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0042 - tp: 99.9231 - fp: 16.0769 - tn: 93957.1319 - fn: 67.4396 - acc: 0.9991 - precision: 0.8482 - recall: 0.6010 - auc: 0.9216 - prc: 0.7227 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8132
Epoch 26/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0037 - tp: 103.0769 - fp: 16.0110 - tn: 93956.2747 - fn: 65.2088 - acc: 0.9992 - precision: 0.8841 - recall: 0.6262 - auc: 0.9341 - prc: 0.7564 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 5.0000 - val_tn: 45485.0000 - val_fn: 20.0000 - val_acc: 0.9995 - val_precision: 0.9219 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8134
Epoch 27/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0038 - tp: 98.8022 - fp: 15.9011 - tn: 93962.5604 - fn: 63.3077 - acc: 0.9992 - precision: 0.8723 - recall: 0.6362 - auc: 0.9365 - prc: 0.7486 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 20.0000 - val_acc: 0.9994 - val_precision: 0.9077 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8124
Epoch 28/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0040 - tp: 98.9451 - fp: 15.8022 - tn: 93965.5714 - fn: 60.2527 - acc: 0.9992 - precision: 0.8606 - recall: 0.6430 - auc: 0.9241 - prc: 0.7183 - val_loss: 0.0033 - val_tp: 59.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 20.0000 - val_acc: 0.9994 - val_precision: 0.9077 - val_recall: 0.7468 - val_auc: 0.9238 - val_prc: 0.8148
Epoch 29/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0036 - tp: 104.4286 - fp: 13.1648 - tn: 93965.1429 - fn: 57.8352 - acc: 0.9993 - precision: 0.8989 - recall: 0.6587 - auc: 0.9396 - prc: 0.7463 - val_loss: 0.0033 - val_tp: 58.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 21.0000 - val_acc: 0.9994 - val_precision: 0.9062 - val_recall: 0.7342 - val_auc: 0.9238 - val_prc: 0.8150
Epoch 30/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0039 - tp: 104.1319 - fp: 12.7033 - tn: 93967.4066 - fn: 56.3297 - acc: 0.9992 - precision: 0.8847 - recall: 0.6371 - auc: 0.9197 - prc: 0.7073 - val_loss: 0.0033 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8163
Epoch 31/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0037 - tp: 91.6154 - fp: 17.1648 - tn: 93966.5495 - fn: 65.2418 - acc: 0.9992 - precision: 0.8379 - recall: 0.5814 - auc: 0.9231 - prc: 0.7264 - val_loss: 0.0033 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8168
Epoch 32/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0038 - tp: 98.9231 - fp: 14.8681 - tn: 93962.0440 - fn: 64.7363 - acc: 0.9991 - precision: 0.8635 - recall: 0.5902 - auc: 0.9405 - prc: 0.7401 - val_loss: 0.0033 - val_tp: 56.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 23.0000 - val_acc: 0.9994 - val_precision: 0.9032 - val_recall: 0.7089 - val_auc: 0.9238 - val_prc: 0.8164
Epoch 33/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0037 - tp: 102.0549 - fp: 16.8242 - tn: 93959.9231 - fn: 61.7692 - acc: 0.9992 - precision: 0.8656 - recall: 0.6255 - auc: 0.9333 - prc: 0.7486 - val_loss: 0.0034 - val_tp: 54.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 25.0000 - val_acc: 0.9994 - val_precision: 0.9643 - val_recall: 0.6835 - val_auc: 0.9238 - val_prc: 0.8174
Epoch 34/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0030 - tp: 106.4396 - fp: 10.6593 - tn: 93972.9231 - fn: 50.5495 - acc: 0.9994 - precision: 0.9169 - recall: 0.6839 - auc: 0.9287 - prc: 0.7915 - val_loss: 0.0034 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8158
Epoch 35/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0039 - tp: 96.9780 - fp: 15.2857 - tn: 93964.0220 - fn: 64.2857 - acc: 0.9992 - precision: 0.8767 - recall: 0.5986 - auc: 0.9371 - prc: 0.7327 - val_loss: 0.0034 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8157
Epoch 36/100
90/90 [==============================] - 1s 10ms/step - loss: 0.0039 - tp: 97.0440 - fp: 16.4725 - tn: 93962.8242 - fn: 64.2308 - acc: 0.9992 - precision: 0.8561 - recall: 0.6162 - auc: 0.9231 - prc: 0.7042 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8161
Epoch 37/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0040 - tp: 98.1978 - fp: 16.4176 - tn: 93961.2747 - fn: 64.6813 - acc: 0.9992 - precision: 0.8608 - recall: 0.6091 - auc: 0.9226 - prc: 0.7067 - val_loss: 0.0034 - val_tp: 54.0000 - val_fp: 3.0000 - val_tn: 45487.0000 - val_fn: 25.0000 - val_acc: 0.9994 - val_precision: 0.9474 - val_recall: 0.6835 - val_auc: 0.9238 - val_prc: 0.8160
Epoch 38/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0036 - tp: 97.2198 - fp: 15.0879 - tn: 93970.4835 - fn: 57.7802 - acc: 0.9992 - precision: 0.8639 - recall: 0.6325 - auc: 0.9437 - prc: 0.7358 - val_loss: 0.0034 - val_tp: 55.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 24.0000 - val_acc: 0.9993 - val_precision: 0.9016 - val_recall: 0.6962 - val_auc: 0.9238 - val_prc: 0.8155
Epoch 39/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0041 - tp: 107.5385 - fp: 18.5604 - tn: 93953.8352 - fn: 60.6374 - acc: 0.9991 - precision: 0.8364 - recall: 0.6481 - auc: 0.9290 - prc: 0.7210 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8160
Epoch 40/100
90/90 [==============================] - 1s 7ms/step - loss: 0.0041 - tp: 96.3846 - fp: 16.7143 - tn: 93964.8352 - fn: 62.6374 - acc: 0.9992 - precision: 0.8532 - recall: 0.6206 - auc: 0.9248 - prc: 0.6902 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8154
Epoch 41/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0037 - tp: 103.1758 - fp: 14.4505 - tn: 93961.2198 - fn: 61.7253 - acc: 0.9992 - precision: 0.8754 - recall: 0.6197 - auc: 0.9334 - prc: 0.7572 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8157
Epoch 42/100
90/90 [==============================] - 1s 8ms/step - loss: 0.0034 - tp: 90.2088 - fp: 15.3846 - tn: 93967.3516 - fn: 67.6264 - acc: 0.9991 - precision: 0.8634 - recall: 0.5541 - auc: 0.9456 - prc: 0.7623 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8152
Epoch 43/100
90/90 [==============================] - 1s 9ms/step - loss: 0.0040 - tp: 96.2308 - fp: 21.1209 - tn: 93957.6813 - fn: 65.5385 - acc: 0.9990 - precision: 0.8026 - recall: 0.5805 - auc: 0.9326 - prc: 0.7090 - val_loss: 0.0034 - val_tp: 53.0000 - val_fp: 2.0000 - val_tn: 45488.0000 - val_fn: 26.0000 - val_acc: 0.9994 - val_precision: 0.9636 - val_recall: 0.6709 - val_auc: 0.9238 - val_prc: 0.8150
Restoring model weights from the end of the best epoch.
Epoch 00043: early stopping
查看训练模型的history图
def func_plot_metrics(history):
metrics = ['loss', 'prc', 'precision', 'recall']
for n, metric in enumerate(metrics):
name = metric.replace("_"," ").capitalize()
plt.subplot(2,2,n+1)
plt.plot(history.epoch, history.history[metric], color='blue', label='Train')
plt.plot(history.epoch, history.history['val_'+metric], color='blue', linestyle="--", label='Val')
plt.xlabel('Epoch')
plt.ylabel(name)
if metric == 'loss':
pass
# plt.ylim([0, plt.ylim()[1]])
elif metric == 'auc':
plt.ylim([0.8,1])
elif metric == 'recall':
plt.ylim([0,1])
else:
plt.ylim([0,1])
plt.legend()
plt.tight_layout()
func_plot_metrics(baseline_history)
查看一下混淆矩阵
train_predictions_baseline = model_baseline.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_baseline = model_baseline.predict(test_features, batch_size=BATCH_SIZE)
def func_plot_cm(labels, predictions, p=0.5):
cm = confusion_matrix(labels, predictions > p)
plt.figure(figsize=(5,5))
sns.heatmap(cm, annot=True, fmt="d")
plt.title('Confusion matrix @{:.2f}'.format(p))
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
print('Legitimate Transactions Detected (True Negatives): ', cm[0][0])
print('Legitimate Transactions Incorrectly Detected (False Positives): ', cm[0][1])
print('Fraudulent Transactions Missed (False Negatives): ', cm[1][0])
print('Fraudulent Transactions Detected (True Positives): ', cm[1][1])
print('Total Fraudulent Transactions: ', np.sum(cm[1]))
baseline_results = model_baseline.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model_baseline.metrics_names, baseline_results):
print(name, ': ', value)
print()
func_plot_cm(test_labels, test_predictions_baseline)
loss : 0.002730185631662607
tp : 72.0
fp : 9.0
tn : 56855.0
fn : 26.0
acc : 0.9993855357170105
precision : 0.8888888955116272
recall : 0.7346938848495483
auc : 0.9487220048904419
prc : 0.8363494873046875
Legitimate Transactions Detected (True Negatives): 56855
Legitimate Transactions Incorrectly Detected (False Positives): 9
Fraudulent Transactions Missed (False Negatives): 26
Fraudulent Transactions Detected (True Positives): 72
Total Fraudulent Transactions: 98
画roc图
def func_plot_roc(name, labels, predictions, **kwargs):
fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)
plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
plt.xlabel('False positives [%]')
plt.ylabel('True positives [%]')
plt.xlim([-0.5,20])
plt.ylim([80,100.5])
plt.grid(True)
ax = plt.gca()
ax.set_aspect('equal')
func_plot_roc("Train Baseline", train_labels, train_predictions_baseline, color='red')
func_plot_roc("Test Baseline", test_labels, test_predictions_baseline, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2cc0264550>
画PR图
def func_plot_prc(name, labels, predictions, **kwargs):
precision, recall, _ = sklearn.metrics.precision_recall_curve(labels, predictions)
plt.plot(precision, recall, label=name, linewidth=2, **kwargs)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.grid(True)
ax = plt.gca()
ax.set_aspect('equal')
func_plot_prc("Train Baseline", train_labels, train_predictions_baseline, color='red')
func_plot_prc("Test Baseline", test_labels, test_predictions_baseline, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2cc00d0e20>
model_1: add class_weight
计算class_weight
weight_for_0 = (1 / neg) * (total / 2.0) # total / (2*neg)
weight_for_1 = (1 / pos) * (total / 2.0) # total / (2*pos)
class_weight = {
0: weight_for_0,
1: weight_for_1
}
class_weight
{0: 0.5008652375006595, 1: 289.43800813008136}
train a model with class weight
model_weight = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_weight.load_weights(correct_initial_bias_path)
model_weight_history = model_weight.fit(train_features, train_labels,
batch_size=BATCH_SIZE, epochs=EPOCHS,
callbacks=[early_stopping],
validation_data=(val_features, val_labels),
# the class weights go here
class_weight=class_weight)
Epoch 1/100
90/90 [==============================] - 4s 20ms/step - loss: 3.3108 - tp: 85.5495 - fp: 103.0659 - tn: 150735.5495 - fn: 178.4066 - acc: 0.9983 - precision: 0.5085 - recall: 0.3590 - auc: 0.7769 - prc: 0.3236 - val_loss: 0.0081 - val_tp: 10.0000 - val_fp: 6.0000 - val_tn: 45484.0000 - val_fn: 69.0000 - val_acc: 0.9984 - val_precision: 0.6250 - val_recall: 0.1266 - val_auc: 0.9132 - val_prc: 0.3528
Epoch 2/100
90/90 [==============================] - 1s 9ms/step - loss: 1.5420 - tp: 58.9451 - fp: 177.4945 - tn: 93802.3516 - fn: 101.7802 - acc: 0.9972 - precision: 0.2352 - recall: 0.3316 - auc: 0.7668 - prc: 0.1856 - val_loss: 0.0093 - val_tp: 56.0000 - val_fp: 32.0000 - val_tn: 45458.0000 - val_fn: 23.0000 - val_acc: 0.9988 - val_precision: 0.6364 - val_recall: 0.7089 - val_auc: 0.9470 - val_prc: 0.5701
Epoch 3/100
90/90 [==============================] - 1s 8ms/step - loss: 0.8803 - tp: 92.3956 - fp: 448.8462 - tn: 93533.2418 - fn: 66.0879 - acc: 0.9947 - precision: 0.1718 - recall: 0.5509 - auc: 0.8812 - prc: 0.2806 - val_loss: 0.0147 - val_tp: 62.0000 - val_fp: 94.0000 - val_tn: 45396.0000 - val_fn: 17.0000 - val_acc: 0.9976 - val_precision: 0.3974 - val_recall: 0.7848 - val_auc: 0.9680 - val_prc: 0.6804
Epoch 4/100
90/90 [==============================] - 1s 8ms/step - loss: 0.6487 - tp: 110.7582 - fp: 801.8242 - tn: 93174.9560 - fn: 53.0330 - acc: 0.9914 - precision: 0.1226 - recall: 0.6634 - auc: 0.9129 - prc: 0.2730 - val_loss: 0.0227 - val_tp: 66.0000 - val_fp: 166.0000 - val_tn: 45324.0000 - val_fn: 13.0000 - val_acc: 0.9961 - val_precision: 0.2845 - val_recall: 0.8354 - val_auc: 0.9691 - val_prc: 0.7008
Epoch 5/100
90/90 [==============================] - 1s 9ms/step - loss: 0.4749 - tp: 126.5165 - fp: 1314.5934 - tn: 92661.6813 - fn: 37.7802 - acc: 0.9859 - precision: 0.0907 - recall: 0.7836 - auc: 0.9293 - prc: 0.2892 - val_loss: 0.0308 - val_tp: 66.0000 - val_fp: 281.0000 - val_tn: 45209.0000 - val_fn: 13.0000 - val_acc: 0.9935 - val_precision: 0.1902 - val_recall: 0.8354 - val_auc: 0.9708 - val_prc: 0.6695
Epoch 6/100
90/90 [==============================] - 1s 9ms/step - loss: 0.4803 - tp: 128.5824 - fp: 1652.0330 - tn: 92324.8791 - fn: 35.0769 - acc: 0.9824 - precision: 0.0749 - recall: 0.7747 - auc: 0.9277 - prc: 0.2729 - val_loss: 0.0380 - val_tp: 66.0000 - val_fp: 448.0000 - val_tn: 45042.0000 - val_fn: 13.0000 - val_acc: 0.9899 - val_precision: 0.1284 - val_recall: 0.8354 - val_auc: 0.9745 - val_prc: 0.6359
Epoch 7/100
90/90 [==============================] - 1s 9ms/step - loss: 0.3692 - tp: 136.0110 - fp: 2004.2967 - tn: 91972.8242 - fn: 27.4396 - acc: 0.9787 - precision: 0.0646 - recall: 0.8400 - auc: 0.9357 - prc: 0.2538 - val_loss: 0.0446 - val_tp: 66.0000 - val_fp: 548.0000 - val_tn: 44942.0000 - val_fn: 13.0000 - val_acc: 0.9877 - val_precision: 0.1075 - val_recall: 0.8354 - val_auc: 0.9741 - val_prc: 0.6173
Epoch 8/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3391 - tp: 133.3846 - fp: 2250.4615 - tn: 91731.2527 - fn: 25.4725 - acc: 0.9759 - precision: 0.0547 - recall: 0.8384 - auc: 0.9464 - prc: 0.2333 - val_loss: 0.0502 - val_tp: 66.0000 - val_fp: 602.0000 - val_tn: 44888.0000 - val_fn: 13.0000 - val_acc: 0.9865 - val_precision: 0.0988 - val_recall: 0.8354 - val_auc: 0.9740 - val_prc: 0.6335
Epoch 9/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3891 - tp: 134.6374 - fp: 2574.4066 - tn: 91403.6154 - fn: 27.9121 - acc: 0.9727 - precision: 0.0507 - recall: 0.8210 - auc: 0.9373 - prc: 0.2371 - val_loss: 0.0550 - val_tp: 66.0000 - val_fp: 643.0000 - val_tn: 44847.0000 - val_fn: 13.0000 - val_acc: 0.9856 - val_precision: 0.0931 - val_recall: 0.8354 - val_auc: 0.9756 - val_prc: 0.6202
Epoch 10/100
90/90 [==============================] - 1s 8ms/step - loss: 0.3159 - tp: 137.5714 - fp: 2726.1538 - tn: 91248.8462 - fn: 28.0000 - acc: 0.9712 - precision: 0.0490 - recall: 0.8459 - auc: 0.9529 - prc: 0.2447 - val_loss: 0.0597 - val_tp: 67.0000 - val_fp: 707.0000 - val_tn: 44783.0000 - val_fn: 12.0000 - val_acc: 0.9842 - val_precision: 0.0866 - val_recall: 0.8481 - val_auc: 0.9754 - val_prc: 0.6138
Epoch 11/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2898 - tp: 142.0110 - fp: 2837.5055 - tn: 91137.0440 - fn: 24.0110 - acc: 0.9695 - precision: 0.0480 - recall: 0.8655 - auc: 0.9555 - prc: 0.2209 - val_loss: 0.0630 - val_tp: 67.0000 - val_fp: 753.0000 - val_tn: 44737.0000 - val_fn: 12.0000 - val_acc: 0.9832 - val_precision: 0.0817 - val_recall: 0.8481 - val_auc: 0.9755 - val_prc: 0.6147
Epoch 12/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2809 - tp: 134.1429 - fp: 2854.5824 - tn: 91132.2088 - fn: 19.6374 - acc: 0.9695 - precision: 0.0454 - recall: 0.8817 - auc: 0.9487 - prc: 0.2253 - val_loss: 0.0639 - val_tp: 67.0000 - val_fp: 770.0000 - val_tn: 44720.0000 - val_fn: 12.0000 - val_acc: 0.9828 - val_precision: 0.0800 - val_recall: 0.8481 - val_auc: 0.9754 - val_prc: 0.6218
Epoch 13/100
90/90 [==============================] - 1s 9ms/step - loss: 0.3396 - tp: 139.3956 - fp: 2952.3626 - tn: 91026.8462 - fn: 21.9670 - acc: 0.9683 - precision: 0.0439 - recall: 0.8461 - auc: 0.9395 - prc: 0.2182 - val_loss: 0.0630 - val_tp: 67.0000 - val_fp: 758.0000 - val_tn: 44732.0000 - val_fn: 12.0000 - val_acc: 0.9831 - val_precision: 0.0812 - val_recall: 0.8481 - val_auc: 0.9763 - val_prc: 0.6152
Epoch 14/100
90/90 [==============================] - 1s 8ms/step - loss: 0.2931 - tp: 133.0000 - fp: 2956.3736 - tn: 91029.8462 - fn: 21.3516 - acc: 0.9685 - precision: 0.0404 - recall: 0.8468 - auc: 0.9493 - prc: 0.2041 - val_loss: 0.0675 - val_tp: 67.0000 - val_fp: 811.0000 - val_tn: 44679.0000 - val_fn: 12.0000 - val_acc: 0.9819 - val_precision: 0.0763 - val_recall: 0.8481 - val_auc: 0.9769 - val_prc: 0.6114
Restoring model weights from the end of the best epoch.
Epoch 00014: early stopping
check training history
func_plot_metrics(model_weight_history)
查看一下混淆矩阵
train_predictions_weight = model_weight.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_weight = model_weight.predict(test_features, batch_size=BATCH_SIZE)
results = model_weight.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model_weight.metrics_names, results):
print(name, ': ', value)
print()
func_plot_cm(test_labels, test_predictions_weight)
loss : 0.02326195128262043
tp : 83.0
fp : 210.0
tn : 56654.0
fn : 15.0
acc : 0.9960500001907349
precision : 0.28327643871307373
recall : 0.8469387888908386
auc : 0.969890296459198
prc : 0.7208744287490845
Legitimate Transactions Detected (True Negatives): 56654
Legitimate Transactions Incorrectly Detected (False Positives): 210
Fraudulent Transactions Missed (False Negatives): 15
Fraudulent Transactions Detected (True Positives): 83
Total Fraudulent Transactions: 98
画roc图
func_plot_roc("Train weight", train_labels, train_predictions_weight, color='red')
func_plot_roc("Test weight", test_labels, test_predictions_weight, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2c9c227dc0>
画PR图
func_plot_prc("Train weight", train_labels, train_predictions_weight, color='red')
func_plot_prc("Test weight", test_labels, test_predictions_weight, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2c9c19fd90>
model_2: 上采样少的那个类别
上采样
pos_features = train_features[bool_train_labels]
neg_features = train_features[~bool_train_labels]
pos_labels = train_labels[bool_train_labels]
neg_labels = train_labels[~bool_train_labels]
using numpy
ids = np.arange(len(pos_features))
choices = np.random.choice(ids, len(neg_features)) # Generates a random sample from a given 1-D array
res_pos_features = pos_features[choices]
res_pos_labels = pos_labels[choices]
res_pos_features.shape
(181961, 29)
resampled_features = np.concatenate([res_pos_features, neg_features], axis=0)
resampled_labels = np.concatenate([res_pos_labels, neg_labels], axis=0)
# 随机打乱
order = np.arange(len(resampled_labels))
np.random.shuffle(order)
resampled_features = resampled_features[order]
resampled_labels = resampled_labels[order]
resampled_features.shape
(363922, 29)
using tf.data
def func_make_ds(features, labels):
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.shuffle(len(features)).repeat() # 重复无数次
return ds
pos_ds = func_make_ds(pos_features, pos_labels)
neg_ds = func_make_ds(neg_features, neg_labels)
for features, label in pos_ds.take(1):
print(features, label)
tf.Tensor(
[-0.72499268 2.48494345 -5. 4.71183888 -3.39522677 -1.54142183
-5. 2.32846081 -2.95406275 -5. 5. -5.
0.5493042 -5. 0.99578148 -5. -5. -5.
3.88524349 1.80758865 2.51455711 0.55774627 0.9575459 -1.26789652
-3.3464076 1.04065595 4.89460213 2.17526156 -1.45033267], shape=(29,), dtype=float64) tf.Tensor(1, shape=(), dtype=int64)
resampled_ds = tf.data.experimental.sample_from_datasets([pos_ds, neg_ds], weights=[0.5, 0.5])
resampled_ds = resampled_ds.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((val_features, val_labels)).cache()
val_ds = val_ds.batch(BATCH_SIZE).prefetch(tf.data.experimental.AUTOTUNE)
for features, label in resampled_ds.take(1):
print(label.numpy().mean())
0.50537109375
resampled_steps_per_epoch = np.ceil(2.0*neg/BATCH_SIZE) # 正负样本量/batch_size
resampled_steps_per_epoch
278.0
train a model on oversample data
model_sampled = func_make_model(n_x=train_features.shape[-1], metrics=METRICS)
model_sampled.load_weights(correct_initial_bias_path)
# reset the bias to zero, since this dataset is balanced
model_sampled.layers[-1].bias.assign([0.0]) # 设置最后一层的bias为0
model_sampled_history = model_sampled.fit(resampled_ds, epochs=EPOCHS, steps_per_epoch=resampled_steps_per_epoch,
callbacks=[early_stopping],
validation_data=val_ds)
Epoch 1/100
278/278 [==============================] - 9s 25ms/step - loss: 0.7187 - tp: 106770.7025 - fp: 41735.0538 - tn: 158500.6093 - fn: 36668.2939 - acc: 0.7680 - precision: 0.6690 - recall: 0.6851 - auc: 0.8375 - prc: 0.7999 - val_loss: 0.1986 - val_tp: 67.0000 - val_fp: 1074.0000 - val_tn: 44416.0000 - val_fn: 12.0000 - val_acc: 0.9762 - val_precision: 0.0587 - val_recall: 0.8481 - val_auc: 0.9660 - val_prc: 0.7480
Epoch 2/100
278/278 [==============================] - 6s 23ms/step - loss: 0.2155 - tp: 128022.4910 - fp: 8284.8065 - tn: 135214.0609 - fn: 15191.3011 - acc: 0.9150 - precision: 0.9348 - recall: 0.8921 - auc: 0.9683 - prc: 0.9752 - val_loss: 0.1105 - val_tp: 67.0000 - val_fp: 704.0000 - val_tn: 44786.0000 - val_fn: 12.0000 - val_acc: 0.9843 - val_precision: 0.0869 - val_recall: 0.8481 - val_auc: 0.9679 - val_prc: 0.7478
Epoch 3/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1596 - tp: 129871.1828 - fp: 4884.9964 - tn: 138569.3584 - fn: 13387.1219 - acc: 0.9358 - precision: 0.9629 - recall: 0.9063 - auc: 0.9833 - prc: 0.9858 - val_loss: 0.0815 - val_tp: 68.0000 - val_fp: 640.0000 - val_tn: 44850.0000 - val_fn: 11.0000 - val_acc: 0.9857 - val_precision: 0.0960 - val_recall: 0.8608 - val_auc: 0.9694 - val_prc: 0.7165
Epoch 4/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1347 - tp: 131100.8423 - fp: 4062.0036 - tn: 139467.2903 - fn: 12082.5233 - acc: 0.9432 - precision: 0.9699 - recall: 0.9146 - auc: 0.9890 - prc: 0.9900 - val_loss: 0.0661 - val_tp: 68.0000 - val_fp: 571.0000 - val_tn: 44919.0000 - val_fn: 11.0000 - val_acc: 0.9872 - val_precision: 0.1064 - val_recall: 0.8608 - val_auc: 0.9694 - val_prc: 0.7182
Epoch 5/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1192 - tp: 132650.0143 - fp: 3900.9642 - tn: 139417.2115 - fn: 10744.4695 - acc: 0.9486 - precision: 0.9712 - recall: 0.9245 - auc: 0.9918 - prc: 0.9923 - val_loss: 0.0570 - val_tp: 67.0000 - val_fp: 569.0000 - val_tn: 44921.0000 - val_fn: 12.0000 - val_acc: 0.9873 - val_precision: 0.1053 - val_recall: 0.8481 - val_auc: 0.9670 - val_prc: 0.7172
Epoch 6/100
278/278 [==============================] - 6s 22ms/step - loss: 0.1076 - tp: 133341.7706 - fp: 3745.2294 - tn: 139902.4552 - fn: 9723.2043 - acc: 0.9524 - precision: 0.9724 - recall: 0.9312 - auc: 0.9936 - prc: 0.9936 - val_loss: 0.0479 - val_tp: 67.0000 - val_fp: 508.0000 - val_tn: 44982.0000 - val_fn: 12.0000 - val_acc: 0.9886 - val_precision: 0.1165 - val_recall: 0.8481 - val_auc: 0.9682 - val_prc: 0.7194
Epoch 7/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0995 - tp: 134289.4444 - fp: 3585.8961 - tn: 140065.9319 - fn: 8771.3871 - acc: 0.9566 - precision: 0.9739 - recall: 0.9381 - auc: 0.9947 - prc: 0.9946 - val_loss: 0.0433 - val_tp: 67.0000 - val_fp: 502.0000 - val_tn: 44988.0000 - val_fn: 12.0000 - val_acc: 0.9887 - val_precision: 0.1178 - val_recall: 0.8481 - val_auc: 0.9685 - val_prc: 0.7209
Epoch 8/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0906 - tp: 136117.1470 - fp: 3446.8566 - tn: 139792.3333 - fn: 7356.3226 - acc: 0.9623 - precision: 0.9752 - recall: 0.9490 - auc: 0.9956 - prc: 0.9955 - val_loss: 0.0391 - val_tp: 67.0000 - val_fp: 458.0000 - val_tn: 45032.0000 - val_fn: 12.0000 - val_acc: 0.9897 - val_precision: 0.1276 - val_recall: 0.8481 - val_auc: 0.9647 - val_prc: 0.7208
Epoch 9/100
278/278 [==============================] - 6s 21ms/step - loss: 0.0871 - tp: 136469.3692 - fp: 3419.2043 - tn: 139679.6810 - fn: 7144.4050 - acc: 0.9630 - precision: 0.9756 - recall: 0.9498 - auc: 0.9959 - prc: 0.9957 - val_loss: 0.0363 - val_tp: 67.0000 - val_fp: 448.0000 - val_tn: 45042.0000 - val_fn: 12.0000 - val_acc: 0.9899 - val_precision: 0.1301 - val_recall: 0.8481 - val_auc: 0.9607 - val_prc: 0.7211
Epoch 10/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0823 - tp: 136684.5591 - fp: 3467.4803 - tn: 140006.9319 - fn: 6553.6882 - acc: 0.9649 - precision: 0.9751 - recall: 0.9540 - auc: 0.9963 - prc: 0.9961 - val_loss: 0.0333 - val_tp: 67.0000 - val_fp: 437.0000 - val_tn: 45053.0000 - val_fn: 12.0000 - val_acc: 0.9901 - val_precision: 0.1329 - val_recall: 0.8481 - val_auc: 0.9611 - val_prc: 0.7214
Epoch 11/100
278/278 [==============================] - 6s 22ms/step - loss: 0.0779 - tp: 137468.5520 - fp: 3409.0538 - tn: 139987.1111 - fn: 5847.9427 - acc: 0.9673 - precision: 0.9760 - recall: 0.9581 - auc: 0.9967 - prc: 0.9965 - val_loss: 0.0317 - val_tp: 67.0000 - val_fp: 436.0000 - val_tn: 45054.0000 - val_fn: 12.0000 - val_acc: 0.9902 - val_precision: 0.1332 - val_recall: 0.8481 - val_auc: 0.9515 - val_prc: 0.7120
Restoring model weights from the end of the best epoch.
Epoch 00011: early stopping
check training history
func_plot_metrics(model_sampled_history)
查看一下混淆矩阵
train_predictions_sampled = model_sampled.predict(train_features, batch_size=BATCH_SIZE)
test_predictions_sampled = model_sampled.predict(test_features, batch_size=BATCH_SIZE)
results = model_sampled.evaluate(test_features, test_labels, batch_size=BATCH_SIZE, verbose=0)
for name, value in zip(model_sampled.metrics_names, results):
print(name, ': ', value)
print()
func_plot_cm(test_labels, test_predictions_sampled)
loss : 0.19912180304527283
tp : 91.0
fp : 1298.0
tn : 55566.0
fn : 7.0
acc : 0.9770900011062622
precision : 0.06551475822925568
recall : 0.9285714030265808
auc : 0.9868985414505005
prc : 0.7712134718894958
Legitimate Transactions Detected (True Negatives): 55566
Legitimate Transactions Incorrectly Detected (False Positives): 1298
Fraudulent Transactions Missed (False Negatives): 7
Fraudulent Transactions Detected (True Positives): 91
Total Fraudulent Transactions: 98
画roc图
func_plot_roc("Train sampled", train_labels, train_predictions_sampled, color='red')
func_plot_roc("Test sampled", test_labels, test_predictions_sampled, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2c5c561fd0>
画PR图
func_plot_prc("Train sampled", train_labels, train_predictions_sampled, color='red')
func_plot_prc("Test sampled", test_labels, test_predictions_sampled, color='blue', linestyle='--')
plt.legend(loc='lower right')
<matplotlib.legend.Legend at 0x7f2c5c7c57c0>