-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathlog_reg.py
59 lines (45 loc) · 2.04 KB
/
log_reg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import numpy as np
from time import time
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from pyspark import SparkContext
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import LogisticRegressionWithLBFGS
def parse_doc(values):
return LabeledPoint(values[0], values[1:])
def main():
#parameters
num_features = 400 #vocabulary size
#load data
print "loading 20 newsgroups dataset..."
categories = ['rec.autos','rec.sport.hockey','comp.graphics','sci.space']
tic = time()
dataset = fetch_20newsgroups(shuffle=True, random_state=0, categories=categories, remove=('headers','footers','quotes'))
train_corpus = dataset.data # a list of 11314 documents / entries
train_labels = dataset.target
toc = time()
print "elapsed time: %.4f sec" %(toc - tic)
#tf-idf vectorizer
tfidf = TfidfVectorizer(max_df=0.5, max_features=num_features, \
min_df=2, stop_words='english', use_idf=True)
X_tfidf = tfidf.fit_transform(train_corpus).toarray()
#append document labels
train_labels = train_labels.reshape(-1,1)
X_all = np.hstack([train_labels, X_tfidf])
#distribute the data
sc = SparkContext('local', 'log_reg')
rdd = sc.parallelize(X_all)
labeled_corpus = rdd.map(parse_doc)
train_RDD, test_RDD = labeled_corpus.randomSplit([8, 2], seed=0)
#distributed logistic regression
print "training logistic regression..."
model = LogisticRegressionWithLBFGS.train(train_RDD, regParam=1, regType='l1', numClasses=len(categories))
#evaluated the model on test data
labels_and_preds = test_RDD.map(lambda p: (p.label, model.predict(p.features)))
test_err = labels_and_preds.filter(lambda (v, p): v != p).count() / float(test_RDD.count())
print "log-reg test error: ", test_err
#model.save(sc, './log_reg_lbfgs_model')
if __name__ == "__main__":
main()