You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2009/08/17 21:50:52 UTC

svn commit: r805114 - in /commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic: ./ forward/ forward/analysis/ forward/instructions/

Author: luc
Date: Mon Aug 17 19:50:51 2009
New Revision: 805114

URL: http://svn.apache.org/viewvc?rev=805114&view=rev
Log:
use the tree API for classes too (in addition to methods)
this will allow modifying fields and methods called by the differentiated method

Added:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java   (with props)
Modified:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java

Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java?rev=805114&view=auto
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java (added)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java Mon Aug 17 19:50:51 2009
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.nabla.algorithmic;
+
+import org.apache.commons.nabla.core.DifferentialPair;
+
+/**
+ * Interface defining methods descriptors for differentials.
+ */
+public interface Descriptors {
+
+    /** Name for the DifferentialPair class. */
+    String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
+
+    /** Descriptor for the DifferentialPair class. */
+    String DP_DESCRIPTOR = "L" + DP_NAME + ";";
+
+    /** Descriptor for the primitive class f method. */
+    String D_RETURN_D_DESCRIPTOR = "(D)D";
+
+    /** Descriptor for the derivative class f method. */
+    String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
+
+    /** Descriptor for <code>DifferentialPair f(double)</code> methods. */
+    String D_RETURN_DP_DESCRIPTOR = "(D)" + DP_DESCRIPTOR;
+
+    /** Descriptor for <code>double f()</code> methods. */
+    String VOID_RETURN_D_DESCRIPTOR = "()D";
+
+}

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java
------------------------------------------------------------------------------
    svn:keywords = Author Date Id Revision

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java Mon Aug 17 19:50:51 2009
@@ -25,6 +25,7 @@
 import java.util.HashSet;
 import java.util.Set;
 
+import org.apache.commons.nabla.algorithmic.Descriptors;
 import org.apache.commons.nabla.algorithmic.forward.analysis.ForwardModeClassDifferentiator;
 import org.apache.commons.nabla.core.DifferentiationException;
 import org.apache.commons.nabla.core.UnivariateDerivative;
@@ -32,6 +33,7 @@
 import org.apache.commons.nabla.core.UnivariateDifferentiator;
 import org.objectweb.asm.ClassReader;
 import org.objectweb.asm.ClassWriter;
+import org.objectweb.asm.tree.ClassNode;
 
 /** Algorithmic differentiator class in forward mode based on bytecode analysis.
  * <p>This class is an implementation of the {@link UnivariateDifferentiator}
@@ -163,19 +165,21 @@
         throws DifferentiationException {
         try {
 
-            // set up both ends of the class transform chain
+            // get the original class
             final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class";
             final InputStream stream = differentiableClass.getResourceAsStream(classResourceName);
             final ClassReader reader = new ClassReader(stream);
-            final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES);
 
             // differentiate the function embedded in the differentiable class
-            final ForwardModeClassDifferentiator differentiator = new ForwardModeClassDifferentiator(mathClasses, writer);
-            reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
-            differentiator.reportErrors();
+            final ForwardModeClassDifferentiator differentiator =
+                new ForwardModeClassDifferentiator(reader, mathClasses);
+            differentiator.differentiateMethod("f", Descriptors.D_RETURN_D_DESCRIPTOR,
+                                               Descriptors.DP_RETURN_DP_DESCRIPTOR);
 
             // create the derivative class
-            return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer);
+            final ClassNode   derived = differentiator.getDerivedClass();
+            final ClassWriter writer  = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
+            return new DerivativeLoader(differentiableClass).defineClass(derived, writer);
 
         } catch (IOException ioe) {
             throw new DifferentiationException("class {0} cannot be read ({1})",
@@ -194,14 +198,15 @@
         }
 
         /** Define a derivative class.
-         * @param differentiator class differentiator
+         * @param classNode differentiated class
          * @param writer class writer
          * @return a generated derivative class
          */
         @SuppressWarnings("unchecked")
         public Class<? extends UnivariateDerivative>
-        defineClass(final ForwardModeClassDifferentiator differentiator, final ClassWriter writer) {
-            final String name = differentiator.getDerivativeClassName().replace('/', '.');
+        defineClass(final ClassNode classNode, final ClassWriter writer) {
+            final String name = classNode.name.replace('/', '.');
+            classNode.accept(writer);
             final byte[] bytecode = writer.toByteArray();
             return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length);
         }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java Mon Aug 17 19:50:51 2009
@@ -16,37 +16,38 @@
  */
 package org.apache.commons.nabla.algorithmic.forward.analysis;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Set;
 
 import org.apache.commons.nabla.core.DifferentiationException;
 import org.apache.commons.nabla.core.UnivariateDerivative;
 import org.apache.commons.nabla.core.UnivariateDifferentiable;
-import org.objectweb.asm.AnnotationVisitor;
-import org.objectweb.asm.Attribute;
-import org.objectweb.asm.ClassVisitor;
-import org.objectweb.asm.FieldVisitor;
-import org.objectweb.asm.MethodVisitor;
+import org.objectweb.asm.ClassReader;
 import org.objectweb.asm.Opcodes;
+import org.objectweb.asm.tree.ClassNode;
+import org.objectweb.asm.tree.FieldNode;
+import org.objectweb.asm.tree.MethodNode;
 
 /**
- * Visitor (in asm sense) for differentiating classes using forward mode.
+ * Differentiator for classes using forward mode.
  * <p>
- * This visitor visits classes implementing the
+ * This differentiator transforms classes implementing the
  * {@link UnivariateDifferentiable UnivariateDifferentiable} interface and convert
  * them to classes implementing the {@link UnivariateDerivative
  * UnivariateDerivative} interface.
  * </p>
  * <p>
- * The visitor creates a new class as an inner class of the visited class.
+ * The differentiator creates a new class as an inner class of the visited class.
  * Instances of the generated class are therefore automatically bound to their
  * primitive instance which is their directly enclosing instance. As such they
  * have access to the current value of all fields.
  * </p>
  * <p>
- * The visited class bytecode is not changed at all.
+ * The original class bytecode is not changed at all.
  * </p>
  */
-public class ForwardModeClassDifferentiator implements ClassVisitor {
+public class ForwardModeClassDifferentiator {
 
     /** Name for the primitive instance field. */
     private static final String PRIMITIVE_FIELD = "primitive";
@@ -54,57 +55,42 @@
     /** Math implementation classes. */
     private final Set<String> mathClasses;
 
-    /** Class generating visitor. */
-    private final ClassVisitor generator;
-
-    /** Error reporter. */
-    private final ErrorReporter errorReporter;
+    /** Class to differentiate. */
+    private final ClassNode classNode;
 
     /** Primitive class name. */
-    private String primitiveName;
-
-    /** Descriptor for the primitive class. */
-    private String primitiveDesc;
+    private final String primitiveName;
 
-    /** Derivative class name. */
-    private String derivativeName;
+    /** Primitive class methods. */
+    private final List<MethodNode> primitiveMethods;
 
-    /** Indicator for specific fields and method addition. */
-    private boolean specificMembersAdded;
+    /** Descriptor for the primitive class. */
+    private final String primitiveDesc;
 
     /**
      * Simple constructor.
+     * @param reader reader for the primitive class
      * @param mathClasses math implementation classes
-     * @param generator visitor to which class generation calls will be delegated
+     * @exception DifferentiationException if class cannot be differentiated
      */
-    public ForwardModeClassDifferentiator(final Set<String> mathClasses,
-                                          final ClassVisitor generator) {
-        this.mathClasses = mathClasses;
-        this.generator   = generator;
-        errorReporter    = new ErrorReporter();
-    }
+    @SuppressWarnings("unchecked")
+    public ForwardModeClassDifferentiator(final ClassReader reader,
+                                          final Set<String> mathClasses)
+        throws DifferentiationException {
 
-    /**
-     * Get the name of the derivative class.
-     * @return name of the (generated) derivative class
-     */
-    public String getDerivativeClassName() {
-        return derivativeName;
-    }
+        classNode = new ClassNode();
+        reader.accept(classNode, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
+        this.mathClasses = mathClasses;
 
-    /** {@inheritDoc} */
-    public void visit(final int version, final int access,
-                      final String name, final String signature,
-                      final String superName, final String[] interfaces) {
-        // set up the various names
-        primitiveName  = name;
-        derivativeName = primitiveName + "$NablaForwardModeUnivariateDerivative";
-        primitiveDesc  = "L" + primitiveName + ";";
+        // store the primitive class properties
+        primitiveName    = classNode.name;
+        primitiveDesc    = "L" + primitiveName + ";";
+        primitiveMethods = classNode.methods;
 
         // check the UnivariateDifferentiable interface is implemented
         final Class<UnivariateDifferentiable> uDerClass = UnivariateDifferentiable.class;
         boolean isDifferentiable = false;
-        for (String interf : interfaces) {
+        for (String interf : (List<String>) classNode.interfaces) {
             final String interfName = interf.replace('/', '.');
             Class<?> interfClass = null;
             try {
@@ -112,162 +98,106 @@
             } catch (ClassNotFoundException cnfe) {
                 // this should never occur since class has already been loaded
                 // and an instance already exists ...
-                errorReporter.register(new DifferentiationException("interface {0} not found " +
-                                                                    "while differentiating class {1}",
-                                                                    interfName, name));
+                throw new DifferentiationException("interface {0} not found " +
+                                                   "while differentiating class {1}",
+                                                   interfName, primitiveName);
             }
             if (interfClass != null) {
                 isDifferentiable = isDifferentiable || uDerClass.isAssignableFrom(interfClass);
             }
         }
 
-        if (isDifferentiable) {
-            // generate the new class implementing the UnivariateDerivative interface
-            generator.visit(version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
-                            derivativeName, signature, superName,
-                            new String[] {
-                                UnivariateDerivative.class.getName().replace('.', '/')
-                            });
-        } else {
-            errorReporter.register(new DifferentiationException("the {0} class does not implement " +
-                                                                "the {1} interface",
-                                                                name, uDerClass.getName()));
-        }
-
-        specificMembersAdded = false;
+        if (!isDifferentiable) {
+            throw new DifferentiationException("the {0} class does not implement the {1} interface",
+                                               primitiveName, uDerClass.getName());
+        }
+
+        // change the class properties for the derived class
+        classNode.access     = Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC;
+        classNode.name       = primitiveName + "$NablaForwardModeUnivariateDerivative";
+        classNode.fields     = new ArrayList<FieldNode>();
+        classNode.methods    = new ArrayList<MethodNode>();
+        classNode.interfaces.clear();
+        classNode.interfaces.add(UnivariateDerivative.class.getName().replace('.', '/'));
+
+        // primitive instance field and methods setting/getting it
+        addPrimitiveField();
+        addConstructor();
+        addGetPrimitive();
 
     }
 
-    /** {@inheritDoc} */
-    public MethodVisitor visitMethod(final int access, final String name,
-                                     final String desc, final String signature,
-                                     final String[] exceptions) {
-
-        // don't do anything if an error has already been encountered
-        if (errorReporter.hasError()) {
-            return null;
-        }
-
-        if (!specificMembersAdded) {
-            // add the specific members we need
-            addPrimitiveField();
-            addConstructor();
-            addGetPrimitive();
-            specificMembersAdded = true;
-        }
-
-        // is it the "public double f(double)" method we want to differentiate ?
-        if (((access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) &&
-                "f".equals(name) && "(D)D".equals(desc) &&
-                ((exceptions == null) || (exceptions.length == 0))) {
-
-            // get a generator for the method we are going to create
-            final MethodVisitor visitor =
-                generator.visitMethod(access | Opcodes.ACC_SYNTHETIC, name,
-                                      MethodDifferentiator.DP_RETURN_DP_DESCRIPTOR, null, null);
-
-            // make sure our own differentiator will be used to transform the code
-            return new MethodDifferentiator(access, name, desc, signature, exceptions,
-                                            visitor, primitiveName, mathClasses, errorReporter);
+    /**
+     * Differentiate a method.
+     * @param name of the method
+     * @param primitiveDesc descriptor of the method in the primitive class
+     * @param derivativeDesc descriptor of the method in the derivative class
+     * @exception DifferentiationException if method cannot be differentiated
+     */
+    @SuppressWarnings("unchecked")
+    public void differentiateMethod(final String name, final String primitiveDesc,
+                                    final String derivativeDesc)
+        throws DifferentiationException {
+
+        for (final MethodNode method : primitiveMethods) {
+            if (method.name.equals(name) && method.desc.equals(primitiveDesc)) {
+
+                final MethodDifferentiator differentiator = new MethodDifferentiator(mathClasses);
+                 differentiator.differentiate(primitiveName, method);
+                classNode.methods.add(method);
 
+            }
         }
-
-        // we are not interested in this method
-        return null;
-
-    }
-
-    /** {@inheritDoc} */
-    public FieldVisitor visitField(final int access, final String name,
-                                   final String desc, final  String signature,
-                                   final Object value) {
-        // we are not interested in any fields
-        return null;
-    }
-
-    /** {@inheritDoc} */
-    public void visitSource(final String source, final String debug) {
-    }
-
-    /** {@inheritDoc} */
-    public void visitOuterClass(final String owner, final String name,
-                                final String desc) {
-    }
-
-    /** {@inheritDoc} */
-    public AnnotationVisitor visitAnnotation(final String desc,
-                                             final boolean visible) {
-        return null;
-    }
-
-    /** {@inheritDoc} */
-    public void visitAttribute(final Attribute attr) {
-    }
-
-    /** {@inheritDoc} */
-    public void visitInnerClass(final String name, final String outerName,
-                                final String innerName, final int access) {
     }
 
-    /** {@inheritDoc} */
-    public void visitEnd() {
-
-        // don't do anything if an error has already been encountered
-        if (errorReporter.hasError()) {
-            return;
-        }
-
-        generator.visitEnd();
-
+    /**
+     * Get the derived class.
+     * @return derived class
+     */
+    public ClassNode getDerivedClass() {
+        return classNode;
     }
 
     /** Add the primitive field.
      */
+    @SuppressWarnings("unchecked")
     private void addPrimitiveField() {
-        final FieldVisitor visitor =
-            generator.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
-                                 PRIMITIVE_FIELD, primitiveDesc, null, null);
-        visitor.visitEnd();
+        FieldNode primitiveField =
+            new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
+                          PRIMITIVE_FIELD, primitiveDesc, null, null);
+        classNode.fields.add(primitiveField);
     }
 
     /** Add the class constructor.
      */
+    @SuppressWarnings("unchecked")
     private void addConstructor() {
         final String init = "<init>";
-        final MethodVisitor visitor =
-            generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init,
-                                  "(" + primitiveDesc + ")V", null, null);
-        visitor.visitCode();
-        visitor.visitVarInsn(Opcodes.ALOAD, 0);
-        visitor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V");
-        visitor.visitVarInsn(Opcodes.ALOAD, 0);
-        visitor.visitVarInsn(Opcodes.ALOAD, 1);
-        visitor.visitFieldInsn(Opcodes.PUTFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc);
-        visitor.visitInsn(Opcodes.RETURN);
-        visitor.visitMaxs(0, 0);
-        visitor.visitEnd();
+        final MethodNode constructor =
+            new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init,
+                           "(" + primitiveDesc + ")V", null, null);
+        constructor.visitVarInsn(Opcodes.ALOAD, 0);
+        constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V");
+        constructor.visitVarInsn(Opcodes.ALOAD, 0);
+        constructor.visitVarInsn(Opcodes.ALOAD, 1);
+        constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD, primitiveDesc);
+        constructor.visitInsn(Opcodes.RETURN);
+        constructor.visitMaxs(0, 0);
+        classNode.methods.add(constructor);
     }
 
     /** Add the {@link UnivariateDerivative#getPrimitive() getPrimitive()} method.
      */
+    @SuppressWarnings("unchecked")
     private void addGetPrimitive() {
-        final MethodVisitor visitor =
-            generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive",
-                                  "()" + primitiveDesc, null, null);
-        visitor.visitCode();
-        visitor.visitVarInsn(Opcodes.ALOAD, 0);
-        visitor.visitFieldInsn(Opcodes.GETFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc);
-        visitor.visitInsn(Opcodes.ARETURN);
-        visitor.visitMaxs(0, 0);
-        visitor.visitEnd();
-    }
-
-    /** Report the errors that may have occurred during analysis.
-     * @exception DifferentiationException if the derivative class
-     * could not be generated
-     */
-    public void reportErrors() throws DifferentiationException {
-        errorReporter.reportErrors();
-    }
+        final MethodNode method =
+            new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive",
+                           "()" + primitiveDesc, null, null);
+        method.visitVarInsn(Opcodes.ALOAD, 0);
+        method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD, primitiveDesc);
+        method.visitInsn(Opcodes.ARETURN);
+        method.visitMaxs(0, 0);
+        classNode.methods.add(method);
+   }
 
 }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java Mon Aug 17 19:50:51 2009
@@ -25,6 +25,7 @@
 import java.util.Map;
 import java.util.Set;
 
+import org.apache.commons.nabla.algorithmic.Descriptors;
 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer1;
 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer12;
 import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer2;
@@ -86,9 +87,7 @@
 import org.apache.commons.nabla.algorithmic.forward.trimming.DLoadPop2Trimmer;
 import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDloadTrimmer;
 import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDstoreTrimmer;
-import org.apache.commons.nabla.core.DifferentialPair;
 import org.apache.commons.nabla.core.DifferentiationException;
-import org.objectweb.asm.MethodVisitor;
 import org.objectweb.asm.Opcodes;
 import org.objectweb.asm.tree.AbstractInsnNode;
 import org.objectweb.asm.tree.IincInsnNode;
@@ -107,22 +106,7 @@
 /** Class transforming a method computing a value to a method
  * computing both a value and its differential.
  */
-public class MethodDifferentiator extends MethodNode {
-
-    /** Name for the DifferentialPair class. */
-    public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
-
-    /** Descriptor for the DifferentialPair class. */
-    public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";";
-
-    /** Descriptor for the derivative class f method. */
-    public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
-
-    /** Descriptor for <code>DifferentialPair f(double)</code> methods. */
-    public static final String D_RETURN_DP_DESCRIPTOR = "(D)" + DP_DESCRIPTOR;
-
-    /** Descriptor for <code>double f()</code> methods. */
-    private static final String VOID_RETURN_D_DESCRIPTOR = "()D";
+public class MethodDifferentiator {
 
     /** Math functions transformer. */
     private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS =
@@ -168,18 +152,9 @@
     /** Math implementation classes. */
     private final Set<String> mathClasses;
 
-    /** Generator to use. */
-    private final MethodVisitor generator;
-
     /** Used locals variables array. */
     private boolean[] usedLocals;
 
-    /** Primitive class name. */
-    private final String primitiveName;
-
-    /** Error reporter to use. */
-    private final ErrorReporter errorReporter;
-
     /** Set of converted values. */
     private final Set<TrackingValue> converted;
 
@@ -193,110 +168,95 @@
     private final Map<LabelNode, LabelNode> clonedLabels;
 
     /** Build a differentiator for a method.
-     * @param access access flags of the method
-     * @param name name of the method
-     * @param desc descriptor of the method
-     * @param signature signature of the method
-     * @param exceptions exceptions thrown by the method
-     * @param generator bytecode generator to use for the transformed method
-     * @param primitiveName primitive class name
      * @param mathClasses math implementation classes
-     * @param errorReporter reporter used for delaying exceptions
      */
-    public MethodDifferentiator(final int access, final String name, final String desc,
-                                final String signature, final String[] exceptions,
-                                final MethodVisitor generator,final  String primitiveName,
-                                final Set<String> mathClasses,
-                                final ErrorReporter errorReporter) {
-
-        super(access, name, desc, signature, exceptions);
-        this.generator     = generator;
+    public MethodDifferentiator(final Set<String> mathClasses) {
         this.usedLocals    = null;
-        this.primitiveName = primitiveName;
         this.mathClasses   = mathClasses;
-        this.errorReporter = errorReporter;
         this.converted     = new HashSet<TrackingValue>();
         this.frames        = new IdentityHashMap<AbstractInsnNode, Frame>();
         this.successors    = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>();
         this.clonedLabels  = new HashMap<LabelNode, LabelNode>();
-
     }
 
-    /** {@inheritDoc} */
-    @Override
-    public void visitEnd() {
+    /**
+     * Differentiate a method.
+     * @param primitiveName primitive class name
+     * @param method method to differentiate (<em>will</em> be modified)
+     * @exception DifferentiationException if method cannot be differentiated
+     */
+    public void differentiate(final String primitiveName, final MethodNode method)
+        throws DifferentiationException {
         try {
 
             // at start, "this" and one differential pair are already used
-            maxLocals  = 2 * (maxLocals + MAX_TEMP) - 1;
-            usedLocals = new boolean[maxLocals];
+            method.maxLocals  = 2 * (method.maxLocals + MAX_TEMP) - 1;
+            usedLocals = new boolean[method.maxLocals];
             useLocal(0, 1);
             useLocal(1, 4);
 
             // add spare cells to hold new variables if needed
-            addSpareLocalVariables();
+            addSpareLocalVariables(method.instructions);
 
             // analyze the original code, tracing values production/consumption
-            final Frame[] array =
-                new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this);
+            final FlowAnalyzer analyzer =
+                new FlowAnalyzer(new TrackingInterpreter(), method.instructions);
+            final Frame[] array = analyzer.analyze(primitiveName, method);
 
             // convert the array into a map, since code changes will shift all indices
             for (int i = 0; i < array.length; ++i) {
-                frames.put(instructions.get(i), array[i]);
+                frames.put(method.instructions.get(i), array[i]);
             }
 
             // identify the needed changes
-            final Set<AbstractInsnNode> changes = identifyChanges();
+            final Set<AbstractInsnNode> changes = identifyChanges(method.instructions);
 
             if (changes.isEmpty()) {
 
                 // the method does not depend on the parameter at all!
                 // we replace all "return d;" by "return DifferentialPair.newConstant(d);"
-                for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
+                for (final Iterator<?> i = method.instructions.iterator(); i.hasNext();) {
                     final AbstractInsnNode insn = (AbstractInsnNode) i.next();
                     if (insn.getOpcode() == Opcodes.DRETURN) {
                         final InsnList list = new InsnList();
                         list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
-                                                    MethodDifferentiator.DP_NAME,
-                                                    "newConstant", D_RETURN_DP_DESCRIPTOR));
+                                                    Descriptors.DP_NAME,
+                                                    "newConstant",
+                                                    Descriptors.D_RETURN_DP_DESCRIPTOR));
                         list.add(new InsnNode(Opcodes.ARETURN));
-                        instructions.insert(insn, list);
-                        instructions.remove(insn);
+                        method.instructions.insert(insn, list);
+                        method.instructions.remove(insn);
                     }
                 }
 
             } else {
 
                 // perform the code changes
-                changeCode(changes);
-
-                // remove the local variables added at the beginning and not used
-                removeUnusedSpareLocalVariables();
+                changeCode(method.instructions, changes);
 
                 // trim generated instructions list
-                SwappedDloadTrimmer.getInstance().trim(instructions);
-                SwappedDstoreTrimmer.getInstance().trim(instructions);
-                DLoadPop2Trimmer.getInstance().trim(instructions);
+                SwappedDloadTrimmer.getInstance().trim(method.instructions);
+                SwappedDstoreTrimmer.getInstance().trim(method.instructions);
+                DLoadPop2Trimmer.getInstance().trim(method.instructions);
 
             }
 
-            // change the descriptor to its true final value
-            desc = DP_RETURN_DP_DESCRIPTOR;
+            // remove the local variables added at the beginning and not used
+            removeUnusedSpareLocalVariables(method.instructions);
 
-            // generate the method
-            accept(generator);
+            // change the method properties to the derivative ones
+            method.desc       = Descriptors.DP_RETURN_DP_DESCRIPTOR;
+            method.access    |= Opcodes.ACC_SYNTHETIC;
+            method.maxLocals  = maxVariables();
 
         } catch (AnalyzerException ae) {
+            ae.printStackTrace(System.err);
             if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
-                errorReporter.register((DifferentiationException) ae.getCause());
+                throw (DifferentiationException) ae.getCause();
             } else {
-                final DifferentiationException de =
-                    new DifferentiationException("unable to analyze the {0}.{1} method ({2})",
-                                                 primitiveName, name, ae.getMessage());
-                errorReporter.register(de);
+                throw new DifferentiationException("unable to analyze the {0}.{1} method ({2})",
+                                                   primitiveName, method.name, ae.getMessage());
             }
-        } catch (DifferentiationException de) {
-            errorReporter.register(de);
         }
     }
 
@@ -309,11 +269,13 @@
      * be referenced by the converted instructions in the following passes.</p>
      * <p>The spare cells that will not be used will be reclaimed after
      * conversion, to avoid wasting memory.</p>
+     * @param instructions instructions of the method
      * @exception DifferentiationException if local variables array has not been
      * expanded appropriately beforehand
      * @see #removeUnusedSpareLocalVariables()
      */
-    private void addSpareLocalVariables() throws DifferentiationException {
+    private void addSpareLocalVariables(final InsnList instructions)
+        throws DifferentiationException {
         for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
             final AbstractInsnNode insn = (AbstractInsnNode) i.next();
             if (insn.getType() == AbstractInsnNode.VAR_INSN) {
@@ -340,9 +302,10 @@
     }
 
     /** Remove the unused spare cells introduced at conversion start.
+     * @param instructions instructions of the method
      * @see #addSpareLocalVariables()
      */
-    private void removeUnusedSpareLocalVariables() {
+    private void removeUnusedSpareLocalVariables(final InsnList instructions) {
         for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
             final AbstractInsnNode insn = (AbstractInsnNode) i.next();
             if (insn.getType() == AbstractInsnNode.VAR_INSN) {
@@ -358,9 +321,10 @@
      * instructions path, updating stack cells and local variables as needed.
      * Instructions that must be changed are the ones that consume changed
      * variables or stack cells.</p>
+     * @param instructions instructions of the method
      * @return set containing all the instructions that must be changed
      */
-    private Set<AbstractInsnNode> identifyChanges() {
+    private Set<AbstractInsnNode> identifyChanges(final InsnList instructions) {
 
         // the pending set contains the values (local variables or stack cells)
         // that have been changed, they will trigger changes on the instructions
@@ -461,21 +425,22 @@
     }
 
     /** Perform the code changes.
+     * @param instructions instructions of the method
      * @param changes instructions that must be changed
      * @exception DifferentiationException if some instruction cannot be handled
      */
-    private void changeCode(final Set<AbstractInsnNode> changes)
+    private void changeCode(final InsnList instructions, final Set<AbstractInsnNode> changes)
         throws DifferentiationException {
 
         // insert the parameter conversion code at method start
         final InsnList list = new InsnList();
         list.add(new VarInsnNode(Opcodes.ALOAD, 1));
         list.add(new InsnNode(Opcodes.DUP));
-        list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
-                                    "getValue", VOID_RETURN_D_DESCRIPTOR));
+        list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Descriptors.DP_NAME,
+                                    "getValue", Descriptors.VOID_RETURN_D_DESCRIPTOR));
         list.add(new VarInsnNode(Opcodes.DSTORE, 1));
-        list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
-                                    "getFirstDerivative", VOID_RETURN_D_DESCRIPTOR));
+        list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Descriptors.DP_NAME,
+                                    "getFirstDerivative", Descriptors.VOID_RETURN_D_DESCRIPTOR));
         list.add(new VarInsnNode(Opcodes.DSTORE, 3));
 
         instructions.insertBefore(instructions.get(0), list);
@@ -688,7 +653,7 @@
      * @param name name of the class to test
      * @return true if the named class is a math implementation class
      */
-    public boolean isMathImplementationClass(final String name) {
+    private boolean isMathImplementationClass(final String name) {
         return mathClasses.contains(name);
     }
 
@@ -744,7 +709,7 @@
     /** Shifted the index of a variable instruction.
      * @param insn variable instruction
      */
-    public void shiftVariable(final VarInsnNode insn) {
+    private void shiftVariable(final VarInsnNode insn) {
         int shifted = 0;
         for (int i = 0; i < insn.var; ++i) {
             if (usedLocals[i]) {
@@ -754,6 +719,19 @@
         insn.var = shifted;
     }
 
+    /** Compute the maximal number of used local variables.
+     * @return maximal number of used local variables
+     */
+    private int maxVariables() {
+        int max = 0;
+        for (final boolean isUsed : usedLocals) {
+            if (isUsed) {
+                ++max;
+            }
+        }
+        return max;
+    }
+
     /** Clone an instruction.
      * @param insn instruction to clone
      * @return cloned instruction
@@ -765,11 +743,17 @@
     /** Analyzer preserving instructions successors information. */
     private class FlowAnalyzer extends Analyzer {
 
+        /** Instructions of the method. */
+        private final InsnList instructions;
+
         /** Simple constructor.
          * @param interpreter associated interpreter
+         * @param instructions instructions of the method
          */
-        public FlowAnalyzer(final Interpreter interpreter) {
+        public FlowAnalyzer(final Interpreter interpreter,
+                            final InsnList instructions) {
             super(interpreter);
+            this.instructions = instructions;
         }
 
         /** Store a new edge.

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java?rev=805114&r1=805113&r2=805114&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java Mon Aug 17 19:50:51 2009
@@ -16,6 +16,7 @@
  */
 package org.apache.commons.nabla.algorithmic.forward.instructions;
 
+import org.apache.commons.nabla.algorithmic.Descriptors;
 import org.apache.commons.nabla.algorithmic.forward.analysis.InstructionsTransformer;
 import org.apache.commons.nabla.algorithmic.forward.analysis.MethodDifferentiator;
 import org.apache.commons.nabla.core.DifferentiationException;
@@ -68,17 +69,16 @@
 
         final InsnList list = new InsnList();
         // operand stack initial state: a0, a1
-        list.add(new VarInsnNode(Opcodes.DSTORE, 3));             // => a0
-        list.add(new VarInsnNode(Opcodes.DSTORE, 1));             // =>
-        list.add(new TypeInsnNode(Opcodes.NEW,
-                                  MethodDifferentiator.DP_NAME)); // => o,
-        list.add(new InsnNode(Opcodes.DUP));                      // => o, o
-        list.add(new VarInsnNode(Opcodes.DLOAD, 1));              // => o, o, a0
-        list.add(new VarInsnNode(Opcodes.DLOAD, 3));              // => o, o, a0, a1
+        list.add(new VarInsnNode(Opcodes.DSTORE, 3));                 // => a0
+        list.add(new VarInsnNode(Opcodes.DSTORE, 1));                 // =>
+        list.add(new TypeInsnNode(Opcodes.NEW, Descriptors.DP_NAME)); // => o,
+        list.add(new InsnNode(Opcodes.DUP));                          // => o, o
+        list.add(new VarInsnNode(Opcodes.DLOAD, 1));                  // => o, o, a0
+        list.add(new VarInsnNode(Opcodes.DLOAD, 3));                  // => o, o, a0, a1
         list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL,
-                                    MethodDifferentiator.DP_NAME,
-                                    "<init>", "(DD)V"));          // => dp
-        list.add(new InsnNode(Opcodes.ARETURN));                  // =>
+                                    Descriptors.DP_NAME,
+                                    "<init>", "(DD)V"));              // => dp
+        list.add(new InsnNode(Opcodes.ARETURN));                      // =>
         return list;
 
     }