scikit-learn 之 文本分类

相关链接

分类器、指标、特征提取、特征选择

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import json
import numpy as np
import warnings
warnings.filterwarnings('ignore')

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn import metrics

from sklearn.naive_bayes import GaussianNB
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import ExtraTreeClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neural_network import BernoulliRBM
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neighbors import RadiusNeighborsClassifier
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.svm import libsvm
from sklearn.ensemble import AdaBoostClassifier
from sklearn.ensemble import BaggingClassifier
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import VotingClassifier
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA

from xgboost import XGBClassifier

from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import SelectPercentile
from sklearn.feature_selection import chi2

数据处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def read_file(filename):
all_data, all_tag = [], []
with open(filename, 'r', encoding='utf-8') as fr:
for line in fr:
row = json.loads(line)
all_data.append(row["question"])
all_tag.append(row["Coarse"])
return all_data, all_tag


def get_train_test():
train_file = './QC/data/问题集/trainquestion.json'
test_file = './QC/data/问题集/testquestion.json'
x_train, y_train = read_file(train_file)
x_test, y_test = read_file(test_file)
return x_train, y_train, x_test, y_test


def get_ngram_seg(sen, n):
ngram_list = []
for i in range(len(sen)):
if i + n <= len(sen):
ngram_list.append(sen[i: i+n])
return ngram_list


def get_seg(DATA):
# return [" ".join(jieba.lcut(d)) for d in DATA]
return [" ".join(get_ngram_seg(d, 1)) for d in DATA]


def PRF(y_true, y_pred):
acc_test = metrics.accuracy_score(y_true, y_pred)
P_test = metrics.precision_score(y_true, y_pred, average='macro')
R_test = metrics.recall_score(y_true, y_pred, average='macro')
F_test = metrics.f1_score(y_true, y_pred, average='macro')
print(acc_test, P_test, R_test, F_test)


def CLF(model):
x_train, y_train, x_test, y_test = get_train_test()
x_train_seg = get_seg(x_train)
x_test_seg = get_seg(x_test)

vectorizer = CountVectorizer(token_pattern='\\b\\w+\\b')
vectorizer.fit(x_train_seg)
# print(vectorizer.vocabulary_) # 词汇

bow_train = vectorizer.transform(x_train_seg) # 词袋特征 one-hot向量
bow_test = vectorizer.transform(x_test_seg)
# print("特征提取前:", end=' ')
# print(bow_train.shape, end=' ')
# print(bow_test.shape, end='\t')

# tfidf_transformer = TfidfTransformer()
# tfidf_transformer.fit(bow_train.toarray())
# tfidf_train = tfidf_transformer.transform(bow_train) # TFIDF特征
# tfidf_test = tfidf_transformer.transform(bow_test)

# FS = SelectKBest(chi2, k=3000) # 选择topK特征
FS = SelectPercentile(chi2, percentile=100) # 选择百分比
bow_train_new = FS.fit_transform(bow_train, np.array(y_train))
feature_index = FS.get_support(True) # 特征选择后,保留的特征维度
bow_test_new = bow_test[:, feature_index]

# print("特征提取后:", end=' ')
# print(bow_train_new.shape, end=' ')
# print(bow_test_new.shape, end='\t')

clf = model.fit(bow_train_new.toarray(), np.array(y_train))
y_pred = clf.predict(bow_test_new.toarray())
print(model.__class__.__name__, end='\t') # 模型名
class_prob = model.predict_proba(bow_x)[0] # 概率分布值
class_list = model.classes_ # 类别标签列表
PRF(y_test, y_pred)
# print(metrics.classification_report(y_test, y_pred))

测试不同分类器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# CLF(GaussianNB())
CLF(MultinomialNB())
CLF(LogisticRegression())
CLF(SGDClassifier())
CLF(DecisionTreeClassifier())
CLF(ExtraTreeClassifier())
CLF(MLPClassifier())
# CLF(BernoulliRBM())
CLF(KNeighborsClassifier())
# CLF(RadiusNeighborsClassifier())
CLF(SVC(C=1, kenerl='linear', gamma=1, shrinkling=True, probability)) # 基于libsvm实现
CLF(LinearSVC())
# CLF(libsvm)
CLF(AdaBoostClassifier())
CLF(BaggingClassifier())
CLF(ExtraTreesClassifier())
CLF(GradientBoostingClassifier())
CLF(RandomForestClassifier())
# CLF(VotingClassifier())
CLF(QDA())
CLF(LDA())
CLF(XGBClassifier())
坚持原创技术分享,您的支持将鼓励我继续创作!
0%