You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2012/10/08 17:38:17 UTC

svn commit: r1395619 - in /commons/sandbox/nabla/trunk/src: main/java/org/apache/commons/nabla/forward/ main/java/org/apache/commons/nabla/forward/analysis/ test/java/org/apache/commons/nabla/forward/

Author: luc
Date: Mon Oct  8 15:38:15 2012
New Revision: 1395619

URL: http://svn.apache.org/viewvc?rev=1395619&view=rev
Log:
Work In Progress in handling of multi-argument functions.

This commit improves several part but breaks other parts, it is
only work in progress and some existing tests that did pass before the
changes do not pass anymore. This should be fixed with next commit.

Modified:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
    commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java?rev=1395619&r1=1395618&r2=1395619&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java Mon Oct  8 15:38:15 2012
@@ -243,6 +243,10 @@ public class ForwardModeDifferentiator i
                 final Class<? extends NablaDifferentiated> dClass =
                         new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
                 byteCodeMap.put(name, bytecode);
+ 
+                // TODO: remove development trace
+                new ClassReader(differentiableClass.getResourceAsStream("/" + Type.getInternalName(differentiableClass) + ".class")).accept(new TraceClassVisitor(new PrintWriter(System.out)), 0);
+                new ClassReader(bytecode).accept(new TraceClassVisitor(new PrintWriter(System.err)), 0);
 
                 if (differentiator.getPrimitive().name.equals(Type.getType(differentiableClass).getInternalName())) {
                     nudf = (Class<? extends NablaUnivariateDifferentiableFunction>) dClass;

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java?rev=1395619&r1=1395618&r2=1395619&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java Mon Oct  8 15:38:15 2012
@@ -157,7 +157,7 @@ public class ClassDifferentiator {
             if (method.name.equals(name) && Type.getType(method.desc).equals(primitiveMethodType)) {
 
                 final MethodDifferentiator differentiator = new MethodDifferentiator(mathClasses, this);
-                differentiator.differentiate(method, derivativedMethodType);
+                differentiator.differentiate(method, primitiveMethodType, derivativedMethodType);
                 classNode.methods.add(method);
 
             }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java?rev=1395619&r1=1395618&r2=1395619&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java Mon Oct  8 15:38:15 2012
@@ -53,6 +53,7 @@ import org.apache.commons.nabla.forward.
 import org.objectweb.asm.Opcodes;
 import org.objectweb.asm.Type;
 import org.objectweb.asm.tree.AbstractInsnNode;
+import org.objectweb.asm.tree.IincInsnNode;
 import org.objectweb.asm.tree.InsnList;
 import org.objectweb.asm.tree.InsnNode;
 import org.objectweb.asm.tree.MethodInsnNode;
@@ -115,15 +116,17 @@ public class MethodDifferentiator {
     /**
      * Differentiate a method.
      * @param method method to differentiate (<em>will</em> be modified)
+     * @param primitiveMethodType type of the method in the primitive class
      * @param derivedMethodType type of the derived method
      * @exception DifferentiationException if method cannot be differentiated
      */
-    public void differentiate(final MethodNode method, final Type derivedMethodType)
+    public void differentiate(final MethodNode method,
+                              final Type primitiveMethodType, final Type derivedMethodType)
         throws DifferentiationException {
         try {
 
-            final int     dsIndex  = method.maxLocals;
             final boolean isStatic = (method.access & Opcodes.ACC_STATIC) != 0;
+            final boolean[] usedLocals = new boolean[method.maxLocals + 1];
 
             // analyze the original code, tracing values production/consumption
             final FlowAnalyzer analyzer =
@@ -135,8 +138,14 @@ public class MethodDifferentiator {
                 frames.put(method.instructions.get(i), array[i]);
             }
 
+            // identify the needed changes in code
+            Set<AbstractInsnNode> changes =
+                    identifyChanges(method.name, usedLocals, primitiveMethodType.getArgumentTypes(),
+                                    derivedMethodType.getArgumentTypes(),
+                                    method.instructions, isStatic);
+
             // perform the code changes
-            for (final AbstractInsnNode insn : identifyChanges(method.instructions, isStatic)) {
+            for (final AbstractInsnNode insn : changes) {
                 method.instructions.insert(insn, getReplacement(insn, method.maxLocals));
                 method.instructions.remove(insn);
             }
@@ -146,16 +155,32 @@ public class MethodDifferentiator {
             new SwappedDstoreTrimmer().trim(method.instructions);
             new DLoadPop2Trimmer().trim(method.instructions);
 
-            // insert the preservation of the reference derivative structure
-            method.instructions.insert(preserveReferenceDerivativeStructure(derivedMethodType, isStatic, dsIndex));
+            // mark the local variables that are used
+            markUsedLocalVariables(method.instructions, usedLocals);
+
+            if (usedLocals[usedLocals.length - 1]) {
+                // insert the preservation of the reference derivative structure
+                // (we know we have reserved the last local variable for this)
+                method.instructions.insert(preserveReferenceDerivativeStructure(derivedMethodType, isStatic,
+                                                                                usedLocals.length - 1));
+            }
+
+            // remove the local variables added at the beginning and not used
+            removeUnusedLocalVariables(method.instructions, usedLocals);
+
+            // count the number of local variables really used
+            method.maxLocals = 0;
+            for (final boolean used : usedLocals) {
+                if (used) {
+                    ++method.maxLocals;
+                }
+            }
 
             // set the method properties
-            method.desc      = derivedMethodType.getDescriptor();
-            method.access   |= Opcodes.ACC_SYNTHETIC;
-            method.maxLocals = dsIndex + 1;
+            method.desc    = derivedMethodType.getDescriptor();
+            method.access |= Opcodes.ACC_SYNTHETIC;
 
         } catch (AnalyzerException ae) {
-            ae.printStackTrace(System.err);
             if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
                 throw (DifferentiationException) ae.getCause();
             } else {
@@ -165,6 +190,60 @@ public class MethodDifferentiator {
         }
     }
 
+    /** Mark local variables usage.
+     * @param instructions methods instructions
+     * @param usedLocals array of variables use indicators to fill in
+     */
+    private void markUsedLocalVariables(final InsnList instructions, final boolean[] usedLocals) {
+        for (final Iterator<AbstractInsnNode> i = instructions.iterator(); i.hasNext();) {
+            final AbstractInsnNode insn = i.next();
+            final int opcode = insn.getOpcode();
+            if (insn.getOpcode() == Opcodes.IINC) {
+                usedLocals[((IincInsnNode) insn).var] = true;
+            } else if (opcode == Opcodes.ILOAD  || opcode == Opcodes.FLOAD  || opcode == Opcodes.ALOAD ||
+                       opcode == Opcodes.ISTORE || opcode == Opcodes.FSTORE || opcode == Opcodes.ASTORE) {
+                usedLocals[((VarInsnNode) insn).var] = true;
+            } else if (opcode == Opcodes.DLOAD  || opcode == Opcodes.LLOAD ||
+                       opcode == Opcodes.DSTORE || opcode == Opcodes.LSTORE) {
+                usedLocals[((VarInsnNode) insn).var]     = true;
+                usedLocals[((VarInsnNode) insn).var + 1] = true;
+            }
+        }
+    }
+
+    /** Remove the unused spare cells introduced at conversion start.
+     * @param instructions instructions of the method
+     * @param usedLocals array of variables use indicators to fill in
+     */
+    private void removeUnusedLocalVariables(final InsnList instructions, final boolean[] usedLocals) {
+
+        for (final Iterator<AbstractInsnNode> i = instructions.iterator(); i.hasNext();) {
+            final AbstractInsnNode insn = i.next();
+            if (insn.getType() == AbstractInsnNode.VAR_INSN) {
+                // shift the index of the instruction, collapsing unused local variables
+                final VarInsnNode varInsn = (VarInsnNode) insn;
+                int shifted = 0;
+                for (int j = 0; j < varInsn.var; ++j) {
+                    if (usedLocals[j]) {
+                        ++shifted;
+                    }
+                }
+                varInsn.var = shifted;
+            } else if (insn.getType() == AbstractInsnNode.IINC_INSN) {
+                // shift the index of the instruction, collapsing unused local variables
+                final IincInsnNode iincInsn = (IincInsnNode) insn;
+                int shifted = 0;
+                for (int j = 0; j < iincInsn.var; ++j) {
+                    if (usedLocals[j]) {
+                        ++shifted;
+                    }
+                }
+                iincInsn.var = shifted;
+            }
+        }
+
+    }
+
     /** Request differentiation of a method.
      * @param owner class in which the method is defined
      * @param isStatic if true, the method is static
@@ -195,25 +274,29 @@ public class MethodDifferentiator {
      *       converted or not, as in some branch codes the value may return
      *       simple constants like "return 0").</li>
      * </ul>
+     * @param name method name
+     * @param usedLocals array of variables use indicators to fill in
+     * @param primitiveArguments type of the method arguments in the primitive class
+     * @param derivedArguments type of the method arguments in the derived class
      * @param instructions instructions of the method
      * @param isStatic if true, the method is a static method
      * @return set containing all the instructions that must be changed
+     * @exception DifferentiationException if some unsupported bytecode is found
      */
-    private Set<AbstractInsnNode> identifyChanges(final InsnList instructions, final boolean isStatic) {
+    private Set<AbstractInsnNode> identifyChanges(final String name, final boolean[] usedLocals,
+                                                  final Type[] primitiveArguments, final Type[] derivedArguments,
+                                                  final InsnList instructions, final boolean isStatic)
+        throws DifferentiationException {
 
-        // the pending set contains the values (local variables or stack cells)
-        // that have been changed, they will trigger changes on the instructions
-        // that consume them
-        final Set<TrackingValue> pending = new HashSet<TrackingValue>();
+        // the pending set contains the values (local variables or stack cells) that have
+        // been changed, they will trigger changes on the instructions that consume them,
+        // we bootstrap the analysis by looking at changed method arguments
+        final Set<TrackingValue> pending =
+                identifyArguments(isStatic, instructions, primitiveArguments, derivedArguments, usedLocals);
 
         // the changes set contains the instructions that must be changed
         final Set<AbstractInsnNode> changes = new HashSet<AbstractInsnNode>();
 
-        // start by converting the parameter of the method,
-        // which is kept in local variable 0 or 1 of the initial frame (depending on the method being static or not)
-        final TrackingValue dpParameter = frames.get(instructions.get(0)).getLocal(isStatic ? 0 : 1);
-        pending.add(dpParameter);
-
         // propagate the values conversions throughout the method
         while (!pending.isEmpty()) {
 
@@ -262,6 +345,89 @@ public class MethodDifferentiator {
 
     }
 
+    /** Identify how method arguments are used.
+     * @param isStatic if true, the method is static
+     * @param instructions instructions of the method
+     * @param primitiveArguments type of the method arguments in the primitive class
+     * @param derivedArguments type of the method arguments in the derived class
+     * @param usedLocals array of variables use indicators to fill in
+     * @return set of argument local variables that are changed
+     */
+    private Set<TrackingValue> identifyArguments(final boolean isStatic, final InsnList instructions,
+                                                 final Type[] primitiveArguments, final Type[] derivedArguments,
+                                                 final boolean[] usedLocals) {
+
+        final Set<TrackingValue> changedValues = new HashSet<TrackingValue>();
+
+        int index = 0;
+        if (!isStatic) {
+            // non-static methods use variable 0 for "this"
+            usedLocals[index++] = true;
+        }
+
+        // start by converting the arguments of the method
+        final Frame<TrackingValue> initialFrame = frames.get(instructions.get(0));
+        for (int i = 0; i < primitiveArguments.length; ++i) {
+            if (!primitiveArguments[i].equals(derivedArguments[i])) {
+
+                // the argument type is changed, we have to track the instructions that depend on it
+                changedValues.add(initialFrame.getLocal(index));
+
+                // the second half of this argument local variable will disappear
+                // TODO: should Nabla support such arguments override?
+                for (ListIterator<AbstractInsnNode> iterator = instructions.iterator(); iterator.hasNext();) {
+                    final AbstractInsnNode insn = iterator.next();
+                    if (insn.getType() == AbstractInsnNode.VAR_INSN) {
+                        final int var = ((VarInsnNode) insn).var;
+                        if (var == index) {
+                            if (insn.getOpcode() == Opcodes.DSTORE || insn.getOpcode() == Opcodes.DLOAD) {
+                                // this is an expected use of the argument normally as a double
+                                // we need to change the type and track it
+                                changedValues.add(frames.get(insn).getLocal(index));
+                            } else {
+                                // the argument is overridden by some other variable type
+                                // for now we don't know how to handle this case
+                                throw DifferentiationException.createInternalError(null);
+                            }
+                        } else if (var == index + 1) {
+                            // the second half of the argument is overridden by some other variable type
+                            // for now we don't know how to handle this case
+                            throw DifferentiationException.createInternalError(null); 
+                        }
+                    } else if (insn.getType() == AbstractInsnNode.IINC_INSN) {
+                        final int var = ((IincInsnNode) insn).var;
+                        if (var == index || var == index + 1) {
+                            // the argument is overridden by some other variable type
+                            // for now we don't know how to handle this case
+                            throw DifferentiationException.createInternalError(null); 
+                        }
+                    }
+                }
+
+                // mark the first part of the argument as used
+                // the second part will not be used anymore after the change
+                // so the following variables indices will be shifted later on
+                // to remove the unused slot
+                usedLocals[index] = true;
+
+            } else {
+
+                // mark the argument as used
+                usedLocals[index] = true;
+                if (primitiveArguments[i].getSize() > 1) {
+                    usedLocals[index + 1] = true;
+                }
+
+            }
+
+            index += primitiveArguments[i].getSize();
+
+        }
+
+        return changedValues;
+
+    }
+
     /** Get the list of double values produced by an instruction and not yet converted.
      * @param instruction instruction producing the values
      * @return list of double values produced
@@ -285,6 +451,14 @@ public class MethodDifferentiator {
                 final Frame<TrackingValue> produced = frames.get(successor);
 
                 // check the stack cells
+                for (int i = 0; i < before.getStackSize(); ++i) {
+                    final TrackingValue value = before.getStack(i);
+                    if (((i >= produced.getStackSize()) || (value != produced.getStack(i))) &&
+                        value.getType().equals(Type.DOUBLE_TYPE) &&
+                        !converted.contains(value)) {
+                        values.add(value);
+                    }
+                }
                 for (int i = 0; i < produced.getStackSize(); ++i) {
                     final TrackingValue value = produced.getStack(i);
                     if (((i >= beforeStackSize) || (value != before.getStack(i))) &&
@@ -431,15 +605,18 @@ public class MethodDifferentiator {
 
         final Type dsType = Type.getType(DerivativeStructure.class);
         final Type[] parameterTypes = derivedMethodType.getArgumentTypes();
+        int var = isStatic ? 0 : 1;
         for (int i = 0; i < parameterTypes.length; ++i) {
             if (parameterTypes[i].equals(dsType)) {
                 // we have found the first derivative structure parameter
 
                 // preserve the parameter as a new variable
                 final InsnList list = new InsnList();
-                list.add(new VarInsnNode(Opcodes.ALOAD, isStatic ? i : (i + 1)));
+                list.add(new VarInsnNode(Opcodes.ALOAD, var));
                 list.add(new VarInsnNode(Opcodes.ASTORE, dsIndex));
                 return list;
+            } else {
+                var += parameterTypes[i].getSize();
             }
 
         }
@@ -468,7 +645,7 @@ public class MethodDifferentiator {
                                     Type.getInternalName(DerivativeStructure.class),
                                     "getFreeParameters",
                                     Type.getMethodDescriptor(Type.INT_TYPE)));       // => y_ds y_ds d params
-        list.add(new VarInsnNode(Opcodes.ALOAD, 1));                                 // => y_ds y_ds d params x_ds
+        list.add(new VarInsnNode(Opcodes.ALOAD, dsIndex));                           // => y_ds y_ds d params x_ds
         list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL,
                                     Type.getInternalName(DerivativeStructure.class),
                                     "getOrder",

Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java?rev=1395619&r1=1395618&r2=1395619&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java (original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java Mon Oct  8 15:38:15 2012
@@ -119,21 +119,33 @@ public class ForwardModeDifferentiatorTe
     }
 
     @Test
-    public void testEmbeddedInvoke() {
+    public void testArgumentOverride() {
+        checkReference(new ReferenceFunction() {
+            public double value(double t) {
+                t = 1.0; // here, we override the argument with something that does not depend on it
+                return 2 * t;
+            }
+            public double firstDerivative(double t) { return 0.0; }
+        }, -5, 5, 20, 8.0e-15);
+    }
+
+    @Test
+    public void testMultipleFunctions() {
         checkReference(new ReferenceFunction() {
             private double f(double t) {
                 return 2 * t;
             }
-            private double g(double h) {
-                return h * h;
+            private double g(double h, double a) {
+                a = a + 1;
+                return h * a;
             }
-            private double h(double t) {
-                return t - 1;
+            private double h(long cL, double cD, int cI) {
+                return cL + cD + cI;
             }
             public double value(double t) {
-                return f(t) + g(h(t));
+                return f(t) * g(h(-1l, t, 2), 2.0);
             }
-            public double firstDerivative(double t) { return 2 * t; }
+            public double firstDerivative(double t) { return 12 * t + 6; }
         }, -5, 5, 20, 8.0e-15);
     }