使用Python 2.7实现的垃圾短信识别器

  最近参加比赛,写了一个垃圾短信识别器,在这里做一下记录。

  官方提供的数据是csv文件,其中训练集有80万条数据,测试集有20万条数据,训练集的格式为:行号 标记(0为普通短信,1为垃圾短信) 短信内容;测试集的格式为: 行号 短信内容;要求输出的数据格式要求为: 行号 标记,以csv格式保存。

  实现的原理可概括为以下几步:

    1.读取文件,输入数据

    2.对数据进行分割,将每一行数据分成行号、标记、短信内容。由于短信内容中可能存在空格,故不能简单地用split()分割字符串,应该用正则表达式模块re进行匹配分割。

    3.将分割结果存入数据库(MySQL),方便下次测试时直接从数据库读取结果,省略步骤。

    4.对短信内容进行分词,这一步用到了第三方库结巴分词:https://github.com/fxsjy/jieba

    5.将分词的结果用于训练模型,训练的算法为朴素贝叶斯算法,可调用第三方库Scikit-Learn:http://scikit-learn.org/stable 

    6.从数据库中读取测试集,进行判断,输出结果并写入文件。

  最终实现出来一共有4个py文件:

    1.ImportIntoDB.py 将数据进行预处理并导入数据库,仅在第一次使用。

    2.DataHandler.py 从数据库中读取数据,进行分词,随后处理数据,训练模型。

    3.Classifier.py 从数据库中读取测试集数据,利用训练好的模型进行判断,输出结果到文件中。

    4.Main.py 程序的入口

 使用Python 2.7实现的垃圾短信识别器

  最终程序每次运行耗时平均在260秒-270秒之间,附代码:

  ImportIntoDB.py:

 1 # -*- coding:utf-8 -*-
 2 __author__ = 'Jz'
 3 
 4 import MySQLdb
 5 import codecs
 6 import re
 7 import time
 8 
 9 # txt_path = 'D:/coding_file/python_file/Big Data/trash message/train80w.txt'
10 txt_path = 'D:/coding_file/python_file/Big Data/trash message/test20w.txt'
11 
12 # use regular expression to split string into parts
13 # split_pattern_80w = re.compile(u'([0-9]+).*?([01])(.*)')
14 split_pattern_20w = re.compile(u'([0-9]+)(.*)')
15 
16 txt = codecs.open(txt_path, 'r')
17 lines = txt.readlines()
18 start_time = time.time()
19 
20 #connect mysql database
21 con = MySQLdb.connect(host = 'localhost', port = 3306, user = 'root', passwd = '*****', db = 'TrashMessage', charset = 'UTF8')
22 cur = con.cursor()
23 
24 # insert into 'train' table
25 # sql = 'insert into train(sms_id, sms_type, content) values (%s, %s, %s)'
26 # for line in lines:
27 #     match = re.match(split_pattern_80w, line)
28 #     sms_id, sms_type, content = match.group(1), match.group(2), match.group(3).lstrip()
29 #     cur.execute(sql, (sms_id, sms_type, content))
30 #     print sms_id
31 # # commit transaction
32 # con.commit()
33 
34 # insert into 'test' table
35 sql = 'insert into test(sms_id, content) values (%s, %s)'
36 for line in lines:
37     match = re.match(split_pattern_20w, line)
38     sms_id, content = match.group(1), match.group(2).lstrip()
39     cur.execute(sql, (sms_id, content))
40     print sms_id
41 # commit transaction
42 con.commit()
43 
44 cur.close()
45 con.close()
46 txt.close()
47 end_time = time.time()
48 print 'time-consuming: ' + str(end_time - start_time) + 's.'

  DataHandler.py:

 1 # -*- coding:utf-8 -*-
 2 __author__ = 'Jz'
 3 
 4 import MySQLdb
 5 import jieba
 6 import re
 7 
 8 class DataHandler:
 9     def __init__(self):
10         try:
11             self.con = MySQLdb.connect(host = 'localhost', port = 3306, user = 'root', passwd = '*****', db = 'TrashMessage', charset = 'UTF8')
12             self.cur = self.con.cursor()
13         except MySQLdb.OperationalError, oe:
14             print 'Connection error! Details:', oe
15 
16     def __del__(self):
17         self.cur.close()
18         self.con.close()
19 
20     # obsolete function
21     # def getConnection(self):
22     #     return self.con
23 
24     # obsolete function
25     # def getCursor(self):
26     #     return self.cur
27 
28     def query(self, sql):
29         self.cur.execute(sql)
30         result_set = self.cur.fetchall()
31         return result_set
32 
33     def resultSetTransformer(self, train, test):
34         # list of words divided by jieba module after de-duplication
35         train_division = []
36         test_division = []
37         # list of classification of each message
38         train_class = []
39         # divide messages into words
40         for record in train:
41             train_class.append(record[1])
42             division = jieba.cut(record[2])
43             filtered_division_set = set()
44             for word in division:
45                 filtered_division_set.add(word + ' ')
46             division = list(filtered_division_set)
47             str_word = ''.join(division)
48             train_division.append(str_word)        
49 
50         # handle test set in a similar way as above
51         for record in test:
52             division = jieba.cut(record[1])
53             filtered_division_set = set()
54             for word in division:
55                 filtered_division_set.add(word + ' ')
56             division = list(filtered_division_set)
57             str_word = ''.join(division)
58             test_division.append(str_word)
59 
60         return train_division, train_class, test_division

  Classifier.py:

 1 # -*- coding:utf-8 -*-
 2 __author__ = 'Jz'
 3 
 4 from DataHandler import DataHandler
 5 from sklearn.feature_extraction.text import TfidfVectorizer
 6 from sklearn.feature_extraction.text import TfidfTransformer
 7 from sklearn.feature_extraction.text import CountVectorizer
 8 from sklearn.naive_bayes import MultinomialNB
 9 import time
10 
11 class Classifier:
12     def __init__(self):
13         start_time = time.time()
14         self.data_handler = DataHandler()
15         # get result set
16         self.train = self.data_handler.query('select * from train')
17         self.test = self.data_handler.query('select * from test')
18         self.train_division, self.train_class, self.test_division = self.data_handler.resultSetTransformer(self.train, self.test)
19         end_time = time.time()
20         print 'Classifier finished initializing, time-consuming:' + str(end_time - start_time) + 's.'
21 
22     def getMatrices(self):
23         start_time = time.time()
24         # convert a collection of raw documents to a matrix of TF-IDF features.
25         self.tfidf_vectorizer = TfidfVectorizer()
26         # learn vocabulary and idf, return term-document matrix [sample, feature]
27         self.train_count_matrix = self.tfidf_vectorizer.fit_transform(self.train_division)
28         # transform the count matrix of the train set to a normalized tf-idf representation 
29         self.tfidf_transformer = TfidfTransformer()
30         self.train_tfidf_matrix = self.tfidf_transformer.fit_transform(self.train_count_matrix)
31         end_time = time.time()
32         print 'Classifier finished getting matrices, time-consuming:' + str(end_time - start_time) + 's.'
33 
34     def classify(self):
35         self.getMatrices()
36         start_time = time.time()
37         # convert a collection of text documents to a matrix of token counts
38         # scikit-learn doesn't support chinese vocabulary
39         test_tfidf_vectorizer = CountVectorizer(vocabulary = self.tfidf_vectorizer.vocabulary_)
40         # learn the vocabulary dictionary and return term-document matrix.
41         test_count_matrix = test_tfidf_vectorizer.fit_transform(self.test_division)
42         # transform a count matrix to a normalized tf or tf-idf representation
43         test_tfidf_transformer = TfidfTransformer()
44         test_tfidf_matrix = test_tfidf_transformer.fit(self.train_count_matrix).transform(test_count_matrix)
45 
46         # the multinomial Naive Bayes classifier is suitable for classification with discrete features
47         # e.g., word counts for text classification).
48         naive_bayes = MultinomialNB(alpha = 0.65)
49         naive_bayes.fit(self.train_tfidf_matrix, self.train_class)
50         prediction = naive_bayes.predict(test_tfidf_matrix)
51 
52         # output result to a csv file
53         index = 0
54         csv = open('result.csv', 'w')
55         for sms_type in prediction:
56             csv.write(str(self.test[index][0]) + ',' + str(sms_type) + '
')
57             index += 1
58         csv.close()
59         end_time = time.time()
60         print 'Classifier finished classifying, time-consuming: ' + str(end_time - start_time) + 's.'

  Main.py:

 1 # -*- coding:utf-8 -*-
 2 __author__ = 'Jz'
 3 
 4 import time
 5 from Classifier import Classifier
 6 
 7 start_time = time.time()
 8 classifier = Classifier()
 9 classifier.classify()
10 end_time = time.time()
11 print 'total time-consuming: ' + str(end_time - start_time) + 's.'