You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "zhengruifeng (JIRA)" <ji...@apache.org> on 2019/05/08 10:35:00 UTC

[jira] [Resolved] (SPARK-16872) Include Gaussian Naive Bayes Classifier

     [ https://issues.apache.org/jira/browse/SPARK-16872?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]

zhengruifeng resolved SPARK-16872.
----------------------------------
    Resolution: Not A Problem

> Include Gaussian Naive Bayes Classifier
> ---------------------------------------
>
>                 Key: SPARK-16872
>                 URL: https://issues.apache.org/jira/browse/SPARK-16872
>             Project: Spark
>          Issue Type: New Feature
>          Components: ML
>            Reporter: zhengruifeng
>            Assignee: zhengruifeng
>            Priority: Major
>
> I implemented Gaussian NB according to scikit-learn's {{GaussianNB}}.
> In GaussianNB model, the {{theta}} matrix is used to store means and there is a extra {{sigma}} matrix storing the variance of each feature.
> GaussianNB in spark
> {code}
> scala> import org.apache.spark.ml.classification.GaussianNaiveBayes
> import org.apache.spark.ml.classification.GaussianNaiveBayes
> scala> val path = "/Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt"
> path: String = /Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt
> scala> val data = spark.read.format("libsvm").load(path).persist()
> data: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector]
> scala> val gnb = new GaussianNaiveBayes()
> gnb: org.apache.spark.ml.classification.GaussianNaiveBayes = gnb_54c50467306c
> scala> val model = gnb.fit(data)
> 17/01/03 14:25:48 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: training: numPartitions=1 storageLevel=StorageLevel(1 replicas)
> 17/01/03 14:25:48 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {}
> 17/01/03 14:25:49 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {"numFeatures":4}
> 17/01/03 14:25:49 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: {"numClasses":3}
> 17/01/03 14:25:49 INFO Instrumentation: GaussianNaiveBayes-gnb_54c50467306c-720112035-1: training finished
> model: org.apache.spark.ml.classification.GaussianNaiveBayesModel = GaussianNaiveBayesModel (uid=gnb_54c50467306c) with 3 classes
> scala> model.pi
> res0: org.apache.spark.ml.linalg.Vector = [-1.0986122886681098,-1.0986122886681098,-1.0986122886681098]
> scala> model.pi.toArray.map(math.exp)
> res1: Array[Double] = Array(0.3333333333333333, 0.3333333333333333, 0.3333333333333333)
> scala> model.theta
> res2: org.apache.spark.ml.linalg.Matrix =
> 0.2711110067018001   -0.18833335400000006  0.5430507200000001   0.605000046
> -0.6077777799999998  0.181666672           -0.8427117400000006  -0.8800001399999998
> -0.0911111425964     -0.3583333580000001   0.105084738          0.021666701507102017
> scala> model.sigma
> res3: org.apache.spark.ml.linalg.Matrix =
> 0.1223012510889361   0.07078051983960698  0.03430000595243976   0.051336071297393815
> 0.03758145300924998  0.09880280046403413  0.003390296940069426  0.007822241779598893
> 0.08058763609659315  0.06701386661293329  0.024866409227781675  0.02661391644759426
> scala> model.transform(data).select("probability").take(10)
> [rdd_68_0]
> res4: Array[org.apache.spark.sql.Row] = Array([[1.0627410543476422E-21,0.9999999999999938,6.2765233965353945E-15]], [[7.254521422345374E-26,1.0,1.3849442153180895E-18]], [[1.9629244119173135E-24,0.9999999999999998,1.9424765181237926E-16]], [[6.061218297948492E-22,0.9999999999999902,9.853216073401884E-15]], [[0.9972225671942837,8.844241161578932E-165,0.002777432805716399]], [[5.361683970373604E-26,1.0,2.3004604508982183E-18]], [[0.01062850630038623,3.3102617689978775E-100,0.9893714936996136]], [[1.9297314618271785E-4,2.124922209137708E-71,0.9998070268538172]], [[3.118816393732361E-27,1.0,6.5310299615983584E-21]], [[0.9999926009854522,8.734773657627494E-206,7.399014547943611E-6]])
> scala> model.transform(data).select("prediction").take(10)
> [rdd_68_0]
> res5: Array[org.apache.spark.sql.Row] = Array([1.0], [1.0], [1.0], [1.0], [0.0], [1.0], [2.0], [2.0], [1.0], [0.0])
> {code}
> GaussianNB in scikit-learn
> {code}
> import numpy as np
> from sklearn.naive_bayes import GaussianNB
> from sklearn.datasets import load_svmlight_file
> path = '/Users/zrf/.dev/spark-2.1.0-bin-hadoop2.7/data/mllib/sample_multiclass_classification_data.txt'
> X, y = load_svmlight_file(path)
> X = X.toarray()
> clf = GaussianNB()
> clf.fit(X, y)
> >>> clf.class_prior_
> array([ 0.33333333,  0.33333333,  0.33333333])
> >>> clf.theta_
> array([[ 0.27111101, -0.18833335,  0.54305072,  0.60500005],
>        [-0.60777778,  0.18166667, -0.84271174, -0.88000014],
>        [-0.09111114, -0.35833336,  0.10508474,  0.0216667 ]])
>        
> >>> clf.sigma_
> array([[ 0.12230125,  0.07078052,  0.03430001,  0.05133607],
>        [ 0.03758145,  0.0988028 ,  0.0033903 ,  0.00782224],
>        [ 0.08058764,  0.06701387,  0.02486641,  0.02661392]])
>        
> >>> clf.predict_proba(X)[:10]
> array([[  1.06274105e-021,   1.00000000e+000,   6.27652340e-015],
>        [  7.25452142e-026,   1.00000000e+000,   1.38494422e-018],
>        [  1.96292441e-024,   1.00000000e+000,   1.94247652e-016],
>        [  6.06121830e-022,   1.00000000e+000,   9.85321607e-015],
>        [  9.97222567e-001,   8.84424116e-165,   2.77743281e-003],
>        [  5.36168397e-026,   1.00000000e+000,   2.30046045e-018],
>        [  1.06285063e-002,   3.31026177e-100,   9.89371494e-001],
>        [  1.92973146e-004,   2.12492221e-071,   9.99807027e-001],
>        [  3.11881639e-027,   1.00000000e+000,   6.53102996e-021],
>        [  9.99992601e-001,   8.73477366e-206,   7.39901455e-006]])
>        
> >>> clf.predict(X)[:10]
> array([ 1.,  1.,  1.,  1.,  0.,  1.,  2.,  2.,  1.,  0.])
> {code}



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org