Функция продолжает выбирать первый clf как лучшую модель

#python #pandas #scikit-learn #neural-network #classification

Вопрос:

Я следовал этому руководству по созданию классификатора жанров для midi-файлов и, наконец, добрался до той части, где я могу экспериментировать со скрытыми размерами сети и т. Д., Однако я обнаружил, что моя функция продолжает выбирать первую модель как лучшую, независимо от ее точности. Код выглядит так, как будто он должен работать для меня, но я чувствую, что мне не хватает чего-то действительно простого.

 def train_model(t_features, t_labels, v_features, v_labels):
"""
this func trains a nn using a few dif configs
INPUT: training features(nparray flt), training labels(nparray int), validation features(nparray float), validation labels(nparay int)

OUTPUT: the classifier which achieved the best validation accuracy (sklearn neural multilayer perceptron)
"""
#NN and SVM Configs
clf1 = MLPClassifier(solver='adam', alpha=1e-4, hidden_layer_sizes=(100, 100), random_state= 1)
clf2 = MLPClassifier(solver='adam', alpha=1e-4, hidden_layer_sizes=( 40, 20, 10, 5, 1), random_state= 1)
clf3 = MLPClassifier(solver='adam', alpha=1e-4, hidden_layer_sizes=(240, 120, 80, 40, 20, 10, 1), random_state= 1)
clf4 = MLPClassifier(solver='sgd', alpha=1e-4, hidden_layer_sizes=(12, 10), random_state= 1)
clf_svm = SVC()

#keep track of best model
best_clf = None
best_accurracy = 0


#test accuracies of models and get the best one
for clf in [clf1, clf2, clf3, clf4, clf_svm]:
    t_labels_hot = one_hot(t_labels)
    v_labels_hot = one_hot(v_labels)
    if (type(clf) == SVC):
        clf = clf.fit(t_features, t_labels)
    else:
        clf = clf.fit(t_features, t_labels_hot)
    predictions = clf.predict(v_features)
    count = 0
    for i in range(len(v_labels)):
        if (type(clf) != SVC):
            if np.array_equal(v_labels_hot[i], predictions[i]):
                count  = 1
        else:
            if (v_labels[i] == predictions[i]):
                count  = 1

    accuracy = count / len(v_labels_hot)
    if accuracy > best_accurracy:
        best_accurracy= accuracy
        best_clf = clf
print(best_clf)
print("best accuracy:", best_accurracy)
return best_clf