You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Aseem Bansal (JIRA)" <ji...@apache.org> on 2017/02/03 14:19:51 UTC

[jira] [Created] (SPARK-19449) Inconsistent results between ml package RandomForestClassificationModel and mllib package RandomForestModel

Aseem Bansal created SPARK-19449:
------------------------------------

             Summary: Inconsistent results between ml package RandomForestClassificationModel and mllib package RandomForestModel
                 Key: SPARK-19449
                 URL: https://issues.apache.org/jira/browse/SPARK-19449
             Project: Spark
          Issue Type: Bug
          Components: ML, MLlib
    Affects Versions: 2.1.0
            Reporter: Aseem Bansal


I worked on some code to convert ml package RandomForestClassificationModel to mllib package RandomForestModel. It was needed because we need to make predictions on the order of ms. I found that the results are inconsistent although the underlying DecisionTreeModel are exactly the same. 

The below code can be used to reproduce the issue. 

{noformat}
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.*;
import org.apache.spark.ml.linalg.*;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.tree.model.RandomForestModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Enumeration;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

abstract class Predictor {
    abstract double predict(Vector vector);
}

public class MainConvertModels {

    public static final int seed = 42;

    public static void main(String[] args) {

        int numRows = 1000;
        int numFeatures = 3;
        int numClasses = 2;

        double trainFraction = 0.8;
        double testFraction = 0.2;


        SparkSession spark = SparkSession.builder()
                .appName("conversion app")
                .master("local")
                .getOrCreate();

//        Dataset<Row> data = getData(spark, "libsvm", "/opt/spark2/data/mllib/sample_libsvm_data.txt");
        Dataset<Row> data = getDummyData(spark, numRows, numFeatures, numClasses);

        Dataset<Row>[] splits = data.randomSplit(new double[]{trainFraction, testFraction}, seed);
        Dataset<Row> trainingData = splits[0];
        Dataset<Row> testData = splits[1];
        testData.cache();

        List<Double> labels = getLabels(testData);
        List<DenseVector> features = getFeatures(testData);

        DecisionTreeClassifier classifier1 = new DecisionTreeClassifier();
        DecisionTreeClassificationModel model1 = classifier1.fit(trainingData);
        final DecisionTreeModel convertedModel1 = convertDecisionTreeModel(model1, Algo.Classification());


        RandomForestClassifier classifier = new RandomForestClassifier();
        RandomForestClassificationModel model2 = classifier.fit(trainingData);
        final RandomForestModel convertedModel2 = convertRandomForestModel(model2);


        LogisticRegression lr = new LogisticRegression();
        LogisticRegressionModel model3 = lr.fit(trainingData);
        final org.apache.spark.mllib.classification.LogisticRegressionModel convertedModel3 = convertLogisticRegressionModel(model3);


        System.out.println(

                "****** DecisionTreeClassifier\n" +
                        "** Original **" + getInfo(model1, testData) + "\n" +
                        "** New      **" + getInfo(new Predictor() {
                    double predict(Vector vector) {return convertedModel1.predict(vector);}
                }, labels, features) + "\n" +

                        "\n" +

                "****** RandomForestClassifier\n" +
                        "** Original **" + getInfo(model2, testData) + "\n" +
                        "** New      **" + getInfo(new Predictor() {double predict(Vector vector) {return convertedModel2.predict(vector);}}, labels, features) + "\n" +

                        "\n" +

                "****** LogisticRegression\n" +
                        "** Original **" + getInfo(model3, testData) + "\n" +
                        "** New      **" + getInfo(new Predictor() {double predict(Vector vector) { return convertedModel3.predict(vector);}}, labels, features) + "\n" +

                        "");
    }

    static Dataset<Row> getData(SparkSession spark, String format, String location) {

        return spark.read()
                .format(format)
                .load(location);
    }

    static Dataset<Row> getDummyData(SparkSession spark, int numberRows, int numberFeatures, int labelUpperBound) {

        StructType schema = new StructType(new StructField[]{
                new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
                new StructField("features", new VectorUDT(), false, Metadata.empty())
        });

        double[][] vectors = prepareData(numberRows, numberFeatures);

        Random random = new Random(seed);
        List<Row> dataTest = new ArrayList<>();
        for (double[] vector : vectors) {
            double label = (double) random.nextInt(2);
            dataTest.add(RowFactory.create(label, Vectors.dense(vector)));
        }

        return spark.createDataFrame(dataTest, schema);
    }

    static double[][] prepareData(int numRows, int numFeatures) {

        Random random = new Random(seed);

        double[][] result = new double[numRows][numFeatures];

        for (int row = 0; row < numRows; row++) {
            for (int feature = 0; feature < numFeatures; feature++) {
                result[row][feature] = random.nextDouble();
            }
        }

        return result;
    }

    static String getInfo(Predictor predictor,
                          List<Double> labels,
                          List<DenseVector> features) {

        Long startTime = System.currentTimeMillis();

        List<Double> predictions = new ArrayList<>();
        for (DenseVector feature : features) {
            predictions.add(predictor.predict(feature));
        }
        return getInfo(startTime, labels, predictions);
    }

    static List<Double> getLabels(Dataset<Row> testData) {

        List<Double> labels = new ArrayList<>();
        List<DenseVector> vectors = new ArrayList<>();
        for (Row row : testData.collectAsList()) {
            vectors.add(new DenseVector(((org.apache.spark.ml.linalg.Vector) row.get(1)).toArray()));
            labels.add(row.getDouble(0));
        }
        return labels;
    }

    static List<DenseVector> getFeatures(Dataset<Row> testData) {

        List<DenseVector> features = new ArrayList<>();
        for (Row row : testData.collectAsList()) {
            features.add(new DenseVector(((org.apache.spark.ml.linalg.Vector) row.get(1)).toArray()));
        }
        return features;
    }

    static String getInfo(Transformer model, Dataset<Row> testData) {

        Dataset<Row> predictions = model.transform(testData);
        predictions.cache();

        Dataset<Row> correctPredictions = predictions.filter("label == prediction");
        correctPredictions.cache();

        Dataset<Row> incorrectPredictions = predictions.filter("label != prediction");
        incorrectPredictions.cache();

        Long truePositives = correctPredictions.filter("prediction == 1.0").count();
        Long trueNegatives = correctPredictions.filter("prediction == 0.0").count();

        Long falsePositives = incorrectPredictions.filter("prediction == 1.0").count();
        Long falseNegatives = incorrectPredictions.filter("prediction == 0.0").count();

        return getInfo(null, truePositives, trueNegatives, falsePositives, falseNegatives);
    }

    static String getInfo(Long startTime, List<Double> labels, List<Double> predictions) {

        Long endTime = System.currentTimeMillis();

        if (labels.size() != predictions.size()) {
            throw new RuntimeException("labels size is " + labels.size() +
                    " but predictions size is " + predictions.size());
        }

        Long truePositives = 0L;
        Long trueNegatives = 0L;
        Long falsePositives = 0L;
        Long falseNegatives = 0L;

        for (int i = 0; i < labels.size(); i++) {
            double label = labels.get(i);
            double prediction = predictions.get(i);

            if (label == prediction) {
                if (prediction == 1.0) {
                    truePositives += 1;
                } else {
                    trueNegatives += 1;
                }
            } else {
                if (prediction == 1.0) {
                    falsePositives += 1;
                } else {
                    falseNegatives += 1;
                }
            }
        }
        return getInfo(endTime - startTime, truePositives, trueNegatives, falsePositives, falseNegatives);
    }

    static double ratio(Long numerator, Long denominator) {

        if (numerator == 0 || denominator == 0) {
            return 0;
        }
        return ((double) numerator) / denominator;
    }

    static String getInfo(Long timeTakenMilliseconds, Long truePositives, Long trueNegatives, Long falsePositives,
                          Long falseNegatives) {

        Long testDataCount = truePositives + trueNegatives + falsePositives + falseNegatives;
        double accuracy = ratio(truePositives + trueNegatives, testDataCount);
        double precision = ratio(truePositives, truePositives + falsePositives);
        double recall = ratio(truePositives, truePositives + falseNegatives);

        String last = "";
        if (timeTakenMilliseconds != null) {
            last = ", Average time taken (ms) " + ratio(timeTakenMilliseconds, testDataCount);
        }

        return (
                "true positives " + truePositives
                        + ", true negatives " + trueNegatives
                        + ", false positives " + falsePositives
                        + ", false negatives " + falseNegatives
                        + ", total " + testDataCount
                        + "\n\t accuracy " + accuracy
                        + ", precision " + precision
                        + ", recall " + recall
                        + last
        );
    }


    static DecisionTreeModel convertDecisionTreeModel(org.apache.spark.ml.tree.DecisionTreeModel model,
                                                      Enumeration.Value algo) {
        return new DecisionTreeModel(model.rootNode().toOld(1), algo);
    }

    static RandomForestModel convertRandomForestModel(org.apache.spark.ml.tree.TreeEnsembleModel model) {

        Enumeration.Value algo;
        if (model instanceof RandomForestRegressionModel) {
            algo = Algo.Regression();
        } else {
            algo = Algo.Classification();
        }
        Object[] decisionTreeModels = model.trees();
        DecisionTreeModel[] convertedDecisionTreeModels = new DecisionTreeModel[decisionTreeModels.length];
        for (int i = 0; i < decisionTreeModels.length; i++) {

            org.apache.spark.ml.tree.DecisionTreeModel originalModel = (org.apache.spark.ml.tree.DecisionTreeModel) decisionTreeModels[i];
            DecisionTreeModel convertedModel = convertDecisionTreeModel(originalModel, algo);

            convertedDecisionTreeModels[i] = convertedModel;
        }
        RandomForestModel result = new RandomForestModel(algo, convertedDecisionTreeModels);

        return result;
    }

    static org.apache.spark.mllib.classification.LogisticRegressionModel convertLogisticRegressionModel(LogisticRegressionModel model) {


        org.apache.spark.mllib.classification.LogisticRegressionModel convertedModel;

        try {
            convertedModel = new org.apache.spark.mllib.classification.LogisticRegressionModel(
                    new DenseVector(model.coefficients().toArray()),
                    model.intercept(),
                    model.numFeatures(),
                    model.numClasses()
            );
        } catch (Exception e) {
            //Should be SparkException but that does not compile.
            // Raised when we have Multinomial Linear Regression
            // Cannot check as the relevant variable is private
            Vector coefficients = matrixToVector(model.coefficientMatrix());

            for (double v : coefficients.toArray()) {
                System.out.println(v);
            }

            System.out.println("**********");
            for (double v : model.interceptVector().toDense().values()) {
                System.out.println(v);
            }

            convertedModel = new org.apache.spark.mllib.classification.LogisticRegressionModel(
                    coefficients,
                    model.interceptVector().toDense().values()[0], //TODO fix this.
                    model.numFeatures(),
                    model.numClasses()
            );
        }

        convertedModel.setThreshold(model.getThreshold());

        return convertedModel;
    }

    static Vector matrixToVector(Matrix matrix) {

        Vector result;
        if (matrix instanceof DenseMatrix) {
            result = new DenseVector(((DenseMatrix) matrix).values());
        } else {
            SparseMatrix _matrix = (SparseMatrix) matrix;

            result = new org.apache.spark.mllib.linalg.SparseVector(
                    _matrix.numActives(),
                    _matrix.rowIndices(),
                    _matrix.values()
            );
        }
        return result;
    }
}
{noformat}

The output looks like the below. In the below the Original refers to ml package version and New refers to mllib package version. 

- I converted the mllib version Decision tree to ml version Decision tree. Gave both versions same input and I received the exact same output. 
- Then converted the mllib version Random Forest to ml version Random Forest giving both the same underlying Decision trees (using the previoeus conversion method). Gave both versions same input but I received different output. 

{noformat}
****** DecisionTreeClassifier
** Original **true positives 8128, true negatives 1923, false positives 7942, false negatives 1897, total 19890
     accuracy 0.5053293112116641, precision 0.5057871810827629, recall 0.8107730673316709
** New      **true positives 8128, true negatives 1923, false positives 7942, false negatives 1897, total 19890
     accuracy 0.5053293112116641, precision 0.5057871810827629, recall 0.8107730673316709, Average time taken (ms) 0.001558572146807441

****** RandomForestClassifier
** Original **true positives 3940, true negatives 5915, false positives 3950, false negatives 6085, total 19890
     accuracy 0.49547511312217196, precision 0.49936628643852976, recall 0.39301745635910224
** New      **true positives 2461, true negatives 7350, false positives 2515, false negatives 7564, total 19890
     accuracy 0.4932629462041227, precision 0.4945739549839228, recall 0.2454862842892768, Average time taken (ms) 0.01085972850678733

****** LogisticRegression
** Original **true positives 6728, true negatives 3321, false positives 6544, false negatives 3297, total 19890
     accuracy 0.5052287581699346, precision 0.5069318866787221, recall 0.6711221945137157
** New      **true positives 6728, true negatives 3321, false positives 6544, false negatives 3297, total 19890
     accuracy 0.5052287581699346, precision 0.5069318866787221, recall 0.6711221945137157, Average time taken (ms) 0.001558572146807441
{noformat}



--
This message was sent by Atlassian JIRA
(v6.3.15#6346)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org