Data diagnosis for high dimensional multi-class machine learning
-
Upload
fang-pentel -
Category
Documents
-
view
17 -
download
0
Transcript of Data diagnosis for high dimensional multi-class machine learning
000001002003004005006007008009010011012013014015016017018019020021022023024025026027028029030031032033034035036037038039040041042043044045046047048049050051052053054
055056057058059060061062063064065066067068069070071072073074075076077078079080081082083084085086087088089090091092093094095096097098099100101102103104105106107108109
Data diagnosis for high dimensional multi-class machine learningCS 6780 (Spring 2015)
AbstractMulticlass labeling of high dimensional dataholds importance for industries such as online re-tailing. Due to the complexity of the data and re-lated nature of the product classes, this machinelearning task is not trivial. It is of primary in-terest to understand whether a particular learningalgorithm is best suited to this particular data set.In the case of no clear winner, we endeavored todiagnose the data to identify why learning wasdifficult for all of the algorithms, and attemptedto address the difficulty with available learningalgorithms.
1. IntroductionProper classification of products sold is key to ensuring ac-curate understanding of product feedback, quality control,and a number of other business operations for a typical on-line retailer. Now, expand this scenario to include a widearray of online retailers operating under a single parentcompany. Specifically, we are interested in data from theOtto Group, one of the world’s biggest e-commerce com-panies, responsible for millions of products sold worldwideeach day, with a constant incoming stream of thousandsof new products. Due to the large number of subsidiariesand varying infrastructures within each one, the potentialfor identical products to be classified differently betweensubsidiaries exists at a non-negligable rate. We aim to in-vestigate the performance of various learning algorithms tocorrectly classify these products, and attempt to investigatewhy certain algorithms out-perform others in specific cate-gories. The work begins with a survey of available learningpackages and concludes with an in-depth look as to the dif-ferences or similarities in accuracy and performance.
2. DataThe data is from Kaggle (Kaggle, 2015), made availablethrough the Otto group. There are a total of 93 features and
Preliminary work. Under review by the International Conferenceon Machine Learning (ICML). Do not distribute.
9 classes. Each record indicates the count of each of the 93features, along with the actual class label. The specific na-ture of the features or classes have not been made available;they are simple numerical descriptions (1 to 93 for features,1 to 9 for classes).
The dataset has 61,878 records. We chose to separate thisdata into a training set and test set of 75% and 25%, respec-tively. On the training set, a 5-fold cross validation schemeis used to evaluate the performance and robustness of eachlearning method.
We began by looking for high correlation values betweeneach of the 93 features in the dataset. A correlation plot isshown in Figure 1. Though it does appear that many of thefeatures are loosely correlated, it did not appear that anyfeatures were worth excluding altogether. We also lookedat the correlation matrices for each of the nine classes tosee if there was a particular “footprint” in correlation forany one of the classes that would inform feature selection,but did not find that this was the case. Thus, we movedforward with a complete feature set.
0 20 40 60 80Feature #
0
20
40
60
80
Featu
re #
0.15
0.00
0.15
0.30
0.45
0.60
0.75
0.90
Figure 1. Plot of feature correlation matrix. High correlation val-ues only occur on the diagonal.
Because this is a Kaggle competition, an additional test setof 18,000 records is also available through Kaggle, but onthis additional test set the class label is unavailable. To seehow our model(s) compare with others, we submitted our
110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
CS 6780: Advanced Machine Learning Final Project
best models to Kaggle on their test set to see how we scoredon their log-loss grading scale. Note, however, that scoringon the leaderboard is not the main intention of this project.
3. Learning MethodsA variety of different methods were used to begin in-vestigating the best methodology to approach this high-dimensional multiclass machine learning problem. An out-line of our overall work flow is provided in Figure 2. InStages 1-3, we used our training set (75% of overall data),and after we finalized our models, we applied our trainedmodels to the test set in Stage 4. In Stage 1, we exploreda variety of learning algorithms. In Stage 2, we aggregatedthe results of our promising models proportional to the per-formance we saw in Stage 1. In Stage 3, we re-train threeof the models from Stage 1 and the two models from Stage2 on the complete training set. Finally, in the last stage,we test the five trained models in order to compare perfor-mance.
Train!
SVM!
RF!
NN!
NN!
RF!
RF!SVM!NN!
5 Fold CV!
NN!
RF!
RF!SVM!NN!
5 Fold CV! Test!
SVM!
RF!
NN!
NN!
RF!
RF!SVM!NN!Training set
75% Total Data!
Test set!25% Total Data!
Key
Stage 1 Stage 2 Stage 3 Stage 4
SVM!
RF!
NN!
NB!
Figure 2. An overview of the work flow covered in this paper.Stage 1 consisted of the basic methods, including Random Forests(RF), support vector machines (SVM), naive Bayes (NB), andneural networks (NN), optimized over appropriate grid searcheswhen relevant. This set was five-fold cross validated on the train-ing set. Stage 2 aggregated different combinations of promisingmodels from Stage 1 and weighted them according to their Stage1 cross validation performance. These were again cross validatedon the same five-fold scheme. In Stage 3, we re-trained five of themodels on the whole training set. In Stage 4, we ran the trainedmodels from Stage 3 on the remaining test set.
As shown in Figure 2, the team began with basic methodsfrom different learning algorithm families, including Ran-dom Forests (RF), support vector machines (SVM), naiveBayes (NB), and neural networks (NN). In addition to the
four models shown in the figure under Stage 1, we alsolooked initially at decision trees (DT) before using a ran-dom forest ensemble. A survey of these various methodsenabled us to gauge the potential performance of the differ-ent learning algorithms on our particular data set. We usedavailable machine learning packages from scikit-learn (Pe-dregosa et al., 2011) for the first three, and Keras (Chollet,Francois, 2015; LISA lab, University of Montreal, 2015) toimplement a neural network. Computing was performed inboth Python and MATLAB.
After an initial attempt through the four methods, the deci-sion tree seemed to hold promise due to its fast nature andeasy implementation in ensemble models. Similarly, theneural network also appeared to have an advantage over theother methods. The linear SVM algorithm did not performas well, but we pursued it in order to discern any possi-ble insights about our data that the others may not provide.Naive Bayes performed the most poorly compared to theothers; the independence assumption in this scenario, be-tween the 93 features, does not hold in data set (Murphy,2012; Shalev-Shwartz & Ben-David, 2014). Moving for-ward, we did not include Naive Bayes in subsequent anal-ysis.
3.1. Decision Trees: Ensemble Methods
The decision tree is a favorable option for creating an en-semble. It is computationally cheap to generate many shal-low trees, or ”weak learners”. We initially investigated boththe bootstrap aggregating (”bagging”) method (Breiman,1996) as well as a random forest (Breiman, 2001). For bag-ging, we adjusted the learning rate and number of estima-tors. Similarly, ranging over maximum depth of trees (20-60) and number of estimators (10-200) as shown in Figure3, we chose a maximum depth of 38 and 100 estimators forthe random forest parameters. We moved forward with therandom forest since it outperformed bagging in the earlytrials. The mean accuracy of 5-fold cross validation for therandom forest model is 80.4%.
3.2. Linear SVMs
Even initially, the linear SVM did not appear as promisingas the other options, since the classes did not appear to belinearly separable. We decided to keep this model in orderto explore tuning parameters to increase performance dur-ing cross validation (Cristianini & Shawe-Taylor, 2000).Mainly, we began by adjusting the regularization parameterC, as well as a search over number of features used in train-ing. The features are picked in the order of chi2 values.This analysis resulted in choosing a value of C = 1 andlimiting analysis to include 60 features, which appearedto be an appropriate number to ensure against under- andover-fitting. Figure 4 (a) shows the validation scores with
220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
CS 6780: Advanced Machine Learning Final Project
Figure 3. Grid search on maximum depth of trees and number ofestimators for random forest. The vertical axis represents log ac-curacy of 5-fold cross validation. Optimal parameters are maxi-mum depth of 38 and 100 estimators.
respect to parameter C using all features, and (b) showsthe scores with respect to the number of features whenC = 1. The mean accuracy of 5-fold cross validation forlinear SVM is 73.1%.
−20 −10 0 10
0.65
0.7
0.75
0.8
Regularization parameter C (log)
Val
idat
ion
Sco
res
0 50 1000.2
0.4
0.6
0.8
# of features
Val
idat
ion
Scor
es
(a) (b)
Figure 4. (a) Validation scores with respect to regularization pa-rameter C, and (b) number of features for linear SVM. Optimalparameters are C = 1 and 60 features.
3.3. Neural Networks
In using a neural network approach to this learning prob-lem, we explored two main approaches: one, where theneural network replicates a pyramid shape, starting a firstlayer with hidden units less than the input layer and narrow-ing in with each layer; and two, where the neural networkexpands in width from the input of 93 features, then nar-rows and is relatively shallow (Daniel Nouri, 2014). Weexpected to see better performance from the former (a nar-row and deep network), but we found that the latter scheme(a wide and thin network) outperformed the former duringcross validation.
The final neural network model consisted of 3 hidden lay-ers, starting from an input of 93 features, out to a width of500, 400, then 300, and finally to an output of 9 classes.Figure 5 shows the relation between maximum epoch andcross validation scores. Maximum epoch 30 is used to pre-vent overfitting. The mean accuracy of 5-fold cross valida-tion for neural networks is 79.8%.
0 20 40 60 80 1000.77
0.78
0.79
0.8
0.81
0.82
Max EpochV
alid
atio
n Sc
ores
Figure 5. Relation between number of epochs and cross validationscores for neural network with 3 hidden layers, and 500, 400, to300 hidden units.
4. Data DiagnosisAll of the methods reached a limit in their performancearound 73-80% accuracy during cross validation. This limitnecessitated further investigation. We began by identify-ing which classes were the most difficult to correctly label,and if there was any difference in this class-by-class accu-racy between methods. The accuracy of each learning algo-rithm, P (Y = y|Y = y), where Y is the true class and Yis the predicted class, on the true class label is presented inFigure 6. From the Figure, it can be seen that three classes1, 3, and 4 were highly misclassified.
Within these classes, Class 1 was often mistaken for Class9, while Classes 3 and 4 were often labeled as Class 2. Thiswas the case across all three learning algorithms, as pre-sented in the confusion matrices for the three models (RF,SVM, and NN, shown in Figures 7, 8, and 9, respectively).When looking at the accuracy of our prediction of class yfor each training record, that is, P (Y = y|Y = y), theprediction accuracy differences were a little bit more sub-tle, and the lower performing classes included the mistakenClasses 2 and 9 (Figure 10). These trends held across alllearning algorithms.
To investigate further, we extracted these classes to seeif classification improved if we trained only on these fiveclasses, instead of all nine, but the low-accuracy classes
330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
CS 6780: Advanced Machine Learning Final Project
1 2 3 4 5 6 7 8 90
0.2
0.4
0.6
0.8
1
Class
Accura
cyP(Y
=y|Y
=y)
RF SVM NN
Figure 6. Performance P (Y = y|Y = y) of all methods in eachclass. Y is the true class and Y is the predicted class.
522 1 2 0 1 31 18 24 55
61 10879 3066 894 61 110 312 71 90
7 1065 2748 239 1 11 149 22 1
1 57 47 817 0 8 23 0 0
3 12 0 13 1937 2 13 2 7
145 27 12 64 6 10043 167 188 143
42 69 58 14 6 109 1240 25 15
316 26 29 2 4 169 212 5862 165
330 14 7 2 4 123 28 92 3264
True classes
Pred
icte
d cl
asse
s
Randon Forest
1 2 3 4 5 6 7 8 9
1
2
3
4
5
6
7
8
90
2000
4000
6000
8000
10000
Figure 7. Random forest confusion matrix during cross valida-tion.
still posed issues; similarly, we trained exclusively on Class2 and Class 3, but Class 2 and 3 were still difficult to sepa-rate.
In order to have a clearer visual of what our data looks look,we used the t-distributed stochastic neighbor embedding (t-SNE) algorithm, an effective visualization tool to reduceour 93 dimensions to a 2-d plot (Van der Maaten & Hinton,2008). The resulting figure indicated that our difficulty inclassifying Class 2 and Class 3 were not functions of ourlearning algorithms, but a function of the highly non-linear,difficult-to-separate nature of these two classes. While theother classes have some visual separation in the t-SNE plot(Figure 11), Classes 2 and 3, and to some extent, Class 4,share almost the same space and characteristics. Runningthe t-SNE algorithm again on only Classes 2 and 3, thesimilarity is even more apparent (Figure 12).
201 1 1 0 0 23 35 70 27
124 10906 4463 1369 52 144 394 124 134
5 834 1170 177 1 30 161 25 10
1 33 23 268 0 9 8 0 1
13 73 12 21 1957 4 14 7 10
212 75 48 164 4 9831 230 376 221
56 93 89 29 0 126 1073 52 13
366 80 111 8 5 268 197 5501 311
449 55 52 9 1 171 50 131 3013
True classes
Pred
icte
d cl
asse
s
SVM
1 2 3 4 5 6 7 8 9
1
2
3
4
5
6
7
8
90
2000
4000
6000
8000
10000
Figure 8. SVM confusion matrix during cross validation.
736 10 2 3 5 71 64 118 129
29 9629 2331 641 29 54 163 45 43
12 2089 3310 351 8 20 180 22 4
1 187 137 941 1 21 20 2 1
3 28 4 21 1960 2 14 3 7
102 34 12 52 6 9926 129 124 116
66 106 143 30 8 168 1477 65 28
200 34 21 5 3 195 101 5802 147
278 33 9 1 0 149 14 105 3265
True classes
Pred
icte
d cl
asse
s
Neural Networks
1 2 3 4 5 6 7 8 9
1
2
3
4
5
6
7
8
90
2000
4000
6000
8000
Figure 9. Neural network confusion matrix during cross valida-tion.
In light of this information, we then moved forward byaggregating the predictions of our three models, a small,simple ensemble approach in order to utilize our existingmodel strategies such that each model has a vote with anassigned weight. We tried two different cases. One, wegave the random forest model and neural network modelpredictions a weight of 0.4 each, and the SVM model aweight of 0.2. The second, we left the SVM model outand split the weight evenly between the random forest andneural network models. The resulting confusion matricesfor these two models using our 5-fold cross validation arepresented in Figures 13 and 14.
In including the model results of SVM, performance incross validation (79.5%) did not improve over previoussingle models random forest (80.4%) and neural networks
440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
CS 6780: Advanced Machine Learning Final Project
1 2 3 4 5 6 7 8 90
0.2
0.4
0.6
0.8
1
Class
Accura
cyP(Y
=y|Y
=y)
RF SVM NN
Figure 10. Performance P (Y = y|Y = y) of all methods foreach class.
123456789
Figure 11. t-SNE plot for 93 dimensions in 2-D for all 9 classes.
(79.8%).
The model with equal votes from the random forest andneural network models, however, saw a 1% increase overeither of those models alone. At this point, given the highlylinearly unseparable nature of Class 2 and 3 in the trainingdata, we moved on to re-train the models on the entire train-ing set in order to run final test results.
5. ResultsOn our test set, we tried the random forest, SVM, neuralnetwork models, as well as our two aggregating models de-scribed above. In addition to seeking the results of modelperformance between the five models, we were also inter-ested to see if the chosen training set versus test set split
23
Figure 12. t-SNE plot for 93 dimensions in 2-D for Class 2 andClass 3.
599 3 2 1 1 31 32 57 63
69 10968 3698 1022 50 95 281 81 71
2 927 2115 183 4 17 150 10 5
0 61 35 731 0 8 12 0 0
4 25 2 20 1957 4 11 4 8
120 24 14 63 3 9990 149 143 117
43 88 76 21 1 130 1381 33 20
251 33 18 4 4 196 127 5878 171
339 21 9 0 0 135 19 80 3285
True classes
Pred
icte
d cl
asse
s
RF + NN + SVM
1 2 3 4 5 6 7 8 9
1
2
3
4
5
6
7
8
90
2000
4000
6000
8000
10000
Figure 13. Weighted 3-model aggregate (random forest, neuralnetworks, and SVM) confusion matrix during cross validation.
of 75-25 was reasonable, as well as if our cross validationscheme was representative of the results on our test set.
For this reason, we again present the confusion matricesfor the two aggregate models, but this time with resultsfrom the test set, in Figures 15 and 16. These can be com-pared with the confusion matrices from the aggregate mod-els from the training set. All of the results are summarizedin Table 1.
We see close agreement between the cross validation per-centages and the test set percentages. Overall, the RF+NN
550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604
605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
CS 6780: Advanced Machine Learning Final Project
735 9 3 2 2 55 56 83 95
43 10471 2748 754 37 73 208 58 55
6 1376 3004 278 7 16 158 17 4
0 111 80 912 1 17 22 1 1
4 23 1 18 1960 3 8 1 7
95 20 11 55 5 10004 126 127 116
61 91 102 24 6 144 1474 54 26
194 26 16 2 2 166 96 5852 145
289 23 4 0 0 128 14 93 3291
True classes
Pred
icte
d cl
asse
sRF + NN
1 2 3 4 5 6 7 8 9
1
2
3
4
5
6
7
8
90
2000
4000
6000
8000
10000
Figure 14. Weighted 2-model aggregate (random forest, neuralnetworks) confusion matrix during cross validation.
Table 1. Cross-validation and test accuracies for all models.
RF SVM NNRFNNSVM
RFNN
CV acc. (%) 80.4 73.1 79.8 79.5 81.2Test acc. (%) 80.6 73.3 79.7 79.4 81.1
model outperformed the other four models in the CV stage(Stages 1, 2) and test stage (Stage 4). SVM performance isconsistently lower than the other models, and including itin the aggregate model appears to worsen the performanceof either RF or NN alone.
219 0 0 0 0 9 8 16 12
20 3620 1312 327 31 29 91 24 30
0 280 666 61 0 5 41 1 0
0 21 9 221 0 2 7 0 1
3 6 0 5 686 1 5 3 0
32 9 1 23 0 3322 48 53 40
11 19 38 4 0 42 433 8 2
91 8 6 4 2 75 36 2036 45
124 8 3 1 0 44 8 37 1085
True classes
Pred
icte
d cl
asse
s
RF + NN + SVM (Test set)
1 2 3 4 5 6 7 8 9
1
2
3
4
5
6
7
8
90
500
1000
1500
2000
2500
3000
3500
Figure 15. Confusion matrix for weighted RF+NN+SVM with re-sults from the test set.
258 1 0 0 0 21 18 30 25
12 3476 1011 229 23 21 66 18 22
0 390 943 99 2 2 36 3 1
0 48 25 285 0 3 9 0 0
3 7 0 2 684 0 4 2 0
30 10 2 20 1 3332 47 45 39
13 26 46 6 5 43 461 19 4
75 7 3 4 3 68 29 2028 39
109 6 5 1 1 39 7 33 1085
True classes
Pred
icte
d cl
asse
s
RF + NN (Test set)
1 2 3 4 5 6 7 8 9
1
2
3
4
5
6
7
8
90
500
1000
1500
2000
2500
3000
Figure 16. Confusion matrix for weighted RF+NN with resultsfrom the test set.
6. Conclusions and Future WorkWith the use of high-dimensional visual tools such as t-SNE and the resources of many machine learning algo-rithms, we were able to classify these products but onlyto a certain limit. It appears that Class 2 and Class 3 havevery similar characteristics, and, due to the nature of thisKaggle competition, we do not know why this may be thecase.
For example, it is possible that these are inherently twovery similar categories that share many characteristics andwould be difficult even for a human to separate. It is alsopossible that these are in fact two very distinct categoriesthat happen to share many features, though it would bereadily apparent to an individual that these are separate cat-egories. Unfortunately, all of the features are only given nu-merical descriptions (1 to 93), and similarly the classes areonly numbered (1 to 9), and there is no intention on elabo-rating on this information, and we have no way of knowingif either of these scenarios is the case.
Because separating the data through additional data pre-processing or more extensive learning was not a route wesaw coming to fruition, we instead chose to move forwardby aggregating the predictions of our best performing mod-els. This led to a slight improvement in performance, bothon the test set and validation set. We saw that our resultswere consistent between our cross-validation scheme andour test set split. This affirms that the 5-fold cross valida-tion and 75-25 split were reasonable to use given the sizeof our data set.
This data set emphasized the importance to understand thenature of the data. Naive Bayes, for example, did poorly
660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
CS 6780: Advanced Machine Learning Final Project
due to the high number of features and poor basis in theassumption of independence between those features. Thelinear SVM struggled compared to the random forest andneural network models since the data is highly non-linear.We believe that further work using the random forests andneural network could lead to better results, especially if alarger ensemble of these types of models were used.
ReferencesBreiman, Leo. Bagging predictors. Machine Learning, 24
(2):123–140, 1996.
Breiman, Leo. Random forests. Machine Learning, 45(1):5–32, 2001.
Chollet, Francois. Keras documentation - theano-baseddeep learning library, 2015. URL http://keras.io/.
Cristianini, N. and Shawe-Taylor, J. An introduction tosupport vector machines and other kernel-based learn-ing methods. Cambridge University Press, 2000. ISBN9780521780193. URL http://books.google.com/books?id=B-Y88GdO1yYC.
Daniel Nouri. Using convolutional neural nets to de-tect facial keypoints tutorial, 2014. URL http://danielnouri.org/notes/category/programming/.
Kaggle. Kaggle: Otto group product classificationchallenge - classify products into the correct category,March 2015. URL https://www.kaggle.com/c/otto-group-product-classification-challenge.
LISA lab, University of Montreal. Theano 0.7 documen-tation - python library for defining, optimizing, andevaluating mathematical expressions involving multi-dimensional arrays efficiently, 2015. URL http://deeplearning.net/software/theano/.
Murphy, Kevin P. Machine learning: a probabilistic per-spective. MIT press, 2012.
Pedregosa, F., Varoquaux, G., Gramfort, A., Michel, V.,Thirion, B., Grisel, O., Blondel, M., Prettenhofer, P.,Weiss, R., Dubourg, V., Vanderplas, J., Passos, A., Cour-napeau, D., Brucher, M., Perrot, M., and Duchesnay, E.Scikit-learn: Machine learning in python. Journal ofMachine Learning Research, 12:2825–2830, 2011.
Shalev-Shwartz, Shai and Ben-David, Shai. UnderstandingMachine Learning: From Theory to Algorithms. Cam-bridge University Press, 2014.
Van der Maaten, Laurens and Hinton, Geoffrey. Visualizingdata using t-sne. Journal of Machine Learning Research,9(2579-2605):85, 2008.