You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2012/10/08 17:38:44 UTC

svn commit: r1395621 - in /commons/sandbox/nabla/trunk/src: main/java/org/apache/commons/nabla/forward/ main/java/org/apache/commons/nabla/forward/analysis/ main/java/org/apache/commons/nabla/forward/instructions/ test/java/org/apache/commons/nabla/for...

Author: luc
Date: Mon Oct  8 15:38:44 2012
New Revision: 1395621

URL: http://svn.apache.org/viewvc?rev=1395621&view=rev
Log:
Fixed differentiation with multiple arguments.

Differentiation with multiple arguments now works. It does not change
the types of all arguments, but only the necessary ones. This means that
differentiating f(double, double) generates one specific functions
f(DerivativeStructure, double) when only the first argument must be
differentiated. Another separate function can be generated at the same
time for f(double, DerivativeStructure) when only the second argument
must be differentiated.

Modified:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
    commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java?rev=1395621&r1=1395620&r2=1395621&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java Mon Oct  8 15:38:44 2012
@@ -116,25 +116,25 @@ public class ForwardModeDifferentiator i
      * @param isStatic if true, the method is static
      * @param method method name
      * @param primitiveMethodType method type in the primitive (includes return and arguments types)
-     * @return type of the differentiated method
+     * @param differentiatedMethodType method type in the differentiated class (includes return and arguments types)
      * @exception DifferentiationException if class cannot be found
      */
-    public Type requestMethodDifferentiation(final String owner, final boolean isStatic,
-                                             final String method, final Type primitiveMethodType)
+    public void requestMethodDifferentiation(final String owner, final boolean isStatic,
+                                             final String method, final Type primitiveMethodType,
+                                             final Type differentiatedMethodType)
         throws DifferentiationException {
 
         try {
 
             final DifferentiableMethod dm =
-                    new DifferentiableMethod(Class.forName(owner), isStatic, method, primitiveMethodType);
+                    new DifferentiableMethod(Class.forName(owner), isStatic,
+                                             method, primitiveMethodType, differentiatedMethodType);
 
             if (!processedDifferentiations.contains(dm)) {
                 // schedule the request if method has not been processed yet
                 pendingDifferentiations.add(dm);
             }
  
-            return dm.getDifferentiatedMethodType();
-
         } catch (ClassNotFoundException cnfe) {
             throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
                                                owner, cnfe.getMessage());
@@ -199,8 +199,10 @@ public class ForwardModeDifferentiator i
             final Set<ClassDifferentiator> differentiators = new HashSet<ClassDifferentiator>();
 
             // bootstrap differentiation using the top level value function from the UnivariateFunction interface
+            final Type dsType = Type.getType(DerivativeStructure.class);
             requestMethodDifferentiation(differentiableClass.getName(), false, "value",
-                                         Type.getMethodType(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE));
+                                         Type.getMethodType(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
+                                         Type.getMethodType(dsType, dsType));
 
             while (!pendingDifferentiations.isEmpty()) {
 
@@ -300,13 +302,24 @@ public class ForwardModeDifferentiator i
         /** Type of the method in the primitive class. */
         private final Type primitiveMethodType;
 
-        /** Simple constructor. */
+        /** Type of the method in the differentiated class. */
+        private final Type differentiatedMethodType;
+
+        /** Simple constructor.
+         * @param primitiveClass class in which the method is defined
+         * @param isStatic if true, the method is static
+         * @param method method name
+         * @param primitiveMethodType method type in the primitive (includes return and arguments types)
+         * @param differentiatedMethodType method type in the differentiated class (includes return and arguments types)
+         */
         public DifferentiableMethod(final Class<?> primitiveClass, final boolean isStatic,
-                                    final String method, final Type primitiveMethodType) {
-            this.primitiveClass       = primitiveClass;
-            this.isStatic             = isStatic;
-            this.method               = method;
-            this.primitiveMethodType  = primitiveMethodType;
+                                    final String method, final Type primitiveMethodType,
+                                    final Type differentiatedMethodType) {
+            this.primitiveClass           = primitiveClass;
+            this.isStatic                 = isStatic;
+            this.method                   = method;
+            this.primitiveMethodType      = primitiveMethodType;
+            this.differentiatedMethodType = differentiatedMethodType;
         }
 
         /** Get the primitive class to which the method belongs.
@@ -334,21 +347,7 @@ public class ForwardModeDifferentiator i
          * @return type of the method in the differentiated class
          */
         public Type getDifferentiatedMethodType() {
-
-            // transform arguments types
-            final Type[] argumentsTypes = primitiveMethodType.getArgumentTypes();
-            for (int i = 0; i < argumentsTypes.length; ++i) {
-                if (argumentsTypes[i].equals(Type.DOUBLE_TYPE)) {
-                    argumentsTypes[i] = Type.getType(DerivativeStructure.class);
-                }
-            }
-
-            // transform return type
-            final Type returnType = primitiveMethodType.getReturnType().equals(Type.DOUBLE_TYPE) ?
-                                    Type.getType(DerivativeStructure.class) : primitiveMethodType.getReturnType();
-
-            return Type.getMethodType(returnType, argumentsTypes);
-
+            return differentiatedMethodType;
         }
 
         /** {@inheritDoc} */

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java?rev=1395621&r1=1395620&r2=1395621&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java Mon Oct  8 15:38:44 2012
@@ -153,12 +153,14 @@ public class ClassDifferentiator {
                                     final Type derivativedMethodType)
         throws DifferentiationException {
 
-        for (final MethodNode method : primitiveNode.methods) {
-            if (method.name.equals(name) && Type.getType(method.desc).equals(primitiveMethodType)) {
+        for (final MethodNode primitiveMethod : primitiveNode.methods) {
+            if (primitiveMethod.name.equals(name) && Type.getType(primitiveMethod.desc).equals(primitiveMethodType)) {
 
                 final MethodDifferentiator differentiator = new MethodDifferentiator(mathClasses, this);
-                differentiator.differentiate(method, primitiveMethodType, derivativedMethodType);
-                classNode.methods.add(method);
+                final MethodNode differentiatedMethod     = differentiator.differentiate(primitiveMethod,
+                                                                                         primitiveMethodType,
+                                                                                         derivativedMethodType);
+                classNode.methods.add(differentiatedMethod);
 
             }
         }
@@ -170,12 +172,14 @@ public class ClassDifferentiator {
      * @param isStatic if true, the method is static
      * @param method method name
      * @param primitiveMethodType method type in the primitive (includes return and arguments types)
-     * @return type of the differentiated method
+     * @param differentiatedMethodType method type in the differentiated class (includes return and arguments types)
      * @exception DifferentiationException if class cannot be found
      */
-    public Type requestMethodDifferentiation(final String owner, final boolean isStatic,
-                                             final String method, final Type primitiveMethodType) {
-        return forwardDifferentiator.requestMethodDifferentiation(owner, isStatic, method, primitiveMethodType);
+    public void requestMethodDifferentiation(final String owner, final boolean isStatic,
+                                             final String method, final Type primitiveMethodType,
+                                             final Type differentiatedMethodType) {
+        forwardDifferentiator.requestMethodDifferentiation(owner, isStatic,
+                                                           method, primitiveMethodType, differentiatedMethodType);
     }
 
 }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java?rev=1395621&r1=1395620&r2=1395621&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java Mon Oct  8 15:38:44 2012
@@ -115,16 +115,24 @@ public class MethodDifferentiator {
 
     /**
      * Differentiate a method.
-     * @param method method to differentiate (<em>will</em> be modified)
+     * @param primitiveMethod method to differentiate
      * @param primitiveMethodType type of the method in the primitive class
      * @param derivedMethodType type of the derived method
+     * @return differentiated method
      * @exception DifferentiationException if method cannot be differentiated
      */
-    public void differentiate(final MethodNode method,
-                              final Type primitiveMethodType, final Type derivedMethodType)
+    public MethodNode differentiate(final MethodNode primitiveMethod,
+                                    final Type primitiveMethodType, final Type derivedMethodType)
         throws DifferentiationException {
         try {
 
+            // copy the primitive method as a new independent node
+            final MethodNode method =
+                    new MethodNode(primitiveMethod.access | Opcodes.ACC_SYNTHETIC,
+                                   primitiveMethod.name, derivedMethodType.getDescriptor(),
+                                   null, primitiveMethod.exceptions.toArray(new String[primitiveMethod.exceptions.size()]));
+            primitiveMethod.accept(method);
+
             final boolean isStatic = (method.access & Opcodes.ACC_STATIC) != 0;
             final boolean[] usedLocals = new boolean[method.maxLocals + 1];
 
@@ -176,16 +184,14 @@ public class MethodDifferentiator {
                 }
             }
 
-            // set the method properties
-            method.desc    = derivedMethodType.getDescriptor();
-            method.access |= Opcodes.ACC_SYNTHETIC;
+            return method;
 
         } catch (AnalyzerException ae) {
             if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
                 throw (DifferentiationException) ae.getCause();
             } else {
                 throw new DifferentiationException(NablaMessages.UNABLE_TO_ANALYZE_METHOD,
-                                                   getPrimitiveName(), method.name, ae.getMessage());
+                                                   getPrimitiveName(), primitiveMethod.name, ae.getMessage());
             }
         }
     }
@@ -249,12 +255,15 @@ public class MethodDifferentiator {
      * @param isStatic if true, the method is static
      * @param method method name
      * @param primitiveMethodType method type in the primitive (includes return and arguments types)
+     * @param differentiatedMethodType method type in the differentiated class (includes return and arguments types)
      * @return type of the differentiated method
      * @exception DifferentiationException if class cannot be found
      */
-    public Type requestMethodDifferentiation(final String owner, final boolean isStatic,
-                                             final String method, final Type primitiveMethodType) {
-        return classDifferentiator.requestMethodDifferentiation(owner, isStatic, method, primitiveMethodType);
+    public void requestMethodDifferentiation(final String owner, final boolean isStatic,
+                                             final String method, final Type primitiveMethodType,
+                                             final Type differentiatedMethodType) {
+        classDifferentiator.requestMethodDifferentiation(owner, isStatic, method,
+                                                         primitiveMethodType, differentiatedMethodType);
     }
 
     /** Identify the instructions that must be changed.
@@ -451,14 +460,6 @@ public class MethodDifferentiator {
                 final Frame<TrackingValue> produced = frames.get(successor);
 
                 // check the stack cells
-                for (int i = 0; i < before.getStackSize(); ++i) {
-                    final TrackingValue value = before.getStack(i);
-                    if (((i >= produced.getStackSize()) || (value != produced.getStack(i))) &&
-                        value.getType().equals(Type.DOUBLE_TYPE) &&
-                        !converted.contains(value)) {
-                        values.add(value);
-                    }
-                }
                 for (int i = 0; i < produced.getStackSize(); ++i) {
                     final TrackingValue value = produced.getStack(i);
                     if (((i >= beforeStackSize) || (value != before.getStack(i))) &&
@@ -501,6 +502,23 @@ public class MethodDifferentiator {
 
     }
 
+    /** Get the type of a stack element before instruction is executed.
+     * @param insn current instruction
+     * @param index index of the stack element, 0 corresponding to top stack cell
+     * @return type of a stack element before instruction is executed
+     */
+    public Type stackElementType(final AbstractInsnNode insn, final int index) {
+
+        // get the frame at the start of the instruction
+        final Frame<TrackingValue> frame = frames.get(insn);
+
+        // get stack size
+        final int size = frame.getStackSize();
+
+        return frame.getStack(size - (index + 1)).getType();
+
+    }
+
     /** Get the replacement list for an instruction.
      * @param insn instruction to replace
      * @param dsIndex index of a reference {@link DerivativeStructure derivative structure} variable

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java?rev=1395621&r1=1395620&r2=1395621&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java Mon Oct  8 15:38:44 2012
@@ -16,6 +16,7 @@
  */
 package org.apache.commons.nabla.forward.instructions;
 
+import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
 import org.apache.commons.nabla.DifferentiationException;
 import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
 import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
@@ -39,12 +40,25 @@ public class InvokeNonMathTransformer im
 
         final MethodInsnNode methodInsn = (MethodInsnNode) insn;
 
+        // build transformed method signature based on stack elements
+        final Type dsType              = Type.getType(DerivativeStructure.class);
+        final Type primitiveMethodType = Type.getMethodType(methodInsn.desc);
+        final Type[] argumentTypes     = new Type[primitiveMethodType.getArgumentTypes().length];
+        for (int i = 0; i < argumentTypes.length; ++i) {
+            final int index = argumentTypes.length - 1 - i;
+            argumentTypes[i] = methodDifferentiator.stackElementIsConverted(insn, index) ?
+                               dsType : methodDifferentiator.stackElementType(insn, index);
+        }
+        final Type returnType = (primitiveMethodType.getReturnType() == Type.DOUBLE_TYPE) ?
+                                dsType : primitiveMethodType.getReturnType();
+        final Type differentiatedMethodType = Type.getMethodType(returnType, argumentTypes);
+
         // request the global differentiator to differentiate the invoked method
-        Type differentiatedMethodType =
-                methodDifferentiator.requestMethodDifferentiation(Type.getType("L" + methodInsn.owner + ";").getClassName(),
-                                                                  methodInsn.getOpcode() == Opcodes.INVOKESTATIC,
-                                                                  methodInsn.name, Type.getMethodType(methodInsn.desc));
+        methodDifferentiator.requestMethodDifferentiation(Type.getType("L" + methodInsn.owner + ";").getClassName(),
+                                                          methodInsn.getOpcode() == Opcodes.INVOKESTATIC,
+                                                          methodInsn.name, primitiveMethodType, differentiatedMethodType);
 
+        // create the transformed instruction
         final InsnList list = new InsnList();
         list.add(new MethodInsnNode(methodInsn.getOpcode(),
                                     methodDifferentiator.getDerivedName(), methodInsn.name,

Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java?rev=1395621&r1=1395620&r2=1395621&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java (original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java Mon Oct  8 15:38:44 2012
@@ -150,6 +150,22 @@ public class ForwardModeDifferentiatorTe
     }
 
     @Test
+    public void testDifferentiateWRTDifferentArguments() {
+        checkReference(new ReferenceFunction() {
+            private double f(double a, double b) {
+                return a / b;
+            }
+            public double value(double t) {
+                // the differentiator will generate two different methods,
+                // one differentiated with respect to the first argument only,
+                // one differentiated with respect to the second argument only
+                return f(t, 2.0) + f(1.0, t);
+            }
+            public double firstDerivative(double t) { return 0.5 - 1.0 / (t * t); }
+        }, -5.25, 5, 20, 8.0e-15);
+    }
+
+    @Test
     public void testPartialDerivatives() throws Exception {
         PartialFunction function = new PartialFunction(1);