Skip to content

Commit 517fbb0

Browse files
Add files via upload
1 parent de2fc8d commit 517fbb0

File tree

1 file changed

+52
-19
lines changed

1 file changed

+52
-19
lines changed

train.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import datetime
66
from sklearn.metrics import confusion_matrix
77
from sklearn.metrics import f1_score, precision_score, recall_score
8-
import time
98
import warnings
109
warnings.filterwarnings("ignore")
1110

@@ -17,13 +16,13 @@
1716
]
1817

1918
params = {
20-
'max_depth': 4,
21-
'eta': 0.05,
19+
'max_depth': 3,
20+
'eta': 0.03,
2221
'objective': 'binary:logistic',
2322
'eval_metric': 'auc',
2423
}
2524

26-
def train(train_features,train_labels,num_round=400):
25+
def train(train_features,train_labels,num_round=900):
2726
dtrain = xgb.DMatrix(train_features, label=train_labels)
2827
bst = xgb.train(params, dtrain, num_round)
2928
# get best_threshold
@@ -90,19 +89,15 @@ def get_feature_importances(bst):
9089
importance = sorted(importance, key=lambda x: x[0][1], reverse=True)
9190
return importance
9291

93-
if __name__ == '__main__':
94-
model_save_pth = "model/"+datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
95-
if not os.path.exists(model_save_pth):
96-
os.makedirs(model_save_pth)
92+
def train_loop(num_round=900):
9793
precision_list = []
9894
recall_list = []
9995
f1_list = []
10096
c_matrix_list = []
10197
feature_importance_list = []
10298
for i in range(len(os.listdir("Input"))):
103-
time.sleep(1)
10499
train_features, train_labels, test_features, test_labels = preprocess("Input/" + str(i))
105-
bst, best_threshold = train(train_features, train_labels)
100+
bst, best_threshold = train(train_features, train_labels, num_round)
106101
precision, recall, f1, c_matrix = test(bst,best_threshold, test_features, test_labels)
107102
feature_importance = get_feature_importances(bst)
108103
#print(f"Positive rate in Training: {sum(train_labels)/len(train_labels)*100:.2f}%")
@@ -116,12 +111,8 @@ def get_feature_importances(bst):
116111
bst.save_model(model_save_pth+f"/{i}.model")
117112
with open(model_save_pth+f"/{i}.threshold",'w') as f:
118113
f.write(str(best_threshold))
119-
# give evaluation results
120-
print("Average Precision: %.2f" % np.mean(precision_list))
121-
print("Average Recall: %.2f" % np.mean(recall_list))
122-
print("Average F1: %.2f" % np.mean(f1_list))
123-
print(f1_list)
124-
print(np.mean(c_matrix_list,axis=0))
114+
#print(f1_list)
115+
#print(np.mean(c_matrix_list,axis=0))
125116
# evaluate feature importance
126117
feature_name_importance = {}
127118
for feature_importance in feature_importance_list:
@@ -131,6 +122,48 @@ def get_feature_importances(bst):
131122
else:
132123
feature_name_importance[feature_name] = im[1]
133124
feature_name_importance = sorted(feature_name_importance.items(), key=lambda x: x[1], reverse=True)
134-
print('feature importance:')
135-
for item in feature_name_importance:
136-
print(item)
125+
return precision_list, recall_list, f1_list, c_matrix_list, feature_name_importance
126+
127+
def optimize_hyperparameter(eta_candid,max_depth_candid,num_round_candid):
128+
best_f1 = 0
129+
for eta in eta_candid:
130+
for max_depth in max_depth_candid:
131+
for num_round in num_round_candid:
132+
print(eta, max_depth, num_round)
133+
params["eta"] = eta
134+
params["max_depth"] = max_depth
135+
precision_list, recall_list, f1_list, c_matrix_list, feature_name_importance = train_loop(num_round)
136+
if np.mean(f1_list) > best_f1:
137+
best_f1 = np.mean(f1_list)
138+
best_params = params
139+
best_precision = np.mean(precision_list)
140+
best_recall = np.mean(recall_list)
141+
return best_params, best_precision, best_recall, best_f1
142+
143+
144+
if __name__ == '__main__':
145+
model_save_pth = "model/"+datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
146+
if not os.path.exists(model_save_pth):
147+
os.makedirs(model_save_pth)
148+
149+
# tune parameters
150+
if False:
151+
eta_candidate = [0.3,0.2,0.15,0.1,0.08,0.05,0.03]
152+
max_depth_candidate = [5,10,15,20,25,30,35,40,45,50]
153+
num_round_candidate = [100,200,300,400,500,600,700,800,900,1000]
154+
best_params,best_precision, best_recall, best_f1 = optimize_hyperparameter(eta_candidate,max_depth_candidate,num_round_candidate)
155+
print(best_params)
156+
print(best_precision)
157+
print(best_recall)
158+
print(best_f1)
159+
160+
precision_list, recall_list, f1_list, c_matrix_list, feature_name_importance = train_loop()
161+
# give evaluation results
162+
print("Average Precision: %.3f" % np.mean(precision_list))
163+
print("Average Recall: %.3f" % np.mean(recall_list))
164+
print("Average F1: %.3f" % np.mean(f1_list))
165+
print(f1_list)
166+
print("Average Confusion Matrix: \n", np.mean(c_matrix_list,axis=0))
167+
print("Feature Importance:")
168+
for importance in feature_name_importance:
169+
print(f"{importance[0]}: {importance[1]}")

0 commit comments

Comments
 (0)