Overview
近期正在更新新用户模型,仍然在用XGBoost。由于训练集数据已经达到20W,故用神经网络来训练一下,看看效果如何。
TensorFlow 2.0集成了Keras,易用性很高,且Keras之后不再单独更新了,而是作为TensorFlow的一个模块来使用。我们这次就用TensorFlow 2.0中的tf.keras来训练我们的结构化数据。
1. 导入特征列表及数据
import numpy as np
import pandas as pd
import sklearn
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow.keras import models, layers, losses, metrics
# 特征列表
name_list = pd.read_csv('feature_list.txt', header=None, index_col=0)
my_feature_names = list(name_list.transpose())
# 导入数据
df_total = pd.read_csv('data_total.csv')
df_total.head()
2. 数据处理和数据集划分
# 空值填充为0
df_total = df_total.fillna(0)
# 划分数据集
df_train = df_total[df_total.apply_time < '2020-01-21 00:00:00']
df_val = df_total[(df_total.apply_time >= '2020-01-21 00:00:00') & (df_total.apply_time < '2020-02-01 00:00:00')]
df_test = df_total[df_total.apply_time >= '2020-02-01 00:00:00']
# 选取我们需要的数据列
train_x = df_train[my_feature_names]
train_y = df_train['label']
val_x = df_val[my_feature_names]
val_y = df_val['label']
test_x = df_test[my_feature_names]
test_y = df_test['label']
# 数据标准化
scaler = StandardScaler()
train_x = scaler.fit_transform(train_x)
val_x = scaler.transform(val_x)
test_x = scaler.transform(test_x)
3. 模型构建
tf.keras.backend.clear_session()
METRICS = [
tf.keras.metrics.AUC(name='auc'),
]
def make_model(metrics = METRICS, output_bias=None):
if output_bias is not None:
output_bias = tf.keras.initializers.Constant(output_bias)
model = tf.keras.Sequential([
layers.Dense(
64, activation='relu',
input_shape=(train_x.shape[-1],)),
layers.Dropout(0.2),
layers.Dense(
128, activation='relu'),
layers.Dropout(0.2),
layers.Dense(
32, activation='relu'),
layers.Dense(1, activation='sigmoid',
bias_initializer=output_bias),
])
model.compile(
optimizer=tf.keras.optimizers.Adam(lr=1e-3),
loss=losses.BinaryCrossentropy(),
metrics=metrics)
return model
# 设置早停
EPOCHS = 100
BATCH_SIZE = 2000
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_auc',
verbose=1,
patience=20,
mode='max',
restore_best_weights=True)
# 处理不平衡问题
neg = len(train_y) - sum(train_y)
pos = sum(train_y)
total = len(train_y)
weight_for_0 = (1 / neg)*(total)/2.0
weight_for_1 = (1 / pos)*(total)/2.0
class_weight = {0: weight_for_0, 1: weight_for_1}
# 构建模型
model = make_model()
model.summary()
输出如下:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 64) 16384
_________________________________________________________________
dropout (Dropout) (None, 64) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 8320
_________________________________________________________________
dropout_1 (Dropout) (None, 128) 0
_________________________________________________________________
dense_2 (Dense) (None, 32) 4128
_________________________________________________________________
dense_3 (Dense) (None, 1) 33
=================================================================
Total params: 28,865
Trainable params: 28,865
Non-trainable params: 0
4. 模型训练
weighted_history = model.fit(
train_x,
train_y,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks = [early_stopping],
validation_data=(val_x, val_y),
# 设置类权重
class_weight=class_weight)
输出如下:
Train on 206917 samples, validate on 15830 samples
Epoch 1/100
206917/206917 [==============================] - 3s 12us/sample - loss: 0.6584 - auc: 0.6498 - val_loss: 0.6108 - val_auc: 0.6729
Epoch 2/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6305 - auc: 0.6974 - val_loss: 0.6042 - val_auc: 0.6840
Epoch 3/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6238 - auc: 0.7075 - val_loss: 0.6018 - val_auc: 0.6895
Epoch 4/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6190 - auc: 0.7142 - val_loss: 0.5987 - val_auc: 0.6940
Epoch 5/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6157 - auc: 0.7190 - val_loss: 0.5978 - val_auc: 0.6961
Epoch 6/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6126 - auc: 0.7230 - val_loss: 0.5957 - val_auc: 0.6989
Epoch 7/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6104 - auc: 0.7257 - val_loss: 0.5951 - val_auc: 0.7007
Epoch 8/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6082 - auc: 0.7284 - val_loss: 0.5947 - val_auc: 0.7019
Epoch 9/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6067 - auc: 0.7301 - val_loss: 0.5937 - val_auc: 0.7034
Epoch 10/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6043 - auc: 0.7335 - val_loss: 0.5937 - val_auc: 0.7038
Epoch 11/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6035 - auc: 0.7344 - val_loss: 0.5934 - val_auc: 0.7036
Epoch 12/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6016 - auc: 0.7365 - val_loss: 0.5924 - val_auc: 0.7046
Epoch 13/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6013 - auc: 0.7367 - val_loss: 0.5930 - val_auc: 0.7041
Epoch 14/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5996 - auc: 0.7390 - val_loss: 0.5925 - val_auc: 0.7042
Epoch 15/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5984 - auc: 0.7403 - val_loss: 0.5930 - val_auc: 0.7045
Epoch 16/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5976 - auc: 0.7412 - val_loss: 0.5937 - val_auc: 0.7034
Epoch 17/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5961 - auc: 0.7430 - val_loss: 0.5942 - val_auc: 0.7034
Epoch 18/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5948 - auc: 0.7444 - val_loss: 0.5946 - val_auc: 0.7027
Epoch 19/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5938 - auc: 0.7455 - val_loss: 0.5949 - val_auc: 0.7023
Epoch 20/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5924 - auc: 0.7472 - val_loss: 0.5944 - val_auc: 0.7024
Epoch 21/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5925 - auc: 0.7471 - val_loss: 0.5953 - val_auc: 0.7028
Epoch 22/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5915 - auc: 0.7482 - val_loss: 0.5944 - val_auc: 0.7022
Epoch 23/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5906 - auc: 0.7488 - val_loss: 0.5964 - val_auc: 0.7008
Epoch 24/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5900 - auc: 0.7496 - val_loss: 0.5947 - val_auc: 0.7025
Epoch 25/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5894 - auc: 0.7503 - val_loss: 0.5956 - val_auc: 0.7031
Epoch 26/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5882 - auc: 0.7517 - val_loss: 0.5944 - val_auc: 0.7028
Epoch 27/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5870 - auc: 0.7532 - val_loss: 0.5975 - val_auc: 0.7001
Epoch 28/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5869 - auc: 0.7530 - val_loss: 0.5965 - val_auc: 0.7022
Epoch 29/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5861 - auc: 0.7537 - val_loss: 0.5970 - val_auc: 0.7011
Epoch 30/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5854 - auc: 0.7543 - val_loss: 0.5960 - val_auc: 0.7015
Epoch 31/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5844 - auc: 0.7559 - val_loss: 0.5994 - val_auc: 0.6989
Epoch 32/100
206000/206917 [============================>.] - ETA: 0s - loss: 0.5835 - auc: 0.7568Restoring model weights from the end of the best epoch.
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5836 - auc: 0.7568 - val_loss: 0.5982 - val_auc: 0.6992
Epoch 00032: early stopping
验证集最好的AUC是0.7046,和XGBoost训练的还是有些差距,经过调参之后,应该会更接近一些。
本文主要参考了官方文档的以下内容:对结构化数据进行分类 和 Classification on imbalanced data。