You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by tn...@apache.org on 2013/12/17 23:11:46 UTC

svn commit: r1551735 - in /commons/proper/math/trunk/src: changes/changes.xml main/java/org/apache/commons/math3/optim/linear/SimplexSolver.java main/java/org/apache/commons/math3/optim/linear/SimplexTableau.java

Author: tn
Date: Tue Dec 17 22:11:45 2013
New Revision: 1551735

URL: http://svn.apache.org/r1551735
Log:
[MATH-1079] Improve performance of SimplexSolver.

Modified:
    commons/proper/math/trunk/src/changes/changes.xml
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexSolver.java
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexTableau.java

Modified: commons/proper/math/trunk/src/changes/changes.xml
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/changes/changes.xml?rev=1551735&r1=1551734&r2=1551735&view=diff
==============================================================================
--- commons/proper/math/trunk/src/changes/changes.xml (original)
+++ commons/proper/math/trunk/src/changes/changes.xml Tue Dec 17 22:11:45 2013
@@ -51,6 +51,10 @@ If the output is not quite correct, chec
   </properties>
   <body>
     <release version="3.3" date="TBD" description="TBD">
+      <action dev="tn" type="fix" issue="MATH-1079">
+        Improved performance of "SimplexSolver" in package o.a.c.math3.optim.linear by
+        directly performing row operations and keeping track of the current basic variables.
+      </action>
       <action dev="tn" type="update" issue="MATH-1080">
         The "LinearConstraintSet" will now return the enclosed collection of "LinearConstraint"
         objects in the same order as they have been added.

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexSolver.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexSolver.java?rev=1551735&r1=1551734&r2=1551735&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexSolver.java (original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexSolver.java Tue Dec 17 22:11:45 2013
@@ -258,7 +258,7 @@ public class SimplexSolver extends Linea
                     minRatioPositions.add(i);
                 } else if (cmp < 0) {
                     minRatio = ratio;
-                    minRatioPositions = new ArrayList<Integer>();
+                    minRatioPositions.clear();
                     minRatioPositions.add(i);
                 }
             }
@@ -290,15 +290,11 @@ public class SimplexSolver extends Linea
 
             Integer minRow = null;
             int minIndex = tableau.getWidth();
-            final int varStart = tableau.getNumObjectiveFunctions();
-            final int varEnd = tableau.getWidth() - 1;
             for (Integer row : minRatioPositions) {
-                for (int i = varStart; i < varEnd && !row.equals(minRow); i++) {
-                    final Integer basicRow = tableau.getBasicRow(i);
-                    if (basicRow != null && basicRow.equals(row) && i < minIndex) {
-                        minIndex = i;
-                        minRow = row;
-                    }
+                final int basicVar = tableau.getBasicVariable(row);
+                if (basicVar < minIndex) {
+                    minIndex = basicVar;
+                    minRow = row;
                 }
             }
             return minRow;
@@ -325,17 +321,7 @@ public class SimplexSolver extends Linea
             throw new UnboundedSolutionException();
         }
 
-        // set the pivot element to 1
-        double pivotVal = tableau.getEntry(pivotRow, pivotCol);
-        tableau.divideRow(pivotRow, pivotVal);
-
-        // set the rest of the pivot column to 0
-        for (int i = 0; i < tableau.getHeight(); i++) {
-            if (i != pivotRow) {
-                final double multiplier = tableau.getEntry(i, pivotCol);
-                tableau.subtractRow(i, pivotRow, multiplier);
-            }
-        }
+        tableau.performRowOperations(pivotCol, pivotRow);
     }
 
     /**

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexTableau.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexTableau.java?rev=1551735&r1=1551734&r2=1551735&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexTableau.java (original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/linear/SimplexTableau.java Tue Dec 17 22:11:45 2013
@@ -21,6 +21,7 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashSet;
 import java.util.List;
@@ -29,7 +30,6 @@ import java.util.TreeSet;
 
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.MatrixUtils;
-import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
 import org.apache.commons.math3.optim.PointValuePair;
@@ -82,7 +82,7 @@ class SimplexTableau implements Serializ
     private final List<String> columnLabels = new ArrayList<String>();
 
     /** Simple tableau. */
-    private transient RealMatrix tableau;
+    private transient Array2DRowRealMatrix tableau;
 
     /** Number of decision variables. */
     private final int numDecisionVariables;
@@ -102,6 +102,12 @@ class SimplexTableau implements Serializ
     /** Cut-off value for entries in the tableau. */
     private final double cutOff;
 
+    /** Maps basic variables to row they are basic in. */
+    private int[] basicVariables;
+    
+    /** Maps rows to their corresponding basic variables. */
+    private int[] basicRows;
+    
     /**
      * Builds a tableau for a linear problem.
      *
@@ -118,7 +124,7 @@ class SimplexTableau implements Serializ
                    final boolean restrictToNonNegative,
                    final double epsilon) {
         this(f, constraints, goalType, restrictToNonNegative, epsilon,
-                SimplexSolver.DEFAULT_ULPS, SimplexSolver.DEFAULT_CUT_OFF);
+             SimplexSolver.DEFAULT_ULPS, SimplexSolver.DEFAULT_CUT_OFF);
     }
 
     /**
@@ -162,13 +168,15 @@ class SimplexTableau implements Serializ
         this.epsilon                = epsilon;
         this.maxUlps                = maxUlps;
         this.cutOff                 = cutOff;
-        this.numDecisionVariables   = f.getCoefficients().getDimension() +
-                                      (restrictToNonNegative ? 0 : 1);
+        this.numDecisionVariables   = f.getCoefficients().getDimension() + (restrictToNonNegative ? 0 : 1);
         this.numSlackVariables      = getConstraintTypeCounts(Relationship.LEQ) +
                                       getConstraintTypeCounts(Relationship.GEQ);
         this.numArtificialVariables = getConstraintTypeCounts(Relationship.EQ) +
                                       getConstraintTypeCounts(Relationship.GEQ);
         this.tableau = createTableau(goalType == GoalType.MAXIMIZE);
+        // initialize the basic variables for phase 1:
+        //   we know that only slack or artificial variables can be basic
+        initializeBasicVariables(getSlackVariableOffset());
         initializeColumnLabels();
     }
 
@@ -200,7 +208,7 @@ class SimplexTableau implements Serializ
      * @param maximize if true, goal is to maximize the objective function
      * @return created tableau
      */
-    protected RealMatrix createTableau(final boolean maximize) {
+    protected Array2DRowRealMatrix createTableau(final boolean maximize) {
 
         // create a matrix of the correct size
         int width = numDecisionVariables + numSlackVariables +
@@ -212,17 +220,16 @@ class SimplexTableau implements Serializ
         if (getNumObjectiveFunctions() == 2) {
             matrix.setEntry(0, 0, -1);
         }
+
         int zIndex = (getNumObjectiveFunctions() == 1) ? 0 : 1;
         matrix.setEntry(zIndex, zIndex, maximize ? 1 : -1);
-        RealVector objectiveCoefficients =
-            maximize ? f.getCoefficients().mapMultiply(-1) : f.getCoefficients();
+        RealVector objectiveCoefficients = maximize ? f.getCoefficients().mapMultiply(-1) : f.getCoefficients();
         copyArray(objectiveCoefficients.toArray(), matrix.getDataRef()[zIndex]);
-        matrix.setEntry(zIndex, width - 1,
-            maximize ? f.getConstantTerm() : -1 * f.getConstantTerm());
+        matrix.setEntry(zIndex, width - 1, maximize ? f.getConstantTerm() : -1 * f.getConstantTerm());
 
         if (!restrictToNonNegative) {
             matrix.setEntry(zIndex, getSlackVariableOffset() - 1,
-                getInvertedCoefficientSum(objectiveCoefficients));
+                            getInvertedCoefficientSum(objectiveCoefficients));
         }
 
         // initialize the constraint rows
@@ -238,7 +245,7 @@ class SimplexTableau implements Serializ
             // x-
             if (!restrictToNonNegative) {
                 matrix.setEntry(row, getSlackVariableOffset() - 1,
-                    getInvertedCoefficientSum(constraint.getCoefficients()));
+                                getInvertedCoefficientSum(constraint.getCoefficients()));
             }
 
             // RHS
@@ -253,7 +260,7 @@ class SimplexTableau implements Serializ
 
             // artificial variables
             if ((constraint.getRelationship() == Relationship.EQ) ||
-                    (constraint.getRelationship() == Relationship.GEQ)) {
+                (constraint.getRelationship() == Relationship.GEQ)) {
                 matrix.setEntry(0, getArtificialVariableOffset() + artificialVar, 1);
                 matrix.setEntry(row, getArtificialVariableOffset() + artificialVar++, 1);
                 matrix.setRowVector(0, matrix.getRowVector(0).subtract(matrix.getRowVector(row)));
@@ -333,6 +340,44 @@ class SimplexTableau implements Serializ
      * @return the row that the variable is basic in.  null if the column is not basic
      */
     protected Integer getBasicRow(final int col) {
+        final int row = basicVariables[col];
+        return row == -1 ? null : row;
+    }
+    
+    /**
+     * Returns the variable that is basic in this row.
+     * @param row the index of the row to check
+     * @return the variable that is basic for this row.
+     */
+    protected int getBasicVariable(final int row) {
+        return basicRows[row];
+    }
+
+    /**
+     * Initializes the basic variable / row mapping.
+     * @param startColumn the column to start
+     */
+    private void initializeBasicVariables(final int startColumn) {
+        basicVariables = new int[getWidth() - 1];
+        basicRows = new int[getHeight()];
+
+        Arrays.fill(basicVariables, -1);
+
+        for (int i = startColumn; i < getWidth() - 1; i++) {
+            Integer row = findBasicRow(i);
+            if (row != null) {
+                basicVariables[i] = row;
+                basicRows[row] = i;
+            }
+        }
+    }
+
+    /**
+     * Returns the row in which the given column is basic.
+     * @param col index of the column
+     * @return the row that the variable is basic in, or {@code null} if the variable is not basic.
+     */
+    private Integer findBasicRow(final int col) {
         Integer row = null;
         for (int i = 0; i < getHeight(); i++) {
             final double entry = getEntry(i, col);
@@ -354,12 +399,12 @@ class SimplexTableau implements Serializ
             return;
         }
 
-        Set<Integer> columnsToDrop = new TreeSet<Integer>();
+        final Set<Integer> columnsToDrop = new TreeSet<Integer>();
         columnsToDrop.add(0);
 
         // positive cost non-artificial variables
         for (int i = getNumObjectiveFunctions(); i < getArtificialVariableOffset(); i++) {
-            final double entry = tableau.getEntry(0, i);
+            final double entry = getEntry(0, i);
             if (Precision.compareTo(entry, 0d, epsilon) > 0) {
                 columnsToDrop.add(i);
             }
@@ -373,12 +418,12 @@ class SimplexTableau implements Serializ
             }
         }
 
-        double[][] matrix = new double[getHeight() - 1][getWidth() - columnsToDrop.size()];
+        final double[][] matrix = new double[getHeight() - 1][getWidth() - columnsToDrop.size()];
         for (int i = 1; i < getHeight(); i++) {
             int col = 0;
             for (int j = 0; j < getWidth(); j++) {
                 if (!columnsToDrop.contains(j)) {
-                    matrix[i - 1][col++] = tableau.getEntry(i, j);
+                    matrix[i - 1][col++] = getEntry(i, j);
                 }
             }
         }
@@ -391,6 +436,8 @@ class SimplexTableau implements Serializ
 
         this.tableau = new Array2DRowRealMatrix(matrix);
         this.numArtificialVariables = 0;
+        // need to update the basic variable mappings as row/columns have been dropped
+        initializeBasicVariables(getNumObjectiveFunctions());
     }
 
     /**
@@ -406,8 +453,10 @@ class SimplexTableau implements Serializ
      * @return whether the model has been solved
      */
     boolean isOptimal() {
-        for (int i = getNumObjectiveFunctions(); i < getWidth() - 1; i++) {
-            final double entry = tableau.getEntry(0, i);
+        final double[] objectiveFunctionRow = getRow(0);
+        final int end = getRhsOffset();
+        for (int i = getNumObjectiveFunctions(); i < end; i++) {
+            final double entry = objectiveFunctionRow[i];
             if (Precision.compareTo(entry, 0d, epsilon) < 0) {
                 return false;
             }
@@ -424,8 +473,8 @@ class SimplexTableau implements Serializ
         Integer negativeVarBasicRow = negativeVarColumn > 0 ? getBasicRow(negativeVarColumn) : null;
         double mostNegative = negativeVarBasicRow == null ? 0 : getEntry(negativeVarBasicRow, getRhsOffset());
 
-        Set<Integer> basicRows = new HashSet<Integer>();
-        double[] coefficients = new double[getOriginalNumDecisionVariables()];
+        final Set<Integer> basicRows = new HashSet<Integer>();
+        final double[] coefficients = new double[getOriginalNumDecisionVariables()];
         for (int i = 0; i < coefficients.length; i++) {
             int colIndex = columnLabels.indexOf("x" + i);
             if (colIndex < 0) {
@@ -453,6 +502,32 @@ class SimplexTableau implements Serializ
     }
 
     /**
+     * Perform the row operations of the simplex algorithm with the selected
+     * pivot column and row.
+     * @param pivotCol the pivot column
+     * @param pivotRow the pivot row
+     */
+    protected void performRowOperations(int pivotCol, int pivotRow) {
+        // set the pivot element to 1
+        final double pivotVal = getEntry(pivotRow, pivotCol);
+        divideRow(pivotRow, pivotVal);
+
+        // set the rest of the pivot column to 0
+        for (int i = 0; i < getHeight(); i++) {
+            if (i != pivotRow) {
+                final double multiplier = getEntry(i, pivotCol);
+                subtractRow(i, pivotRow, multiplier);
+            }
+        }
+
+        // update the basic variable mappings
+        final int previousBasicVariable = getBasicVariable(pivotRow);
+        basicVariables[previousBasicVariable] = -1;
+        basicVariables[pivotCol] = pivotRow;
+        basicRows[pivotRow] = pivotCol;
+    }
+
+    /**
      * Divides one row by a given divisor.
      * <p>
      * After application of this operation, the following will hold:
@@ -461,9 +536,10 @@ class SimplexTableau implements Serializ
      * @param dividendRow index of the row
      * @param divisor value of the divisor
      */
-    protected void divideRow(final int dividendRow, final double divisor) {
+    protected void divideRow(final int dividendRowIndex, final double divisor) {
+        final double[] dividendRow = getRow(dividendRowIndex);
         for (int j = 0; j < getWidth(); j++) {
-            tableau.setEntry(dividendRow, j, tableau.getEntry(dividendRow, j) / divisor);
+            dividendRow[j] /= divisor;
         }
     }
 
@@ -477,15 +553,16 @@ class SimplexTableau implements Serializ
      * @param subtrahendRow row index
      * @param multiple multiplication factor
      */
-    protected void subtractRow(final int minuendRow, final int subtrahendRow,
-                               final double multiple) {
+    protected void subtractRow(final int minuendRowIndex, final int subtrahendRowIndex, final double multiplier) {
+        final double[] minuendRow = getRow(minuendRowIndex);
+        final double[] subtrahendRow = getRow(subtrahendRowIndex);
         for (int i = 0; i < getWidth(); i++) {
-            double result = tableau.getEntry(minuendRow, i) - tableau.getEntry(subtrahendRow, i) * multiple;
+            double result = minuendRow[i] - subtrahendRow[i] * multiplier;
             // cut-off values smaller than the cut-off threshold, otherwise may lead to numerical instabilities
-            if (FastMath.abs(result) < cutOff) {
+            if (result != 0.0 && FastMath.abs(result) < cutOff) {
                 result = 0.0;
             }
-            tableau.setEntry(minuendRow, i, result);
+            minuendRow[i] = result;
         }
     }
 
@@ -521,8 +598,7 @@ class SimplexTableau implements Serializ
      * @param column column index
      * @param value for the entry
      */
-    protected final void setEntry(final int row, final int column,
-                                  final double value) {
+    protected final void setEntry(final int row, final int column, final double value) {
         tableau.setEntry(row, column, value);
     }
 
@@ -589,6 +665,15 @@ class SimplexTableau implements Serializ
     }
 
     /**
+     * Get the row from the tableau.
+     * @param row the row index
+     * @return the reference to the underlying row data
+     */
+    protected final double[] getRow(int row) {
+        return tableau.getDataRef()[row];
+    }
+
+    /**
      * Get the tableau data.
      * @return tableau data
      */