You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by lr...@apache.org on 2015/12/03 19:46:19 UTC
[48/78] [abbrv] [partial] incubator-systemml git commit: Move files
to new package folder structure
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/MLOutput.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/MLOutput.java b/src/main/java/com/ibm/bi/dml/api/MLOutput.java
deleted file mode 100644
index 82f74a0..0000000
--- a/src/main/java/com/ibm/bi/dml/api/MLOutput.java
+++ /dev/null
@@ -1,401 +0,0 @@
-/**
- * (C) Copyright IBM Corp. 2010, 2015
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-
-package com.ibm.bi.dml.api;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map.Entry;
-
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.PairFlatMapFunction;
-import org.apache.spark.mllib.linalg.DenseVector;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-
-import scala.Tuple2;
-
-import com.ibm.bi.dml.runtime.DMLRuntimeException;
-import com.ibm.bi.dml.runtime.instructions.spark.functions.GetMLBlock;
-import com.ibm.bi.dml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
-import com.ibm.bi.dml.runtime.matrix.MatrixCharacteristics;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixIndexes;
-import com.ibm.bi.dml.runtime.util.UtilFunctions;
-
-/**
- * This is a simple container object that returns the output of execute from MLContext
- *
- */
-public class MLOutput {
-
-
-
- HashMap<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs;
- private HashMap<String, MatrixCharacteristics> _outMetadata = null;
-
- public MLOutput(HashMap<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> outputs, HashMap<String, MatrixCharacteristics> outMetadata) {
- this._outputs = outputs;
- this._outMetadata = outMetadata;
- }
-
- public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockedRDD(String varName) throws DMLRuntimeException {
- if(_outputs.containsKey(varName)) {
- return _outputs.get(varName);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
-
- public MatrixCharacteristics getMatrixCharacteristics(String varName) throws DMLRuntimeException {
- if(_outputs.containsKey(varName)) {
- return _outMetadata.get(varName);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
-
- /**
- * Note, the output DataFrame has an additional column ID.
- * An easy way to get DataFrame without ID is by df.sort("ID").drop("ID")
- * @param sqlContext
- * @param varName
- * @return
- * @throws DMLRuntimeException
- */
- public DataFrame getDF(SQLContext sqlContext, String varName) throws DMLRuntimeException {
- JavaPairRDD<MatrixIndexes,MatrixBlock> rdd = getBinaryBlockedRDD(varName);
- if(rdd != null) {
- MatrixCharacteristics mc = _outMetadata.get(varName);
- return RDDConverterUtilsExt.binaryBlockToDataFrame(rdd, mc, sqlContext);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
-
- /**
- *
- * @param sqlContext
- * @param varName
- * @param outputVector if true, returns DataFrame with two column: ID and org.apache.spark.mllib.linalg.Vector
- * @return
- * @throws DMLRuntimeException
- */
- public DataFrame getDF(SQLContext sqlContext, String varName, boolean outputVector) throws DMLRuntimeException {
- if(outputVector) {
- JavaPairRDD<MatrixIndexes,MatrixBlock> rdd = getBinaryBlockedRDD(varName);
- if(rdd != null) {
- MatrixCharacteristics mc = _outMetadata.get(varName);
- return RDDConverterUtilsExt.binaryBlockToVectorDataFrame(rdd, mc, sqlContext);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
- else {
- return getDF(sqlContext, varName);
- }
-
- }
-
- /**
- * This methods improves the performance of MLPipeline wrappers.
- * @param sqlContext
- * @param varName
- * @param range range is inclusive
- * @return
- * @throws DMLRuntimeException
- */
- public DataFrame getDF(SQLContext sqlContext, String varName, HashMap<String, Tuple2<Long, Long>> range) throws DMLRuntimeException {
- JavaPairRDD<MatrixIndexes,MatrixBlock> binaryBlockRDD = getBinaryBlockedRDD(varName);
- if(binaryBlockRDD == null) {
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
- MatrixCharacteristics mc = _outMetadata.get(varName);
- long rlen = mc.getRows(); long clen = mc.getCols();
- int brlen = mc.getRowsPerBlock(); int bclen = mc.getColsPerBlock();
-
- ArrayList<Tuple2<String, Tuple2<Long, Long>>> alRange = new ArrayList<Tuple2<String, Tuple2<Long, Long>>>();
- for(Entry<String, Tuple2<Long, Long>> e : range.entrySet()) {
- alRange.add(new Tuple2<String, Tuple2<Long,Long>>(e.getKey(), e.getValue()));
- }
-
- // Very expensive operation here: groupByKey (where number of keys might be too large)
- JavaRDD<Row> rowsRDD = binaryBlockRDD.flatMapToPair(new ProjectRows(rlen, clen, brlen, bclen))
- .groupByKey().map(new ConvertDoubleArrayToRangeRows(clen, bclen, alRange));
-
- int numColumns = (int) clen;
- if(numColumns <= 0) {
- throw new DMLRuntimeException("Output dimensions unknown after executing the script and hence cannot create the dataframe");
- }
-
- List<StructField> fields = new ArrayList<StructField>();
- // LongTypes throw an error: java.lang.Double incompatible with java.lang.Long
- fields.add(DataTypes.createStructField("ID", DataTypes.DoubleType, false));
- for(int k = 0; k < alRange.size(); k++) {
- String colName = alRange.get(k)._1;
- long low = alRange.get(k)._2._1;
- long high = alRange.get(k)._2._2;
- if(low != high)
- fields.add(DataTypes.createStructField(colName, new VectorUDT(), false));
- else
- fields.add(DataTypes.createStructField(colName, DataTypes.DoubleType, false));
- }
-
- // This will cause infinite recursion due to bug in Spark
- // https://issues.apache.org/jira/browse/SPARK-6999
- // return sqlContext.createDataFrame(rowsRDD, colNames); // where ArrayList<String> colNames
- return sqlContext.createDataFrame(rowsRDD.rdd(), DataTypes.createStructType(fields));
-
- }
-
- public JavaRDD<String> getStringRDD(String varName, String format) throws DMLRuntimeException {
- if(format.compareTo("text") == 0) {
- JavaPairRDD<MatrixIndexes, MatrixBlock> binaryRDD = getBinaryBlockedRDD(varName);
- MatrixCharacteristics mcIn = getMatrixCharacteristics(varName);
- return RDDConverterUtilsExt.binaryBlockToStringRDD(binaryRDD, mcIn, format);
- }
-// else if(format.compareTo("csv") == 0) {
-//
-// }
- else {
- throw new DMLRuntimeException("The output format:" + format + " is not implemented yet.");
- }
-
- }
-
- public MLMatrix getMLMatrix(MLContext ml, SQLContext sqlContext, String varName) throws DMLRuntimeException {
- JavaPairRDD<MatrixIndexes,MatrixBlock> rdd = getBinaryBlockedRDD(varName);
- if(rdd != null) {
- MatrixCharacteristics mc = getMatrixCharacteristics(varName);
- StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
- return new MLMatrix(sqlContext.createDataFrame(rdd.map(new GetMLBlock()).rdd(), schema), mc, ml);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
-
-// /**
-// * Experimental: Please use this with caution as it will fail in many corner cases.
-// * @return org.apache.spark.mllib.linalg.distributed.BlockMatrix
-// * @throws DMLRuntimeException
-// */
-// public BlockMatrix getMLLibBlockedMatrix(MLContext ml, SQLContext sqlContext, String varName) throws DMLRuntimeException {
-// return getMLMatrix(ml, sqlContext, varName).toBlockedMatrix();
-// }
-
- public static class ProjectRows implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, Long, Tuple2<Long, Double[]>> {
- private static final long serialVersionUID = -4792573268900472749L;
- long rlen; long clen;
- int brlen; int bclen;
- public ProjectRows(long rlen, long clen, int brlen, int bclen) {
- this.rlen = rlen;
- this.clen = clen;
- this.brlen = brlen;
- this.bclen = bclen;
- }
-
- @Override
- public Iterable<Tuple2<Long, Tuple2<Long, Double[]>>> call(Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception {
- // ------------------------------------------------------------------
- // Compute local block size:
- // Example: For matrix: 1500 X 1100 with block length 1000 X 1000
- // We will have four local block sizes (1000X1000, 1000X100, 500X1000 and 500X1000)
- long blockRowIndex = kv._1.getRowIndex();
- long blockColIndex = kv._1.getColumnIndex();
- int lrlen = UtilFunctions.computeBlockSize(rlen, blockRowIndex, brlen);
- int lclen = UtilFunctions.computeBlockSize(clen, blockColIndex, bclen);
- // ------------------------------------------------------------------
-
- long startRowIndex = (kv._1.getRowIndex()-1) * bclen;
- MatrixBlock blk = kv._2;
- ArrayList<Tuple2<Long, Tuple2<Long, Double[]>>> retVal = new ArrayList<Tuple2<Long,Tuple2<Long,Double[]>>>();
- for(int i = 0; i < lrlen; i++) {
- Double[] partialRow = new Double[lclen];
- for(int j = 0; j < lclen; j++) {
- partialRow[j] = blk.getValue(i, j);
- }
- retVal.add(new Tuple2<Long, Tuple2<Long,Double[]>>(startRowIndex + i, new Tuple2<Long,Double[]>(kv._1.getColumnIndex(), partialRow)));
- }
- return (Iterable<Tuple2<Long, Tuple2<Long, Double[]>>>) retVal;
- }
- }
-
- public static class ConvertDoubleArrayToRows implements Function<Tuple2<Long, Iterable<Tuple2<Long, Double[]>>>, Row> {
- private static final long serialVersionUID = 4441184411670316972L;
-
- int bclen; long clen;
- boolean outputVector;
- public ConvertDoubleArrayToRows(long clen, int bclen, boolean outputVector) {
- this.bclen = bclen;
- this.clen = clen;
- this.outputVector = outputVector;
- }
-
- @Override
- public Row call(Tuple2<Long, Iterable<Tuple2<Long, Double[]>>> arg0)
- throws Exception {
-
- HashMap<Long, Double[]> partialRows = new HashMap<Long, Double[]>();
- int sizeOfPartialRows = 0;
- for(Tuple2<Long, Double[]> kv : arg0._2) {
- partialRows.put(kv._1, kv._2);
- sizeOfPartialRows += kv._2.length;
- }
-
- // Insert first row as row index
- Object[] row = null;
- if(outputVector) {
- row = new Object[2];
- double [] vecVals = new double[sizeOfPartialRows];
-
- for(long columnBlockIndex = 1; columnBlockIndex <= partialRows.size(); columnBlockIndex++) {
- if(partialRows.containsKey(columnBlockIndex)) {
- Double [] array = partialRows.get(columnBlockIndex);
- // ------------------------------------------------------------------
- // Compute local block size:
- int lclen = UtilFunctions.computeBlockSize(clen, columnBlockIndex, bclen);
- // ------------------------------------------------------------------
- if(array.length != lclen) {
- throw new Exception("Incorrect double array provided by ProjectRows");
- }
- for(int i = 0; i < lclen; i++) {
- vecVals[(int) ((columnBlockIndex-1)*bclen + i)] = array[i];
- }
- }
- else {
- throw new Exception("The block for column index " + columnBlockIndex + " is missing. Make sure the last instruction is not returning empty blocks");
- }
- }
-
- long rowIndex = arg0._1;
- row[0] = new Double(rowIndex);
- row[1] = new DenseVector(vecVals); // breeze.util.JavaArrayOps.arrayDToDv(vecVals);
- }
- else {
- row = new Double[sizeOfPartialRows + 1];
- long rowIndex = arg0._1;
- row[0] = new Double(rowIndex);
- for(long columnBlockIndex = 1; columnBlockIndex <= partialRows.size(); columnBlockIndex++) {
- if(partialRows.containsKey(columnBlockIndex)) {
- Double [] array = partialRows.get(columnBlockIndex);
- // ------------------------------------------------------------------
- // Compute local block size:
- int lclen = UtilFunctions.computeBlockSize(clen, columnBlockIndex, bclen);
- // ------------------------------------------------------------------
- if(array.length != lclen) {
- throw new Exception("Incorrect double array provided by ProjectRows");
- }
- for(int i = 0; i < lclen; i++) {
- row[(int) ((columnBlockIndex-1)*bclen + i) + 1] = array[i];
- }
- }
- else {
- throw new Exception("The block for column index " + columnBlockIndex + " is missing. Make sure the last instruction is not returning empty blocks");
- }
- }
- }
- Object[] row_fields = row;
- return RowFactory.create(row_fields);
- }
- }
-
-
- public static class ConvertDoubleArrayToRangeRows implements Function<Tuple2<Long, Iterable<Tuple2<Long, Double[]>>>, Row> {
- private static final long serialVersionUID = 4441184411670316972L;
-
- int bclen; long clen;
- ArrayList<Tuple2<String, Tuple2<Long, Long>>> range;
- public ConvertDoubleArrayToRangeRows(long clen, int bclen, ArrayList<Tuple2<String, Tuple2<Long, Long>>> range) {
- this.bclen = bclen;
- this.clen = clen;
- this.range = range;
- }
-
- @Override
- public Row call(Tuple2<Long, Iterable<Tuple2<Long, Double[]>>> arg0)
- throws Exception {
-
- HashMap<Long, Double[]> partialRows = new HashMap<Long, Double[]>();
- int sizeOfPartialRows = 0;
- for(Tuple2<Long, Double[]> kv : arg0._2) {
- partialRows.put(kv._1, kv._2);
- sizeOfPartialRows += kv._2.length;
- }
-
- // Insert first row as row index
- Object[] row = null;
- row = new Object[range.size() + 1];
-
- double [] vecVals = new double[sizeOfPartialRows];
-
- for(long columnBlockIndex = 1; columnBlockIndex <= partialRows.size(); columnBlockIndex++) {
- if(partialRows.containsKey(columnBlockIndex)) {
- Double [] array = partialRows.get(columnBlockIndex);
- // ------------------------------------------------------------------
- // Compute local block size:
- int lclen = UtilFunctions.computeBlockSize(clen, columnBlockIndex, bclen);
- // ------------------------------------------------------------------
- if(array.length != lclen) {
- throw new Exception("Incorrect double array provided by ProjectRows");
- }
- for(int i = 0; i < lclen; i++) {
- vecVals[(int) ((columnBlockIndex-1)*bclen + i)] = array[i];
- }
- }
- else {
- throw new Exception("The block for column index " + columnBlockIndex + " is missing. Make sure the last instruction is not returning empty blocks");
- }
- }
-
- long rowIndex = arg0._1;
- row[0] = new Double(rowIndex);
-
- int i = 1;
-
- //for(Entry<String, Tuple2<Long, Long>> e : range.entrySet()) {
- for(int k = 0; k < range.size(); k++) {
- long low = range.get(k)._2._1;
- long high = range.get(k)._2._2;
-
- if(high < low) {
- throw new Exception("Incorrect range:" + high + "<" + low);
- }
-
- if(low == high) {
- row[i] = new Double(vecVals[(int) (low-1)]);
- }
- else {
- int lengthOfVector = (int) (high - low + 1);
- double [] tempVector = new double[lengthOfVector];
- for(int j = 0; j < lengthOfVector; j++) {
- tempVector[j] = vecVals[(int) (low + j - 1)];
- }
- row[i] = new DenseVector(tempVector);
- }
-
- i++;
- }
-
- Object[] row_fields = row;
- return RowFactory.create(row_fields);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/jmlc/Connection.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/jmlc/Connection.java b/src/main/java/com/ibm/bi/dml/api/jmlc/Connection.java
deleted file mode 100644
index 57d389f..0000000
--- a/src/main/java/com/ibm/bi/dml/api/jmlc/Connection.java
+++ /dev/null
@@ -1,244 +0,0 @@
-/**
- * (C) Copyright IBM Corp. 2010, 2015
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-
-package com.ibm.bi.dml.api.jmlc;
-
-import java.io.BufferedReader;
-import java.io.ByteArrayInputStream;
-import java.io.FileReader;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.util.HashMap;
-
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-
-import com.ibm.bi.dml.api.DMLException;
-import com.ibm.bi.dml.api.DMLScript;
-import com.ibm.bi.dml.api.DMLScript.RUNTIME_PLATFORM;
-import com.ibm.bi.dml.conf.ConfigurationManager;
-import com.ibm.bi.dml.conf.DMLConfig;
-import com.ibm.bi.dml.hops.OptimizerUtils;
-import com.ibm.bi.dml.hops.rewrite.ProgramRewriter;
-import com.ibm.bi.dml.hops.rewrite.RewriteRemovePersistentReadWrite;
-import com.ibm.bi.dml.parser.AParserWrapper;
-import com.ibm.bi.dml.parser.DMLProgram;
-import com.ibm.bi.dml.parser.DMLTranslator;
-import com.ibm.bi.dml.parser.DataExpression;
-import com.ibm.bi.dml.runtime.DMLRuntimeException;
-import com.ibm.bi.dml.runtime.controlprogram.Program;
-import com.ibm.bi.dml.runtime.controlprogram.caching.CacheableData;
-import com.ibm.bi.dml.runtime.io.MatrixReaderFactory;
-import com.ibm.bi.dml.runtime.io.ReaderTextCell;
-import com.ibm.bi.dml.runtime.matrix.data.InputInfo;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
-import com.ibm.bi.dml.runtime.util.DataConverter;
-
-/**
- * JMLC (Java Machine Learning Connector) API:
- *
- * NOTES:
- * * Currently fused API and implementation in order to reduce complexity.
- * * See SystemTMulticlassSVMScoreTest for an usage example.
- */
-public class Connection
-{
-
- private DMLConfig _conf = null;
-
- /**
- * Connection constructor, starting point for any other JMLC API calls.
- *
- */
- public Connection()
- {
- //setup basic parameters for embedded execution
- DataExpression.REJECT_READ_UNKNOWN_SIZE = false;
- DMLScript.rtplatform = RUNTIME_PLATFORM.SINGLE_NODE;
- OptimizerUtils.PARALLEL_CP_READ_TEXTFORMATS = false;
- OptimizerUtils.PARALLEL_CP_WRITE_TEXTFORMATS = false;
- OptimizerUtils.PARALLEL_CP_READ_BINARYFORMATS = false;
- OptimizerUtils.PARALLEL_CP_WRITE_BINARYFORMATS = false;
- CacheableData.disableCaching();
-
- //create default configuration
- _conf = new DMLConfig();
- ConfigurationManager.setConfig(_conf);
- }
-
- /**
- *
- * @param script
- * @param inputs
- * @param outputs
- * @return
- * @throws DMLException
- */
- public PreparedScript prepareScript( String script, String[] inputs, String[] outputs, boolean parsePyDML)
- throws DMLException
- {
- return prepareScript(script, new HashMap<String,String>(), inputs, outputs, parsePyDML);
- }
-
- /**
- *
- * @param script
- * @param args
- * @param inputs
- * @param outputs
- * @return
- * @throws DMLException
- */
- public PreparedScript prepareScript( String script, HashMap<String, String> args, String[] inputs, String[] outputs, boolean parsePyDML)
- throws DMLException
- {
- //prepare arguments
-
- //simplified compilation chain
- Program rtprog = null;
- try
- {
- //parsing
- AParserWrapper parser = AParserWrapper.createParser(parsePyDML);
- DMLProgram prog = parser.parse(null, script, args);
-
- //language validate
- DMLTranslator dmlt = new DMLTranslator(prog);
- dmlt.liveVariableAnalysis(prog);
- dmlt.validateParseTree(prog);
-
- //hop construct/rewrite
- dmlt.constructHops(prog);
- dmlt.rewriteHopsDAG(prog);
-
- //rewrite persistent reads/writes
- RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs);
- ProgramRewriter rewriter2 = new ProgramRewriter(rewrite);
- rewriter2.rewriteProgramHopDAGs(prog);
-
- //lop construct and runtime prog generation
- dmlt.constructLops(prog);
- rtprog = prog.getRuntimeProgram(_conf);
-
- //final cleanup runtime prog
- JMLCUtils.cleanupRuntimeProgram(rtprog, outputs);
-
- //System.out.println(Explain.explain(rtprog));
- }
- catch(Exception ex)
- {
- throw new DMLException(ex);
- }
-
- //return newly create precompiled script
- return new PreparedScript(rtprog, inputs, outputs);
- }
-
- /**
- *
- */
- public void close()
- {
-
- }
-
- /**
- *
- * @param fname
- * @return
- * @throws IOException
- */
- public String readScript(String fname)
- throws IOException
- {
- StringBuilder sb = new StringBuilder();
- BufferedReader in = null;
- try
- {
- //read from hdfs or gpfs file system
- if( fname.startsWith("hdfs:")
- || fname.startsWith("gpfs:") )
- {
- FileSystem fs = FileSystem.get(ConfigurationManager.getCachedJobConf());
- Path scriptPath = new Path(fname);
- in = new BufferedReader(new InputStreamReader(fs.open(scriptPath)));
- }
- // from local file system
- else
- {
- in = new BufferedReader(new FileReader(fname));
- }
-
- //core script reading
- String tmp = null;
- while ((tmp = in.readLine()) != null)
- {
- sb.append( tmp );
- sb.append( "\n" );
- }
- }
- catch (IOException ex)
- {
- throw ex;
- }
- finally
- {
- if( in != null )
- in.close();
- }
-
- return sb.toString();
- }
-
- /**
- * Converts an input string representation of a matrix in textcell format
- * into a dense double array. The number of rows and columns need to be
- * specified because textcell only represents non-zero values and hence
- * does not define the dimensions in the general case.
- *
- * @param input a string representation of an input matrix,
- * in format textcell (rowindex colindex value)
- * @param rows number of rows
- * @param cols number of columns
- * @return
- * @throws IOException
- */
- public double[][] convertToDoubleMatrix(String input, int rows, int cols)
- throws IOException
- {
- double[][] ret = null;
-
- try
- {
- //read input matrix
- InputStream is = new ByteArrayInputStream(input.getBytes("UTF-8"));
- ReaderTextCell reader = (ReaderTextCell)MatrixReaderFactory.createMatrixReader(InputInfo.TextCellInputInfo);
- MatrixBlock mb = reader.readMatrixFromInputStream(is, rows, cols, DMLTranslator.DMLBlockSize, DMLTranslator.DMLBlockSize, (long)rows*cols);
-
- //convert to double array
- ret = DataConverter.convertToDoubleMatrix( mb );
- }
- catch(DMLRuntimeException rex)
- {
- throw new IOException( rex );
- }
-
- return ret;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/jmlc/JMLCUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/jmlc/JMLCUtils.java b/src/main/java/com/ibm/bi/dml/api/jmlc/JMLCUtils.java
deleted file mode 100644
index 3cf3e18..0000000
--- a/src/main/java/com/ibm/bi/dml/api/jmlc/JMLCUtils.java
+++ /dev/null
@@ -1,107 +0,0 @@
-/**
- * (C) Copyright IBM Corp. 2010, 2015
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-
-package com.ibm.bi.dml.api.jmlc;
-
-import java.util.ArrayList;
-import java.util.Map;
-import java.util.Map.Entry;
-
-import com.ibm.bi.dml.runtime.controlprogram.ForProgramBlock;
-import com.ibm.bi.dml.runtime.controlprogram.FunctionProgramBlock;
-import com.ibm.bi.dml.runtime.controlprogram.IfProgramBlock;
-import com.ibm.bi.dml.runtime.controlprogram.Program;
-import com.ibm.bi.dml.runtime.controlprogram.ProgramBlock;
-import com.ibm.bi.dml.runtime.controlprogram.WhileProgramBlock;
-import com.ibm.bi.dml.runtime.instructions.Instruction;
-import com.ibm.bi.dml.runtime.instructions.cp.VariableCPInstruction;
-
-public class JMLCUtils
-{
-
-
- /**
- * Removes rmvar instructions that would remove any of the given outputs.
- * This is important for keeping registered outputs after the program terminates.
- *
- * @param prog
- */
- public static void cleanupRuntimeProgram( Program prog, String[] outputs)
- {
- Map<String, FunctionProgramBlock> funcMap = prog.getFunctionProgramBlocks();
- if( funcMap != null && !funcMap.isEmpty() )
- {
- for( Entry<String, FunctionProgramBlock> e : funcMap.entrySet() )
- {
- FunctionProgramBlock fpb = e.getValue();
- for( ProgramBlock pb : fpb.getChildBlocks() )
- rCleanupRuntimeProgram(pb, outputs);
- }
- }
-
- for( ProgramBlock pb : prog.getProgramBlocks() )
- rCleanupRuntimeProgram(pb, outputs);
- }
-
- /**
- *
- * @param pb
- * @param outputs
- */
- private static void rCleanupRuntimeProgram( ProgramBlock pb, String[] outputs )
- {
- if( pb instanceof WhileProgramBlock )
- {
- WhileProgramBlock wpb = (WhileProgramBlock)pb;
- for( ProgramBlock pbc : wpb.getChildBlocks() )
- rCleanupRuntimeProgram(pbc,outputs);
- }
- else if( pb instanceof IfProgramBlock )
- {
- IfProgramBlock ipb = (IfProgramBlock)pb;
- for( ProgramBlock pbc : ipb.getChildBlocksIfBody() )
- rCleanupRuntimeProgram(pbc,outputs);
- for( ProgramBlock pbc : ipb.getChildBlocksElseBody() )
- rCleanupRuntimeProgram(pbc,outputs);
- }
- else if( pb instanceof ForProgramBlock )
- {
- ForProgramBlock fpb = (ForProgramBlock)pb;
- for( ProgramBlock pbc : fpb.getChildBlocks() )
- rCleanupRuntimeProgram(pbc,outputs);
- }
- else
- {
- ArrayList<Instruction> tmp = pb.getInstructions();
- for( int i=0; i<tmp.size(); i++ )
- {
- Instruction linst = tmp.get(i);
- if( linst instanceof VariableCPInstruction && ((VariableCPInstruction)linst).isRemoveVariable() )
- {
- VariableCPInstruction varinst = (VariableCPInstruction) linst;
- for( String var : outputs )
- if( varinst.isRemoveVariable(var) )
- {
- tmp.remove(i);
- i--;
- break;
- }
- }
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/jmlc/PreparedScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/jmlc/PreparedScript.java b/src/main/java/com/ibm/bi/dml/api/jmlc/PreparedScript.java
deleted file mode 100644
index c3d4555..0000000
--- a/src/main/java/com/ibm/bi/dml/api/jmlc/PreparedScript.java
+++ /dev/null
@@ -1,231 +0,0 @@
-/**
- * (C) Copyright IBM Corp. 2010, 2015
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-
-package com.ibm.bi.dml.api.jmlc;
-
-import java.util.HashSet;
-
-import com.ibm.bi.dml.api.DMLException;
-import com.ibm.bi.dml.conf.ConfigurationManager;
-import com.ibm.bi.dml.conf.DMLConfig;
-import com.ibm.bi.dml.parser.Expression.ValueType;
-import com.ibm.bi.dml.runtime.controlprogram.LocalVariableMap;
-import com.ibm.bi.dml.runtime.controlprogram.Program;
-import com.ibm.bi.dml.runtime.controlprogram.caching.MatrixObject;
-import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContext;
-import com.ibm.bi.dml.runtime.controlprogram.context.ExecutionContextFactory;
-import com.ibm.bi.dml.runtime.instructions.cp.BooleanObject;
-import com.ibm.bi.dml.runtime.instructions.cp.DoubleObject;
-import com.ibm.bi.dml.runtime.instructions.cp.IntObject;
-import com.ibm.bi.dml.runtime.instructions.cp.ScalarObject;
-import com.ibm.bi.dml.runtime.instructions.cp.StringObject;
-import com.ibm.bi.dml.runtime.matrix.MatrixCharacteristics;
-import com.ibm.bi.dml.runtime.matrix.MatrixFormatMetaData;
-import com.ibm.bi.dml.runtime.matrix.data.InputInfo;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
-import com.ibm.bi.dml.runtime.matrix.data.OutputInfo;
-import com.ibm.bi.dml.runtime.util.DataConverter;
-
-/**
- * JMLC (Java Machine Learning Connector) API:
- *
- * NOTE: Currently fused API and implementation in order to reduce complexity.
- */
-public class PreparedScript
-{
-
- //input/output specification
- private HashSet<String> _inVarnames = null;
- private HashSet<String> _outVarnames = null;
-
- //internal state (reused)
- private Program _prog = null;
- private LocalVariableMap _vars = null;
-
- /**
- * Meant to be invoked only from Connection
- */
- protected PreparedScript( Program prog, String[] inputs, String[] outputs )
- {
- _prog = prog;
- _vars = new LocalVariableMap();
-
- //populate input/output vars
- _inVarnames = new HashSet<String>();
- for( String var : inputs )
- _inVarnames.add( var );
- _outVarnames = new HashSet<String>();
- for( String var : outputs )
- _outVarnames.add( var );
- }
-
- /**
- *
- * @param varname
- * @param scalar
- * @throws DMLException
- */
- public void setScalar(String varname, ScalarObject scalar)
- throws DMLException
- {
- if( !_inVarnames.contains(varname) )
- throw new DMLException("Unspecified input variable: "+varname);
-
- _vars.put(varname, scalar);
- }
-
- /**
- *
- * @param varname
- * @param scalar
- * @throws DMLException
- */
- public void setScalar(String varname, boolean scalar)
- throws DMLException
- {
- if( !_inVarnames.contains(varname) )
- throw new DMLException("Unspecified input variable: "+varname);
-
- BooleanObject bo = new BooleanObject(varname, scalar);
- _vars.put(varname, bo);
- }
-
- /**
- *
- * @param varname
- * @param scalar
- * @throws DMLException
- */
- public void setScalar(String varname, long scalar)
- throws DMLException
- {
- if( !_inVarnames.contains(varname) )
- throw new DMLException("Unspecified input variable: "+varname);
-
- IntObject io = new IntObject(varname, scalar);
- _vars.put(varname, io);
- }
-
- /**
- *
- * @param varname
- * @param scalar
- * @throws DMLException
- */
- public void setScalar(String varname, double scalar)
- throws DMLException
- {
- if( !_inVarnames.contains(varname) )
- throw new DMLException("Unspecified input variable: "+varname);
-
- DoubleObject doo = new DoubleObject(varname, scalar);
- _vars.put(varname, doo);
- }
-
- /**
- *
- * @param varname
- * @param scalar
- * @throws DMLException
- */
- public void setScalar(String varname, String scalar)
- throws DMLException
- {
- if( !_inVarnames.contains(varname) )
- throw new DMLException("Unspecified input variable: "+varname);
-
- StringObject so = new StringObject(varname, scalar);
- _vars.put(varname, so);
- }
-
- /**
- *
- * @param varname
- * @param matrix
- * @throws DMLException
- */
- public void setMatrix(String varname, MatrixBlock matrix)
- throws DMLException
- {
- if( !_inVarnames.contains(varname) )
- throw new DMLException("Unspecified input variable: "+varname);
-
-
- DMLConfig conf = ConfigurationManager.getConfig();
- String scratch_space = conf.getTextValue(DMLConfig.SCRATCH_SPACE);
- int blocksize = conf.getIntValue(DMLConfig.DEFAULT_BLOCK_SIZE);
-
- //create new matrix object
- MatrixCharacteristics mc = new MatrixCharacteristics(matrix.getNumRows(), matrix.getNumColumns(), blocksize, blocksize);
- MatrixFormatMetaData meta = new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
- MatrixObject mo = new MatrixObject(ValueType.DOUBLE, scratch_space+"/"+varname, meta);
- mo.acquireModify(matrix);
- mo.release();
-
- //put create matrix wrapper into symbol table
- _vars.put(varname, mo);
- }
-
- /**
- *
- * @param varname
- * @param matrix
- * @throws DMLException
- */
- public void setMatrix(String varname, double[][] matrix)
- throws DMLException
- {
- if( !_inVarnames.contains(varname) )
- throw new DMLException("Unspecified input variable: "+varname);
-
- MatrixBlock mb = DataConverter.convertToMatrixBlock(matrix);
- setMatrix(varname, mb);
- }
-
-
- /**
- *
- */
- public void clearParameters()
- {
- _vars.removeAll();
- }
-
- /**
- *
- * @return
- * @throws DMLException
- */
- public ResultVariables executeScript()
- throws DMLException
- {
- //create and populate execution context
- ExecutionContext ec = ExecutionContextFactory.createContext(_prog);
- ec.setVariables(_vars);
-
- //core execute runtime program
- _prog.execute( ec );
-
- //construct results
- ResultVariables rvars = new ResultVariables();
- for( String ovar : _outVarnames )
- if( _vars.keySet().contains(ovar) )
- rvars.addResult(ovar, _vars.get(ovar));
-
- return rvars;
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/jmlc/ResultVariables.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/jmlc/ResultVariables.java b/src/main/java/com/ibm/bi/dml/api/jmlc/ResultVariables.java
deleted file mode 100644
index dedb08f..0000000
--- a/src/main/java/com/ibm/bi/dml/api/jmlc/ResultVariables.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/**
- * (C) Copyright IBM Corp. 2010, 2015
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-
-package com.ibm.bi.dml.api.jmlc;
-
-import java.util.HashMap;
-import java.util.Set;
-
-import com.ibm.bi.dml.api.DMLException;
-import com.ibm.bi.dml.runtime.controlprogram.caching.MatrixObject;
-import com.ibm.bi.dml.runtime.instructions.cp.Data;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
-import com.ibm.bi.dml.runtime.util.DataConverter;
-
-/**
- * JMLC (Java Machine Learning Connector) API:
- *
- * NOTE: Currently fused API and implementation in order to reduce complexity.
- */
-public class ResultVariables
-{
-
- private HashMap<String, Data> _out = null;
-
- public ResultVariables()
- {
- _out = new HashMap<String, Data>();
- }
-
- public Set<String> getVariableNames()
- {
- return _out.keySet();
- }
-
- public int size()
- {
- return _out.size();
- }
-
- /**
- *
- * @param var
- * @return
- * @throws DMLException
- */
- public double[][] getMatrix(String varname)
- throws DMLException
- {
- if( !_out.containsKey(varname) )
- throw new DMLException("Non-existing output variable: "+varname);
-
- double[][] ret = null;
- Data dat = _out.get(varname);
-
- //basic checks for data type
- if( !(dat instanceof MatrixObject) )
- throw new DMLException("Expected matrix result '"+varname+"' not a matrix.");
-
- //convert output matrix to double array
- MatrixObject mo = (MatrixObject)dat;
- MatrixBlock mb = mo.acquireRead();
- ret = DataConverter.convertToDoubleMatrix(mb);
- mo.release();
-
- return ret;
- }
-
- /**
- *
- *
- * @param ovar
- * @param data
- */
- protected void addResult(String ovar, Data data)
- {
- _out.put(ovar, data);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/ml/LogisticRegression.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/ml/LogisticRegression.java b/src/main/java/com/ibm/bi/dml/api/ml/LogisticRegression.java
deleted file mode 100644
index 8803d76..0000000
--- a/src/main/java/com/ibm/bi/dml/api/ml/LogisticRegression.java
+++ /dev/null
@@ -1,464 +0,0 @@
-package com.ibm.bi.dml.api.ml;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.HashMap;
-
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.classification.LogisticRegressionParams;
-import org.apache.spark.ml.classification.ProbabilisticClassifier;
-import org.apache.spark.ml.param.BooleanParam;
-import org.apache.spark.ml.param.DoubleParam;
-import org.apache.spark.ml.param.IntParam;
-import org.apache.spark.ml.param.StringArrayParam;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
-
-import com.ibm.bi.dml.api.DMLException;
-import com.ibm.bi.dml.api.MLContext;
-import com.ibm.bi.dml.api.MLOutput;
-import com.ibm.bi.dml.api.ml.LogisticRegressionModel;
-import com.ibm.bi.dml.api.ml.functions.ConvertSingleColumnToString;
-import com.ibm.bi.dml.parser.ParseException;
-import com.ibm.bi.dml.runtime.DMLRuntimeException;
-import com.ibm.bi.dml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
-import com.ibm.bi.dml.runtime.matrix.MatrixCharacteristics;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixIndexes;
-
-/**
- *
- * This class shows how SystemML can be integrated into MLPipeline. Note, it has not been optimized for performance and
- * is implemented as a proof of concept. An optimized pipeline can be constructed by usage of DML's 'parfor' construct.
- *
- * TODO:
- * - Please note that this class expects 1-based labels. To run below example,
- * please set environment variable 'SYSTEMML_HOME' and create folder 'algorithms'
- * and place atleast two scripts in that folder 'MultiLogReg.dml' and 'GLM-predict.dml'
- * - It is not yet optimized for performance.
- * - Also, it needs to be extended to surface all the parameters of MultiLogReg.dml
- *
- * Example usage:
- * <pre><code>
- * // Code to demonstrate usage of pipeline
- * import org.apache.spark.ml.Pipeline
- * import com.ibm.bi.dml.api.ml.LogisticRegression
- * import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
- * import org.apache.spark.mllib.linalg.Vector
- * case class LabeledDocument(id: Long, text: String, label: Double)
- * case class Document(id: Long, text: String)
- * val training = sc.parallelize(Seq(
- * LabeledDocument(0L, "a b c d e spark", 1.0),
- * LabeledDocument(1L, "b d", 2.0),
- * LabeledDocument(2L, "spark f g h", 1.0),
- * LabeledDocument(3L, "hadoop mapreduce", 2.0),
- * LabeledDocument(4L, "b spark who", 1.0),
- * LabeledDocument(5L, "g d a y", 2.0),
- * LabeledDocument(6L, "spark fly", 1.0),
- * LabeledDocument(7L, "was mapreduce", 2.0),
- * LabeledDocument(8L, "e spark program", 1.0),
- * LabeledDocument(9L, "a e c l", 2.0),
- * LabeledDocument(10L, "spark compile", 1.0),
- * LabeledDocument(11L, "hadoop software", 2.0)))
- * val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
- * val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol).setOutputCol("features")
- * val lr = new LogisticRegression(sc, sqlContext)
- * val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr))
- * val model = pipeline.fit(training.toDF)
- * val test = sc.parallelize(Seq(
- * Document(12L, "spark i j k"),
- * Document(13L, "l m n"),
- * Document(14L, "mapreduce spark"),
- * Document(15L, "apache hadoop")))
- * model.transform(test.toDF).show
- *
- * // Code to demonstrate usage of cross-validation
- * import org.apache.spark.ml.Pipeline
- * import com.ibm.bi.dml.api.ml.LogisticRegression
- * import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
- * import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
- * import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
- * import org.apache.spark.mllib.linalg.Vector
- * case class LabeledDocument(id: Long, text: String, label: Double)
- * case class Document(id: Long, text: String)
- * val training = sc.parallelize(Seq(
- * LabeledDocument(0L, "a b c d e spark", 1.0),
- * LabeledDocument(1L, "b d", 2.0),
- * LabeledDocument(2L, "spark f g h", 1.0),
- * LabeledDocument(3L, "hadoop mapreduce", 2.0),
- * LabeledDocument(4L, "b spark who", 1.0),
- * LabeledDocument(5L, "g d a y", 2.0),
- * LabeledDocument(6L, "spark fly", 1.0),
- * LabeledDocument(7L, "was mapreduce", 2.0),
- * LabeledDocument(8L, "e spark program", 1.0),
- * LabeledDocument(9L, "a e c l", 2.0),
- * LabeledDocument(10L, "spark compile", 1.0),
- * LabeledDocument(11L, "hadoop software", 2.0)))
- * val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words")
- * val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol).setOutputCol("features")
- * val lr = new LogisticRegression(sc, sqlContext)
- * val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr))
- * val crossval = new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator)
- * val paramGrid = new ParamGridBuilder().addGrid(hashingTF.numFeatures, Array(10, 100, 1000)).addGrid(lr.regParam, Array(0.1, 0.01)).build()
- * crossval.setEstimatorParamMaps(paramGrid)
- * crossval.setNumFolds(2)
- * val cvModel = crossval.fit(training.toDF)
- * val test = sc.parallelize(Seq(
- * Document(12L, "spark i j k"),
- * Document(13L, "l m n"),
- * Document(14L, "mapreduce spark"),
- * Document(15L, "apache hadoop")))
- * cvModel.transform(test.toDF).show
- * </code></pre>
- *
- */
-public class LogisticRegression extends ProbabilisticClassifier<Vector, LogisticRegression, LogisticRegressionModel>
- implements LogisticRegressionParams {
-
- private static final long serialVersionUID = 7763813395635870734L;
-
- private SparkContext sc = null;
- private SQLContext sqlContext = null;
- private HashMap<String, String> cmdLineParams = new HashMap<String, String>();
-
- private IntParam icpt = new IntParam(this, "icpt", "Value of intercept");
- private DoubleParam reg = new DoubleParam(this, "reg", "Value of regularization parameter");
- private DoubleParam tol = new DoubleParam(this, "tol", "Value of tolerance");
- private IntParam moi = new IntParam(this, "moi", "Max outer iterations");
- private IntParam mii = new IntParam(this, "mii", "Max inner iterations");
- private IntParam labelIndex = new IntParam(this, "li", "Index of the label column");
- private StringArrayParam inputCol = new StringArrayParam(this, "icname", "Feature column name");
- private StringArrayParam outputCol = new StringArrayParam(this, "ocname", "Label column name");
- private int intMin = Integer.MIN_VALUE;
- @SuppressWarnings("unused")
- private int li = 0;
- private String[] icname = new String[1];
- private String[] ocname = new String[1];
-
- public LogisticRegression() {
- }
-
- public LogisticRegression(String uid) {
- }
-
- @Override
- public LogisticRegression copy(org.apache.spark.ml.param.ParamMap paramMap) {
- try {
- // Copy deals with command-line parameter of script MultiLogReg.dml
- LogisticRegression lr = new LogisticRegression(sc, sqlContext);
- lr.cmdLineParams.put(icpt.name(), paramMap.getOrElse(icpt, 0).toString());
- lr.cmdLineParams.put(reg.name(), paramMap.getOrElse(reg, 0.0f).toString());
- lr.cmdLineParams.put(tol.name(), paramMap.getOrElse(tol, 0.000001f).toString());
- lr.cmdLineParams.put(moi.name(), paramMap.getOrElse(moi, 100).toString());
- lr.cmdLineParams.put(mii.name(), paramMap.getOrElse(mii, 0).toString());
-
- return lr;
- } catch (DMLRuntimeException e) {
- e.printStackTrace();
- }
- return null;
-
- }
-
- public LogisticRegression(SparkContext sc, SQLContext sqlContext) throws DMLRuntimeException {
- this.sc = sc;
- this.sqlContext = sqlContext;
-
- setDefault(intercept(), 0);
- cmdLineParams.put(icpt.name(), "0");
- setDefault(regParam(), 0.0f);
- cmdLineParams.put(reg.name(), "0.0f");
- setDefault(tol(), 0.000001f);
- cmdLineParams.put(tol.name(), "0.000001f");
- setDefault(maxOuterIter(), 100);
- cmdLineParams.put(moi.name(), "100");
- setDefault(maxInnerIter(), 0);
- cmdLineParams.put(mii.name(), "0");
- setDefault(labelIdx(), intMin);
- li = intMin;
- setDefault(inputCol(), icname);
- icname[0] = "";
- setDefault(outputCol(), ocname);
- ocname[0] = "";
- }
-
- public LogisticRegression(SparkContext sc, SQLContext sqlContext, int icpt, double reg, double tol, int moi, int mii) throws DMLRuntimeException {
- this.sc = sc;
- this.sqlContext = sqlContext;
-
- setDefault(intercept(), icpt);
- cmdLineParams.put(this.icpt.name(), Integer.toString(icpt));
- setDefault(regParam(), reg);
- cmdLineParams.put(this.reg.name(), Double.toString(reg));
- setDefault(tol(), tol);
- cmdLineParams.put(this.tol.name(), Double.toString(tol));
- setDefault(maxOuterIter(), moi);
- cmdLineParams.put(this.moi.name(), Integer.toString(moi));
- setDefault(maxInnerIter(), mii);
- cmdLineParams.put(this.mii.name(), Integer.toString(mii));
- setDefault(labelIdx(), intMin);
- li = intMin;
- setDefault(inputCol(), icname);
- icname[0] = "";
- setDefault(outputCol(), ocname);
- ocname[0] = "";
- }
-
- @Override
- public String uid() {
- return Long.toString(LogisticRegression.serialVersionUID);
- }
-
- public LogisticRegression setRegParam(double value) {
- cmdLineParams.put(reg.name(), Double.toString(value));
- return (LogisticRegression) setDefault(reg, value);
- }
-
- @Override
- public org.apache.spark.sql.types.StructType validateAndTransformSchema(org.apache.spark.sql.types.StructType arg0, boolean arg1, org.apache.spark.sql.types.DataType arg2) {
- return null;
- }
-
- @Override
- public double getRegParam() {
- return Double.parseDouble(cmdLineParams.get(reg.name()));
- }
-
- @Override
- public void org$apache$spark$ml$param$shared$HasRegParam$_setter_$regParam_$eq(DoubleParam arg0) {
-
- }
-
- @Override
- public DoubleParam regParam() {
- return reg;
- }
-
- @Override
- public DoubleParam elasticNetParam() {
- return null;
- }
-
- @Override
- public double getElasticNetParam() {
- return 0.0f;
- }
-
- @Override
- public void org$apache$spark$ml$param$shared$HasElasticNetParam$_setter_$elasticNetParam_$eq(DoubleParam arg0) {
-
- }
-
- @Override
- public int getMaxIter() {
- return 0;
- }
-
- @Override
- public IntParam maxIter() {
- return null;
- }
-
- public LogisticRegression setMaxOuterIter(int value) {
- cmdLineParams.put(moi.name(), Integer.toString(value));
- return (LogisticRegression) setDefault(moi, value);
- }
-
- public int getMaxOuterIter() {
- return Integer.parseInt(cmdLineParams.get(moi.name()));
- }
-
- public IntParam maxOuterIter() {
- return this.moi;
- }
-
- public LogisticRegression setMaxInnerIter(int value) {
- cmdLineParams.put(mii.name(), Integer.toString(value));
- return (LogisticRegression) setDefault(mii, value);
- }
-
- public int getMaxInnerIter() {
- return Integer.parseInt(cmdLineParams.get(mii.name()));
- }
-
- public IntParam maxInnerIter() {
- return mii;
- }
-
- @Override
- public void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam arg0) {
-
- }
-
- public LogisticRegression setIntercept(int value) {
- cmdLineParams.put(icpt.name(), Integer.toString(value));
- return (LogisticRegression) setDefault(icpt, value);
- }
-
- public int getIntercept() {
- return Integer.parseInt(cmdLineParams.get(icpt.name()));
- }
-
- public IntParam intercept() {
- return icpt;
- }
-
- @Override
- public BooleanParam fitIntercept() {
- return null;
- }
-
- @Override
- public boolean getFitIntercept() {
- return false;
- }
-
- @Override
- public void org$apache$spark$ml$param$shared$HasFitIntercept$_setter_$fitIntercept_$eq(BooleanParam arg0) {
-
- }
-
- public LogisticRegression setTol(double value) {
- cmdLineParams.put(tol.name(), Double.toString(value));
- return (LogisticRegression) setDefault(tol, value);
- }
-
- @Override
- public double getTol() {
- return Double.parseDouble(cmdLineParams.get(tol.name()));
- }
-
- @Override
- public void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam arg0) {
-
- }
-
- @Override
- public DoubleParam tol() {
- return tol;
- }
-
- @Override
- public double getThreshold() {
- return 0;
- }
-
- @Override
- public void org$apache$spark$ml$param$shared$HasThreshold$_setter_$threshold_$eq(DoubleParam arg0) {
-
- }
-
- @Override
- public DoubleParam threshold() {
- return null;
- }
-
- public LogisticRegression setLabelIndex(int value) {
- li = value;
- return (LogisticRegression) setDefault(labelIndex, value);
- }
-
- public int getLabelIndex() {
- return Integer.parseInt(cmdLineParams.get(labelIndex.name()));
- }
-
- public IntParam labelIdx() {
- return labelIndex;
- }
-
- public LogisticRegression setInputCol(String[] value) {
- icname[0] = value[0];
- return (LogisticRegression) setDefault(inputCol, value);
- }
-
- public String getInputCol() {
- return icname[0];
- }
-
- public StringArrayParam inputCol() {
- return inputCol;
- }
-
- public LogisticRegression setOutputCol(String[] value) {
- ocname[0] = value[0];
- return (LogisticRegression) setDefault(outputCol, value);
- }
-
- public String getOutputCol() {
- return ocname[0];
- }
-
- public StringArrayParam outputCol() {
- return outputCol;
- }
-
- @Override
- public LogisticRegressionModel train(DataFrame df) {
- MLContext ml = null;
- MLOutput out = null;
-
- try {
- ml = new MLContext(this.sc);
- } catch (DMLRuntimeException e1) {
- e1.printStackTrace();
- return null;
- }
-
- // Convert input data to format that SystemML accepts
- MatrixCharacteristics mcXin = new MatrixCharacteristics();
- JavaPairRDD<MatrixIndexes, MatrixBlock> Xin;
- try {
- Xin = RDDConverterUtilsExt.vectorDataFrameToBinaryBlock(new JavaSparkContext(this.sc), df, mcXin, false, "features");
- } catch (DMLRuntimeException e1) {
- e1.printStackTrace();
- return null;
- }
-
- JavaRDD<String> yin = df.select("label").rdd().toJavaRDD().map(new ConvertSingleColumnToString());
-
- try {
- // Register the input/output variables of script 'MultiLogReg.dml'
- ml.registerInput("X", Xin, mcXin);
- ml.registerInput("Y_vec", yin, "csv");
- ml.registerOutput("B_out");
-
- // Or add ifdef in MultiLogReg.dml
- cmdLineParams.put("X", " ");
- cmdLineParams.put("Y", " ");
- cmdLineParams.put("B", " ");
-
-
- // ------------------------------------------------------------------------------------
- // Please note that this logic is subject to change and is put as a placeholder
- String systemmlHome = System.getenv("SYSTEMML_HOME");
- if(systemmlHome == null) {
- System.err.println("ERROR: The environment variable SYSTEMML_HOME is not set.");
- return null;
- }
-
- String dmlFilePath = systemmlHome + File.separator + "algorithms" + File.separator + "MultiLogReg.dml";
- // ------------------------------------------------------------------------------------
-
- synchronized(MLContext.class) {
- // static synchronization is necessary before execute call
- out = ml.execute(dmlFilePath, cmdLineParams);
- }
-
- JavaPairRDD<MatrixIndexes, MatrixBlock> b_out = out.getBinaryBlockedRDD("B_out");
- MatrixCharacteristics b_outMC = out.getMatrixCharacteristics("B_out");
- return new LogisticRegressionModel(b_out, b_outMC, sc).setParent(this);
- } catch (IOException e) {
- throw new RuntimeException(e);
- } catch (DMLRuntimeException e) {
- throw new RuntimeException(e);
- } catch (DMLException e) {
- throw new RuntimeException(e);
- } catch (ParseException e) {
- throw new RuntimeException(e);
- }
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/ml/LogisticRegressionModel.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/ml/LogisticRegressionModel.java b/src/main/java/com/ibm/bi/dml/api/ml/LogisticRegressionModel.java
deleted file mode 100644
index 590372e..0000000
--- a/src/main/java/com/ibm/bi/dml/api/ml/LogisticRegressionModel.java
+++ /dev/null
@@ -1,169 +0,0 @@
-package com.ibm.bi.dml.api.ml;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.HashMap;
-
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.ml.classification.ProbabilisticClassificationModel;
-import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SQLContext;
-
-import com.ibm.bi.dml.api.DMLException;
-import com.ibm.bi.dml.api.MLContext;
-import com.ibm.bi.dml.api.MLOutput;
-import com.ibm.bi.dml.parser.ParseException;
-import com.ibm.bi.dml.runtime.DMLRuntimeException;
-import com.ibm.bi.dml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
-import com.ibm.bi.dml.runtime.matrix.MatrixCharacteristics;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
-import com.ibm.bi.dml.runtime.matrix.data.MatrixIndexes;
-
-public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector, LogisticRegressionModel> {
-
- private static final long serialVersionUID = -6464693773946415027L;
- private JavaPairRDD<MatrixIndexes, MatrixBlock> b_out;
- private SparkContext sc;
- private MatrixCharacteristics b_outMC;
- @Override
- public LogisticRegressionModel copy(ParamMap paramMap) {
- return this;
- }
-
- public LogisticRegressionModel(JavaPairRDD<MatrixIndexes, MatrixBlock> b_out2, MatrixCharacteristics b_outMC, SparkContext sc) {
- this.b_out = b_out2;
- this.b_outMC = b_outMC;
- this.sc = sc;
- //this.cmdLineParams = cmdLineParams;
- }
-
- public LogisticRegressionModel() {
- }
-
- public LogisticRegressionModel(String uid) {
- }
-
- @Override
- public String uid() {
- return Long.toString(LogisticRegressionModel.serialVersionUID);
- }
-
- @Override
- public Vector raw2probabilityInPlace(Vector arg0) {
- return arg0;
- }
-
- @Override
- public int numClasses() {
- return 2;
- }
-
- @Override
- public Vector predictRaw(Vector arg0) {
- return arg0;
- }
-
-
- @Override
- public double predict(Vector features) {
- return super.predict(features);
- }
-
- @Override
- public double raw2prediction(Vector rawPrediction) {
- return super.raw2prediction(rawPrediction);
- }
-
- @Override
- public double probability2prediction(Vector probability) {
- return super.probability2prediction(probability);
- }
-
- public static class ConvertIntToRow implements Function<Integer, Row> {
-
- private static final long serialVersionUID = -3480953015655773622L;
-
- @Override
- public Row call(Integer arg0) throws Exception {
- Object[] row_fields = new Object[1];
- row_fields[0] = new Double(arg0);
- return RowFactory.create(row_fields);
- }
-
- }
-
- @Override
- public DataFrame transform(DataFrame dataset) {
- try {
- MatrixCharacteristics mcXin = new MatrixCharacteristics();
- JavaPairRDD<MatrixIndexes, MatrixBlock> Xin;
- try {
- Xin = RDDConverterUtilsExt.vectorDataFrameToBinaryBlock(new JavaSparkContext(this.sc), dataset, mcXin, false, "features");
- } catch (DMLRuntimeException e1) {
- e1.printStackTrace();
- return null;
- }
- MLContext ml = new MLContext(sc);
- ml.registerInput("X", Xin, mcXin);
- ml.registerInput("B_full", b_out, b_outMC); // Changed MLContext for this method
- ml.registerOutput("means");
- HashMap<String, String> param = new HashMap<String, String>();
- param.put("dfam", "3");
-
- // ------------------------------------------------------------------------------------
- // Please note that this logic is subject to change and is put as a placeholder
- String systemmlHome = System.getenv("SYSTEMML_HOME");
- if(systemmlHome == null) {
- System.err.println("ERROR: The environment variable SYSTEMML_HOME is not set.");
- return null;
- }
- // Or add ifdef in GLM-predict.dml
- param.put("X", " ");
- param.put("B", " ");
-
- String dmlFilePath = systemmlHome + File.separator + "algorithms" + File.separator + "GLM-predict.dml";
- // ------------------------------------------------------------------------------------
- MLOutput out = ml.execute(dmlFilePath, param);
-
- SQLContext sqlContext = new SQLContext(sc);
- DataFrame prob = out.getDF(sqlContext, "means", true).withColumnRenamed("C1", "probability");
-
- MLContext mlNew = new MLContext(sc);
- mlNew.registerInput("X", Xin, mcXin);
- mlNew.registerInput("B_full", b_out, b_outMC);
- mlNew.registerInput("Prob", out.getBinaryBlockedRDD("means"), out.getMatrixCharacteristics("means"));
- mlNew.registerOutput("Prediction");
- mlNew.registerOutput("rawPred");
- MLOutput outNew = mlNew.executeScript("Prob = read(\"temp1\"); "
- + "Prediction = rowIndexMax(Prob); "
- + "write(Prediction, \"tempOut\", \"csv\")"
- + "X = read(\"temp2\");"
- + "B_full = read(\"temp3\");"
- + "rawPred = 1 / (1 + exp(- X * t(B_full)) );" // Raw prediction logic:
- + "write(rawPred, \"tempOut1\", \"csv\")");
-
- // TODO: Perform joins in the DML
- DataFrame pred = outNew.getDF(sqlContext, "Prediction").withColumnRenamed("C1", "prediction").withColumnRenamed("ID", "ID1");
- DataFrame rawPred = outNew.getDF(sqlContext, "rawPred", true).withColumnRenamed("C1", "rawPrediction").withColumnRenamed("ID", "ID2");
- DataFrame predictionsNProb = prob.join(pred, prob.col("ID").equalTo(pred.col("ID1"))).select("ID", "probability", "prediction");
- predictionsNProb = predictionsNProb.join(rawPred, predictionsNProb.col("ID").equalTo(rawPred.col("ID2"))).select("ID", "probability", "prediction", "rawPrediction");
- DataFrame dataset1 = RDDConverterUtilsExt.addIDToDataFrame(dataset, sqlContext, "ID");
- return dataset1.join(predictionsNProb, dataset1.col("ID").equalTo(predictionsNProb.col("ID"))).orderBy("id");
- } catch (IOException e) {
- throw new RuntimeException(e);
- } catch (DMLRuntimeException e) {
- throw new RuntimeException(e);
- } catch (DMLException e) {
- throw new RuntimeException(e);
- } catch (ParseException e) {
- throw new RuntimeException(e);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/ml/functions/ConvertSingleColumnToString.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/ml/functions/ConvertSingleColumnToString.java b/src/main/java/com/ibm/bi/dml/api/ml/functions/ConvertSingleColumnToString.java
deleted file mode 100644
index 32844de..0000000
--- a/src/main/java/com/ibm/bi/dml/api/ml/functions/ConvertSingleColumnToString.java
+++ /dev/null
@@ -1,15 +0,0 @@
-package com.ibm.bi.dml.api.ml.functions;
-
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.Row;
-
-public class ConvertSingleColumnToString implements Function<Row, String> {
-
- private static final long serialVersionUID = -499763403738768970L;
-
- @Override
- public String call(Row row) throws Exception {
- return row.apply(0).toString();
- }
-}
-
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/ml/functions/ConvertVectorToDouble.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/ml/functions/ConvertVectorToDouble.java b/src/main/java/com/ibm/bi/dml/api/ml/functions/ConvertVectorToDouble.java
deleted file mode 100644
index e9c77ba..0000000
--- a/src/main/java/com/ibm/bi/dml/api/ml/functions/ConvertVectorToDouble.java
+++ /dev/null
@@ -1,16 +0,0 @@
-package com.ibm.bi.dml.api.ml.functions;
-
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.Row;
-
-public class ConvertVectorToDouble implements Function<Row, Double> {
-
- private static final long serialVersionUID = -6612447783777073929L;
-
- @Override
- public Double call(Row row) throws Exception {
-
- return row.getDouble(0);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/monitoring/InstructionComparator.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/monitoring/InstructionComparator.java b/src/main/java/com/ibm/bi/dml/api/monitoring/InstructionComparator.java
deleted file mode 100644
index 906ebb4..0000000
--- a/src/main/java/com/ibm/bi/dml/api/monitoring/InstructionComparator.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/**
- * (C) Copyright IBM Corp. 2010, 2015
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-package com.ibm.bi.dml.api.monitoring;
-
-import java.util.Comparator;
-import java.util.HashMap;
-
-public class InstructionComparator implements Comparator<String>{
-
- HashMap<String, Long> instructionCreationTime;
- public InstructionComparator(HashMap<String, Long> instructionCreationTime) {
- this.instructionCreationTime = instructionCreationTime;
- }
- @Override
- public int compare(String o1, String o2) {
- try {
- return instructionCreationTime.get(o1).compareTo(instructionCreationTime.get(o2));
- }
- catch(Exception e) {
- return -1;
- }
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/monitoring/Location.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/monitoring/Location.java b/src/main/java/com/ibm/bi/dml/api/monitoring/Location.java
deleted file mode 100644
index 3cc2957..0000000
--- a/src/main/java/com/ibm/bi/dml/api/monitoring/Location.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/**
- * (C) Copyright IBM Corp. 2010, 2015
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-package com.ibm.bi.dml.api.monitoring;
-
-public class Location implements Comparable<Location> {
- public int beginLine;
- public int endLine;
- public int beginCol;
- public int endCol;
- public Location(int beginLine, int endLine, int beginCol, int endCol) {
- this.beginLine = beginLine;
- this.endLine = endLine;
- this.beginCol = beginCol;
- this.endCol = endCol;
- }
-
- @Override
- public boolean equals(Object other) {
- if(other instanceof Location) {
- Location loc = (Location) other;
- if(loc.beginLine == beginLine && loc.endLine == endLine && loc.beginCol == beginCol && loc.endCol == endCol) {
- return true;
- }
- else
- return false;
- }
- return false;
- }
-
- private int compare(int v1, int v2) {
- return new Integer(v1).compareTo(new Integer(v2));
- }
-
- public String toString() {
- return beginLine + ":" + beginCol + ", " + endLine + ":" + endCol;
- }
-
- @Override
- public int compareTo(Location loc) {
- if(loc.beginLine == beginLine && loc.endLine == endLine && loc.beginCol == beginCol && loc.endCol == endCol)
- return 0;
-
- int retVal = compare(beginLine, loc.beginLine);
- if(retVal != 0) {
- return retVal;
- }
- else {
- retVal = compare(beginCol, loc.beginCol);
- if(retVal != 0) {
- return retVal;
- }
- else {
- retVal = compare(endLine, loc.endLine);
- if(retVal != 0) {
- return retVal;
- }
- else {
- return compare(endCol, loc.endCol);
- }
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/276d9257/src/main/java/com/ibm/bi/dml/api/monitoring/SparkMonitoringUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/com/ibm/bi/dml/api/monitoring/SparkMonitoringUtil.java b/src/main/java/com/ibm/bi/dml/api/monitoring/SparkMonitoringUtil.java
deleted file mode 100644
index 9b4996b..0000000
--- a/src/main/java/com/ibm/bi/dml/api/monitoring/SparkMonitoringUtil.java
+++ /dev/null
@@ -1,602 +0,0 @@
-/**
- * (C) Copyright IBM Corp. 2010, 2015
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *
- */
-package com.ibm.bi.dml.api.monitoring;
-
-import java.io.BufferedWriter;
-import java.io.FileWriter;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map.Entry;
-
-import scala.collection.Seq;
-import scala.xml.Node;
-
-import com.google.common.collect.Multimap;
-import com.google.common.collect.TreeMultimap;
-import com.ibm.bi.dml.lops.Lop;
-import com.ibm.bi.dml.runtime.DMLRuntimeException;
-import com.ibm.bi.dml.runtime.instructions.Instruction;
-import com.ibm.bi.dml.runtime.instructions.spark.SPInstruction;
-import com.ibm.bi.dml.runtime.instructions.spark.functions.SparkListener;
-
-/**
- * Usage guide:
- * MLContext mlCtx = new MLContext(sc, true);
- * mlCtx.register...
- * mlCtx.execute(...)
- * mlCtx.getMonitoringUtil().getRuntimeInfoInHTML("runtime.html");
- */
-public class SparkMonitoringUtil {
- // ----------------------------------------------------
- // For VLDB Demo:
- private Multimap<Location, String> instructions = TreeMultimap.create();
- private Multimap<String, Integer> stageIDs = TreeMultimap.create(); // instruction -> stageIds
- private Multimap<String, Integer> jobIDs = TreeMultimap.create(); // instruction -> jobIds
- private HashMap<String, String> lineageInfo = new HashMap<String, String>(); // instruction -> lineageInfo
- private HashMap<String, Long> instructionCreationTime = new HashMap<String, Long>();
-
- private Multimap<Integer, String> rddInstructionMapping = TreeMultimap.create();
-
- private HashSet<String> getRelatedInstructions(int stageID) {
- HashSet<String> retVal = new HashSet<String>();
- if(_sparkListener != null) {
- ArrayList<Integer> rdds = _sparkListener.stageRDDMapping.get(stageID);
- for(Integer rddID : rdds) {
- retVal.addAll(rddInstructionMapping.get(rddID));
- }
- }
- return retVal;
- }
-
- private SparkListener _sparkListener = null;
- public SparkListener getSparkListener() {
- return _sparkListener;
- }
-
- private String explainOutput = "";
-
- public String getExplainOutput() {
- return explainOutput;
- }
-
- public void setExplainOutput(String explainOutput) {
- this.explainOutput = explainOutput;
- }
-
- public SparkMonitoringUtil(SparkListener sparkListener) {
- _sparkListener = sparkListener;
- }
-
- public void addCurrentInstruction(SPInstruction inst) {
- if(_sparkListener != null) {
- _sparkListener.addCurrentInstruction(inst);
- }
- }
-
- public void addRDDForInstruction(SPInstruction inst, Integer rddID) {
- this.rddInstructionMapping.put(rddID, getInstructionString(inst));
- }
-
- public void removeCurrentInstruction(SPInstruction inst) {
- if(_sparkListener != null) {
- _sparkListener.removeCurrentInstruction(inst);
- }
- }
-
- public void setDMLString(String dmlStr) {
- this.dmlStrForMonitoring = dmlStr;
- }
-
- public void resetMonitoringData() {
- if(_sparkListener != null && _sparkListener.stageDAGs != null)
- _sparkListener.stageDAGs.clear();
- if(_sparkListener != null && _sparkListener.stageTimeline != null)
- _sparkListener.stageTimeline.clear();
- }
-
- // public Multimap<Location, String> hops = ArrayListMultimap.create(); TODO:
- private String dmlStrForMonitoring = null;
- public void getRuntimeInfoInHTML(String htmlFilePath) throws DMLRuntimeException, IOException {
- String jsAndCSSFiles = "<script src=\"js/lodash.min.js\"></script>"
- + "<script src=\"js/jquery-1.11.1.min.js\"></script>"
- + "<script src=\"js/d3.min.js\"></script>"
- + "<script src=\"js/bootstrap-tooltip.js\"></script>"
- + "<script src=\"js/dagre-d3.min.js\"></script>"
- + "<script src=\"js/graphlib-dot.min.js\"></script>"
- + "<script src=\"js/spark-dag-viz.js\"></script>"
- + "<script src=\"js/timeline-view.js\"></script>"
- + "<script src=\"js/vis.min.js\"></script>"
- + "<link rel=\"stylesheet\" href=\"css/bootstrap.min.css\">"
- + "<link rel=\"stylesheet\" href=\"css/vis.min.css\">"
- + "<link rel=\"stylesheet\" href=\"css/spark-dag-viz.css\">"
- + "<link rel=\"stylesheet\" href=\"css/timeline-view.css\"> ";
- BufferedWriter bw = new BufferedWriter(new FileWriter(htmlFilePath));
- bw.write("<html><head>\n");
- bw.write(jsAndCSSFiles + "\n");
- bw.write("</head><body>\n<table border=1>\n");
-
- bw.write("<tr>\n");
- bw.write("<td><b>Position in script</b></td>\n");
- bw.write("<td><b>DML</b></td>\n");
- bw.write("<td><b>Instruction</b></td>\n");
- bw.write("<td><b>StageIDs</b></td>\n");
- bw.write("<td><b>RDD Lineage</b></td>\n");
- bw.write("</tr>\n");
-
- for(Location loc : instructions.keySet()) {
- String dml = getExpression(loc);
-
- // Sort the instruction with time - so as to separate recompiled instructions
- List<String> listInst = new ArrayList<String>(instructions.get(loc));
- Collections.sort(listInst, new InstructionComparator(instructionCreationTime));
-
- if(dml != null && dml.trim().length() > 1) {
- bw.write("<tr>\n");
- int rowSpan = listInst.size();
- bw.write("<td rowspan=\"" + rowSpan + "\">" + loc.toString() + "</td>\n");
- bw.write("<td rowspan=\"" + rowSpan + "\">" + dml + "</td>\n");
- boolean firstTime = true;
- for(String inst : listInst) {
- if(!firstTime)
- bw.write("<tr>\n");
-
- if(inst.startsWith("SPARK"))
- bw.write("<td style=\"color:red\">" + inst + "</td>\n");
- else if(isInterestingCP(inst))
- bw.write("<td style=\"color:blue\">" + inst + "</td>\n");
- else
- bw.write("<td>" + inst + "</td>\n");
-
- bw.write("<td>" + getStageIDAsString(inst) + "</td>\n");
- if(lineageInfo.containsKey(inst))
- bw.write("<td>" + lineageInfo.get(inst).replaceAll("\n", "<br />") + "</td>\n");
- else
- bw.write("<td></td>\n");
-
- bw.write("</tr>\n");
- firstTime = false;
- }
-
- }
-
- }
-
- bw.write("</table></body>\n</html>");
- bw.close();
- }
-
- private String getInQuotes(String str) {
- return "\"" + str + "\"";
- }
- private String getEscapedJSON(String json) {
- if(json == null)
- return "";
- else {
- return json
- //.replaceAll("\\\\", "\\\\\\")
- .replaceAll("\\t", "\\\\t")
- .replaceAll("/", "\\\\/")
- .replaceAll("\"", "\\\\\"")
- .replaceAll("\\r?\\n", "\\\\n");
- }
- }
-
- private long maxExpressionExecutionTime = 0;
- HashMap<Integer, Long> stageExecutionTimes = new HashMap<Integer, Long>();
- HashMap<String, Long> expressionExecutionTimes = new HashMap<String, Long>();
- HashMap<String, Long> instructionExecutionTimes = new HashMap<String, Long>();
- HashMap<Integer, HashSet<String>> relatedInstructionsPerStage = new HashMap<Integer, HashSet<String>>();
- private void fillExecutionTimes() {
- stageExecutionTimes.clear();
- expressionExecutionTimes.clear();
- for(Location loc : instructions.keySet()) {
- List<String> listInst = new ArrayList<String>(instructions.get(loc));
- long expressionExecutionTime = 0;
-
- if(listInst != null && listInst.size() > 0) {
- for(String inst : listInst) {
- long instructionExecutionTime = 0;
- for(Integer stageId : stageIDs.get(inst)) {
- try {
- if(getStageExecutionTime(stageId) != null) {
- long stageExecTime = getStageExecutionTime(stageId);
- instructionExecutionTime += stageExecTime;
- expressionExecutionTime += stageExecTime;
- stageExecutionTimes.put(stageId, stageExecTime);
- }
- }
- catch(Exception e) {}
-
- relatedInstructionsPerStage.put(stageId, getRelatedInstructions(stageId));
- }
- instructionExecutionTimes.put(inst, instructionExecutionTime);
- }
- expressionExecutionTime /= listInst.size(); // average
- }
- maxExpressionExecutionTime = Math.max(maxExpressionExecutionTime, expressionExecutionTime);
- expressionExecutionTimes.put(loc.toString(), expressionExecutionTime);
- }
-
- // Now fill empty instructions
- for(Entry<String, Long> kv : instructionExecutionTimes.entrySet()) {
- if(kv.getValue() == 0) {
- // Find all stages that contain this as related instruction
- long sumExecutionTime = 0;
- for(Entry<Integer, HashSet<String>> kv1 : relatedInstructionsPerStage.entrySet()) {
- if(kv1.getValue().contains(kv.getKey())) {
- sumExecutionTime += stageExecutionTimes.get(kv1.getKey());
- }
- }
- kv.setValue(sumExecutionTime);
- }
- }
-
- for(Location loc : instructions.keySet()) {
- if(expressionExecutionTimes.get(loc.toString()) == 0) {
- List<String> listInst = new ArrayList<String>(instructions.get(loc));
- long expressionExecutionTime = 0;
- if(listInst != null && listInst.size() > 0) {
- for(String inst : listInst) {
- expressionExecutionTime += instructionExecutionTimes.get(inst);
- }
- }
- expressionExecutionTime /= listInst.size(); // average
- maxExpressionExecutionTime = Math.max(maxExpressionExecutionTime, expressionExecutionTime);
- expressionExecutionTimes.put(loc.toString(), expressionExecutionTime);
- }
- }
-
- }
-
- /**
- * Useful to avoid passing large String through Py4J
- * @param fileName
- * @throws DMLRuntimeException
- * @throws IOException
- */
- public void saveRuntimeInfoInJSONFormat(String fileName) throws DMLRuntimeException, IOException {
- String json = getRuntimeInfoInJSONFormat();
- BufferedWriter bw = new BufferedWriter(new FileWriter(fileName));
- bw.write(json);
- bw.close();
- }
-
- public String getRuntimeInfoInJSONFormat() throws DMLRuntimeException, IOException {
- StringBuilder retVal = new StringBuilder("{\n");
-
- retVal.append(getInQuotes("dml") + ":" + getInQuotes(getEscapedJSON(dmlStrForMonitoring)) + ",\n");
- retVal.append(getInQuotes("expressions") + ":" + "[\n");
-
- boolean isFirstExpression = true;
- fillExecutionTimes();
-
- for(Location loc : instructions.keySet()) {
- String dml = getEscapedJSON(getExpressionInJSON(loc));
-
- if(dml != null) {
- // Sort the instruction with time - so as to separate recompiled instructions
- List<String> listInst = new ArrayList<String>(instructions.get(loc));
- Collections.sort(listInst, new InstructionComparator(instructionCreationTime));
-
- if(!isFirstExpression) {
- retVal.append(",\n");
- }
- retVal.append("{\n");
- isFirstExpression = false;
-
- retVal.append(getInQuotes("beginLine") + ":" + loc.beginLine + ",\n");
- retVal.append(getInQuotes("beginCol") + ":" + loc.beginCol + ",\n");
- retVal.append(getInQuotes("endLine") + ":" + loc.endLine + ",\n");
- retVal.append(getInQuotes("endCol") + ":" + loc.endCol + ",\n");
-
- long expressionExecutionTime = expressionExecutionTimes.get(loc.toString());
- retVal.append(getInQuotes("expressionExecutionTime") + ":" + expressionExecutionTime + ",\n");
- retVal.append(getInQuotes("expressionHeavyHitterFactor") + ":" + ((double)expressionExecutionTime / (double)maxExpressionExecutionTime) + ",\n");
-
- retVal.append(getInQuotes("expression") + ":" + getInQuotes(dml) + ",\n");
-
- retVal.append(getInQuotes("instructions") + ":" + "[\n");
-
- boolean firstTime = true;
- for(String inst : listInst) {
-
- if(!firstTime)
- retVal.append(", {");
- else
- retVal.append("{");
-
- if(inst.startsWith("SPARK")) {
- retVal.append(getInQuotes("isSpark") + ":" + "true,\n");
- }
- else if(isInterestingCP(inst)) {
- retVal.append(getInQuotes("isInteresting") + ":" + "true,\n");
- }
-
- retVal.append(getStageIDAsJSONString(inst) + "\n");
- if(lineageInfo.containsKey(inst)) {
- retVal.append(getInQuotes("lineageInfo") + ":" + getInQuotes(getEscapedJSON(lineageInfo.get(inst))) + ",\n");
- }
-
- retVal.append(getInQuotes("instruction") + ":" + getInQuotes(getEscapedJSON(inst)));
- retVal.append("}");
- firstTime = false;
- }
-
- retVal.append("]\n");
- retVal.append("}\n");
- }
-
- }
-
- return retVal.append("]\n}").toString();
- }
-
-
- private boolean isInterestingCP(String inst) {
- if(inst.startsWith("CP rmvar") || inst.startsWith("CP cpvar") || inst.startsWith("CP mvvar"))
- return false;
- else if(inst.startsWith("CP"))
- return true;
- else
- return false;
- }
-
- private String getStageIDAsString(String instruction) {
- String retVal = "";
- for(Integer stageId : stageIDs.get(instruction)) {
- String stageDAG = "";
- String stageTimeLine = "";
-
- if(getStageDAGs(stageId) != null) {
- stageDAG = getStageDAGs(stageId).toString();
- }
-
- if(getStageTimeLine(stageId) != null) {
- stageTimeLine = getStageTimeLine(stageId).toString();
- }
-
- retVal += "Stage:" + stageId +
- " ("
- + "<div>"
- + stageDAG.replaceAll("toggleDagViz\\(false\\)", "toggleDagViz(false, this)")
- + "</div>, "
- + "<div id=\"timeline-" + stageId + "\">"
- + stageTimeLine
- .replaceAll("drawTaskAssignmentTimeline\\(", "registerTimelineData(" + stageId + ", ")
- .replaceAll("class=\"expand-task-assignment-timeline\"", "class=\"expand-task-assignment-timeline\" onclick=\"toggleStageTimeline(this)\"")
- + "</div>"
- + ")";
- }
- return retVal;
- }
-
- private String getStageIDAsJSONString(String instruction) {
- long instructionExecutionTime = instructionExecutionTimes.get(instruction);
-
- StringBuilder retVal = new StringBuilder(getInQuotes("instructionExecutionTime") + ":" + instructionExecutionTime + ",\n");
-
- boolean isFirst = true;
- if(stageIDs.get(instruction).size() == 0) {
- // Find back references
- HashSet<Integer> relatedStages = new HashSet<Integer>();
- for(Entry<Integer, HashSet<String>> kv : relatedInstructionsPerStage.entrySet()) {
- if(kv.getValue().contains(instruction)) {
- relatedStages.add(kv.getKey());
- }
- }
- HashSet<String> relatedInstructions = new HashSet<String>();
- for(Entry<String, Integer> kv : stageIDs.entries()) {
- if(relatedStages.contains(kv.getValue())) {
- relatedInstructions.add(kv.getKey());
- }
- }
-
- retVal.append(getInQuotes("backReferences") + ": [\n");
- boolean isFirstRelInst = true;
- for(String relInst : relatedInstructions) {
- if(!isFirstRelInst) {
- retVal.append(",\n");
- }
- retVal.append(getInQuotes(relInst));
- isFirstRelInst = false;
- }
- retVal.append("], \n");
- }
- else {
- retVal.append(getInQuotes("stages") + ": {");
- for(Integer stageId : stageIDs.get(instruction)) {
- String stageDAG = "";
- String stageTimeLine = "";
-
- if(getStageDAGs(stageId) != null) {
- stageDAG = getStageDAGs(stageId).toString();
- }
-
- if(getStageTimeLine(stageId) != null) {
- stageTimeLine = getStageTimeLine(stageId).toString();
- }
-
- long stageExecutionTime = stageExecutionTimes.get(stageId);
- if(!isFirst) {
- retVal.append(",\n");
- }
-
- retVal.append(getInQuotes("" + stageId) + ": {");
-
- // Now add related instructions
- HashSet<String> relatedInstructions = relatedInstructionsPerStage.get(stageId);
-
- retVal.append(getInQuotes("relatedInstructions") + ": [\n");
- boolean isFirstRelInst = true;
- for(String relInst : relatedInstructions) {
- if(!isFirstRelInst) {
- retVal.append(",\n");
- }
- retVal.append(getInQuotes(relInst));
- isFirstRelInst = false;
- }
- retVal.append("],\n");
-
- retVal.append(getInQuotes("DAG") + ":")
- .append(
- getInQuotes(
- getEscapedJSON(stageDAG.replaceAll("toggleDagViz\\(false\\)", "toggleDagViz(false, this)"))
- ) + ",\n"
- )
- .append(getInQuotes("stageExecutionTime") + ":" + stageExecutionTime + ",\n")
- .append(getInQuotes("timeline") + ":")
- .append(
- getInQuotes(
- getEscapedJSON(
- stageTimeLine
- .replaceAll("drawTaskAssignmentTimeline\\(", "registerTimelineData(" + stageId + ", ")
- .replaceAll("class=\"expand-task-assignment-timeline\"", "class=\"expand-task-assignment-timeline\" onclick=\"toggleStageTimeline(this)\""))
- )
- )
- .append("}");
-
- isFirst = false;
- }
- retVal.append("}, ");
- }
-
-
- retVal.append(getInQuotes("jobs") + ": {");
- isFirst = true;
- for(Integer jobId : jobIDs.get(instruction)) {
- String jobDAG = "";
-
- if(getJobDAGs(jobId) != null) {
- jobDAG = getJobDAGs(jobId).toString();
- }
- if(!isFirst) {
- retVal.append(",\n");
- }
-
- retVal.append(getInQuotes("" + jobId) + ": {")
- .append(getInQuotes("DAG") + ":" )
- .append(getInQuotes(
- getEscapedJSON(jobDAG.replaceAll("toggleDagViz\\(true\\)", "toggleDagViz(true, this)"))
- ) + "}\n");
-
- isFirst = false;
- }
- retVal.append("}, ");
-
- return retVal.toString();
- }
-
-
- String [] dmlLines = null;
- private String getExpression(Location loc) {
- try {
- if(dmlLines == null) {
- dmlLines = dmlStrForMonitoring.split("\\r?\\n");
- }
- if(loc.beginLine == loc.endLine) {
- return dmlLines[loc.beginLine-1].substring(loc.beginCol-1, loc.endCol);
- }
- else {
- String retVal = dmlLines[loc.beginLine-1].substring(loc.beginCol-1);
- for(int i = loc.beginLine+1; i < loc.endLine; i++) {
- retVal += "<br />" + dmlLines[i-1];
- }
- retVal += "<br />" + dmlLines[loc.endLine-1].substring(0, loc.endCol);
- return retVal;
- }
- }
- catch(Exception e) {
- return null; // "[[" + loc.beginLine + "," + loc.endLine + "," + loc.beginCol + "," + loc.endCol + "]]";
- }
- }
-
-
- private String getExpressionInJSON(Location loc) {
- try {
- if(dmlLines == null) {
- dmlLines = dmlStrForMonitoring.split("\\r?\\n");
- }
- if(loc.beginLine == loc.endLine) {
- return dmlLines[loc.beginLine-1].substring(loc.beginCol-1, loc.endCol);
- }
- else {
- String retVal = dmlLines[loc.beginLine-1].substring(loc.beginCol-1);
- for(int i = loc.beginLine+1; i < loc.endLine; i++) {
- retVal += "\\n" + dmlLines[i-1];
- }
- retVal += "\\n" + dmlLines[loc.endLine-1].substring(0, loc.endCol);
- return retVal;
- }
- }
- catch(Exception e) {
- return null; // "[[" + loc.beginLine + "," + loc.endLine + "," + loc.beginCol + "," + loc.endCol + "]]";
- }
- }
-
- public Seq<Node> getStageDAGs(int stageIDs) {
- if(_sparkListener == null || _sparkListener.stageDAGs == null)
- return null;
- else
- return _sparkListener.stageDAGs.get(stageIDs);
- }
-
- public Long getStageExecutionTime(int stageID) {
- if(_sparkListener == null || _sparkListener.stageDAGs == null)
- return null;
- else
- return _sparkListener.stageExecutionTime.get(stageID);
- }
-
- public Seq<Node> getJobDAGs(int jobID) {
- if(_sparkListener == null || _sparkListener.jobDAGs == null)
- return null;
- else
- return _sparkListener.jobDAGs.get(jobID);
- }
-
- public Seq<Node> getStageTimeLine(int stageIDs) {
- if(_sparkListener == null || _sparkListener.stageTimeline == null)
- return null;
- else
- return _sparkListener.stageTimeline.get(stageIDs);
- }
- public void setLineageInfo(Instruction inst, String plan) {
- lineageInfo.put(getInstructionString(inst), plan);
- }
- public void setStageId(Instruction inst, int stageId) {
- stageIDs.put(getInstructionString(inst), stageId);
- }
- public void setJobId(Instruction inst, int jobId) {
- jobIDs.put(getInstructionString(inst), jobId);
- }
- public void setInstructionLocation(Location loc, Instruction inst) {
- String instStr = getInstructionString(inst);
- instructions.put(loc, instStr);
- instructionCreationTime.put(instStr, System.currentTimeMillis());
- }
- private String getInstructionString(Instruction inst) {
- String tmp = inst.toString();
- tmp = tmp.replaceAll(Lop.OPERAND_DELIMITOR, " ");
- tmp = tmp.replaceAll(Lop.DATATYPE_PREFIX, ".");
- tmp = tmp.replaceAll(Lop.INSTRUCTION_DELIMITOR, ", ");
- return tmp;
- }
-}