You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2009/08/12 16:14:25 UTC

svn commit: r803525 - in /commons/sandbox/nabla/trunk/src: main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java test/java/org/apache/commons/nabla/algorithmic/AbstractMathTest.java

Author: luc
Date: Wed Aug 12 14:14:23 2009
New Revision: 803525

URL: http://svn.apache.org/viewvc?rev=803525&view=rev
Log:
fixed an error when differentiating a constant function.
the derivative was properly set to 0 but the value was also set to 0 instead of the constant

Modified:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java
    commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/algorithmic/AbstractMathTest.java

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java?rev=803525&r1=803524&r2=803525&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java Wed Aug 12 14:14:23 2009
@@ -91,7 +91,6 @@
 import org.objectweb.asm.MethodVisitor;
 import org.objectweb.asm.Opcodes;
 import org.objectweb.asm.tree.AbstractInsnNode;
-import org.objectweb.asm.tree.FieldInsnNode;
 import org.objectweb.asm.tree.IincInsnNode;
 import org.objectweb.asm.tree.InsnList;
 import org.objectweb.asm.tree.InsnNode;
@@ -119,6 +118,9 @@
     /** Descriptor for the derivative class f method. */
     public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
 
+    /** Descriptor for <code>DifferentialPair f(double)</code> methods. */
+    public static final String D_RETURN_DP_DESCRIPTOR = "(D)" + DP_DESCRIPTOR;
+
     /** Descriptor for <code>double f()</code> methods. */
     private static final String VOID_RETURN_D_DESCRIPTOR = "()D";
 
@@ -225,7 +227,7 @@
     public void visitEnd() {
         try {
 
-            // at start, "this" and one differential pair are used
+            // at start, "this" and one differential pair are already used
             maxLocals  = 2 * (maxLocals + MAX_TEMP) - 1;
             usedLocals = new boolean[maxLocals];
             useLocal(0, 1);
@@ -249,10 +251,19 @@
             if (changes.isEmpty()) {
 
                 // the method does not depend on the parameter at all!
-                // we replace all code by a simple "return DifferentialPair.ZERO;"
-                instructions.clear();
-                instructions.add(new FieldInsnNode(Opcodes.GETSTATIC, DP_NAME, "ZERO", DP_DESCRIPTOR));
-                instructions.add(new InsnNode(Opcodes.ARETURN));
+                // we replace all "return d;" by "return DifferentialPair.newConstant(d);"
+                for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
+                    final AbstractInsnNode insn = (AbstractInsnNode) i.next();
+                    if (insn.getOpcode() == Opcodes.DRETURN) {
+                        final InsnList list = new InsnList();
+                        list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
+                                                    MethodDifferentiator.DP_NAME,
+                                                    "newConstant", D_RETURN_DP_DESCRIPTOR));
+                        list.add(new InsnNode(Opcodes.ARETURN));
+                        instructions.insert(insn, list);
+                        instructions.remove(insn);
+                    }
+                }
 
             } else {
 
@@ -624,9 +635,7 @@
             throw new RuntimeException("MULTIANEWARRAY not handled yet");
         default:
             throw new DifferentiationException("unable to handle instruction with opcode {0}",
-                                          new Object[] {
-                                              Integer.valueOf(insn.getOpcode())
-                                          });
+                                               insn.getOpcode());
         }
 
     }
@@ -650,7 +659,7 @@
                 final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(methodInsn.name);
                 if (transformer == null) {
                     throw new DifferentiationException(UNKNOWN_METHOD_FMT,
-                                                  methodInsn.owner, methodInsn.name);
+                                                       methodInsn.owner, methodInsn.name);
                 }
                 return transformer.getReplacementList(methodInsn.owner, this);
             } else if ("(DD)D".equals(methodInsn.desc)) {
@@ -672,7 +681,7 @@
                     final MathInvocationTransformer transformer = MATH_TRANSFORMERS.get(name);
                     if (transformer == null) {
                         throw new DifferentiationException(UNKNOWN_METHOD_FMT,
-                                                      methodInsn.owner, methodInsn.name);
+                                                           methodInsn.owner, methodInsn.name);
                     }
                     return transformer.getReplacementList(methodInsn.owner, this);
                 }
@@ -699,11 +708,10 @@
      */
     public void useLocal(final int index, final int size)
         throws DifferentiationException {
-        if ((index < 0) || ((index + size - 1) >= usedLocals.length)) {
+        if ((index < 0) || ((index + size) > usedLocals.length)) {
             throw new DifferentiationException("index of size {0} local variable ({1}) " +
-                                          "outside of [{2}, {3}] range",
-                                          Integer.valueOf(size), Integer.valueOf(index),
-                                          Integer.valueOf(1), Integer.valueOf(MAX_TEMP));
+                                               "outside of [{2}, {3}] range",
+                                               size, index, 1, MAX_TEMP);
         }
         for (int i = index; i < index + size; ++i) {
             usedLocals[i] = true;
@@ -733,9 +741,7 @@
     public int getTmp(final int number) throws DifferentiationException {
         if ((number < 0) || (number > MAX_TEMP)) {
             throw new DifferentiationException("number of temporary variable ({0}) outside of [{1}, {2}] range",
-                                               Integer.valueOf(number),
-                                               Integer.valueOf(1),
-                                               Integer.valueOf(MAX_TEMP));
+                                               number, 1, MAX_TEMP);
         }
         final int index = usedLocals.length - 2 * number;
         useLocal(index, 2);

Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/algorithmic/AbstractMathTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/algorithmic/AbstractMathTest.java?rev=803525&r1=803524&r2=803525&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/algorithmic/AbstractMathTest.java (original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/algorithmic/AbstractMathTest.java Wed Aug 12 14:14:23 2009
@@ -37,6 +37,7 @@
                 double t = ((n - 1 - i) * t0 + i * t1) / (n - 1);
                 DifferentialPair dpT = DifferentialPair.newVariable(t);
                 Assert.assertEquals(reference.fPrime(t), derivative.f(dpT).getFirstDerivative(), threshold);
+                Assert.assertEquals(reference.f(t), derivative.f(dpT).getValue(), threshold);
             }
         } catch (DifferentiationException de) {
             Assert.fail(de.getMessage());