You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2012/10/03 21:19:37 UTC

svn commit: r1393702 - in /commons/sandbox/nabla/trunk/src: main/java/org/apache/commons/nabla/ main/java/org/apache/commons/nabla/forward/ main/java/org/apache/commons/nabla/forward/analysis/ main/java/org/apache/commons/nabla/forward/instructions/ ma...

Author: luc
Date: Wed Oct  3 19:19:36 2012
New Revision: 1393702

URL: http://svn.apache.org/viewvc?rev=1393702&view=rev
Log:
Added support for INVOKESPECIAL/INVOKEVIRTUAL/INVOKEINTERFACE.

This is a great step forward for Nabla as it allows to differentiate
methods that call other methods.

The basic structure is there but binding several instances together is
not done, so for now invocation will work only when the other methods
are called on the same instance of the same class.

Added:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaUnivariateDifferentiableFunction.java   (with props)
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeMathTransformer.java   (contents, props changed)
      - copied, changed from r1393701, commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeStaticTransformer.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java   (with props)
Removed:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeStaticTransformer.java
Modified:
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
    commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties
    commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/NablaMessagesTest.java

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java?rev=1393702&r1=1393701&r2=1393702&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java Wed Oct  3 19:19:36 2012
@@ -49,7 +49,6 @@ public enum NablaMessages implements Loc
     CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE("class {0} instantiation from an instance of class {1} failed ({2})"),
     INCORRECT_GENERATED_CODE("class {0} code generated from an instance of class {1} is incorrect ({2})"),
     INTERFACE_NOT_FOUND_WHILE_DIFFERENTIATING("interface {0} not found while differentiating class {1}"),
-    CLASS_DOES_NOT_IMPLEMENT_INTERFACE("the {0} class does not implement the {1} interface"),
     UNABLE_TO_ANALYZE_METHOD("unable to analyze the {0}.{1} method ({2})"),
     UNKNOWN_METHOD("unknown method {0}.{1}"),
     UNEXPECTED_INSTRUCTION("unexpected instruction with opcode {0}"),

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java?rev=1393702&r1=1393701&r2=1393702&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java Wed Oct  3 19:19:36 2012
@@ -18,10 +18,12 @@ package org.apache.commons.nabla.forward
 
 import java.io.IOException;
 import java.io.OutputStream;
+import java.io.PrintWriter;
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.Set;
 
 import org.apache.commons.math3.analysis.UnivariateFunction;
@@ -31,9 +33,11 @@ import org.apache.commons.math3.util.Fas
 import org.apache.commons.nabla.DifferentiationException;
 import org.apache.commons.nabla.NablaMessages;
 import org.apache.commons.nabla.forward.analysis.ClassDifferentiator;
+import org.objectweb.asm.ClassReader;
 import org.objectweb.asm.ClassWriter;
 import org.objectweb.asm.Type;
 import org.objectweb.asm.tree.ClassNode;
+import org.objectweb.asm.util.TraceClassVisitor;
 
 /** Algorithmic differentiator class in forward mode based on bytecode analysis.
  * <p>This class is an implementation of the {@link UnivariateFunctionDifferentiator}
@@ -56,7 +60,8 @@ import org.objectweb.asm.tree.ClassNode;
 public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator {
 
     /** UnivariateFunction/UnivariateDifferentiableFunction map. */
-    private final HashMap<Class<? extends UnivariateFunction>, Class<? extends NablaDifferentiated>> map;
+    private final HashMap<Class<? extends UnivariateFunction>,
+                          Class<? extends NablaUnivariateDifferentiableFunction>> map;
 
     /** Class name/ bytecode map. */
     private final HashMap<String, byte[]> byteCodeMap;
@@ -64,13 +69,22 @@ public class ForwardModeDifferentiator i
     /** Math implementation classes. */
     private final Set<String> mathClasses;
 
+    /** Pending differentiations. */
+    private final Set<DifferentiableMethod> pendingDifferentiations;
+
+    /** Processed differentiations. */
+    private final Set<DifferentiableMethod> processedDifferentiations;
+
     /** Simple constructor.
      * <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
      */
     public ForwardModeDifferentiator() {
-        map         = new HashMap<Class<? extends UnivariateFunction>, Class<? extends NablaDifferentiated>>();
-        byteCodeMap = new HashMap<String, byte[]>();
-        mathClasses = new HashSet<String>();
+        map                       = new HashMap<Class<? extends UnivariateFunction>,
+                                                Class<? extends NablaUnivariateDifferentiableFunction>>();
+        byteCodeMap               = new HashMap<String, byte[]>();
+        mathClasses               = new HashSet<String>();
+        pendingDifferentiations   = new HashSet<ForwardModeDifferentiator.DifferentiableMethod>();
+        processedDifferentiations = new HashSet<ForwardModeDifferentiator.DifferentiableMethod>();
         addMathImplementation(Math.class);
         addMathImplementation(StrictMath.class);
         addMathImplementation(FastMath.class);
@@ -97,16 +111,58 @@ public class ForwardModeDifferentiator i
         throw new RuntimeException("not implemented yet");
     }
 
+    /** Request differentiation of a method.
+     * @param owner class in which the method is defined
+     * @param isStatic if true, the method is static
+     * @param method method name
+     * @param primitiveMethodType method type in the primitive (includes return and arguments types)
+     * @return type of the differentiated method
+     * @exception DifferentiationException if class cannot be found
+     */
+    public Type requestMethodDifferentiation(final String owner, final boolean isStatic,
+                                             final String method, final Type primitiveMethodType)
+        throws DifferentiationException {
+
+        try {
+
+            final DifferentiableMethod dm =
+                    new DifferentiableMethod(Class.forName(owner), isStatic, method, primitiveMethodType);
+
+            if (!processedDifferentiations.contains(dm)) {
+                // schedule the request if method has not been processed yet
+                pendingDifferentiations.add(dm);
+            }
+ 
+            return dm.getDifferentiatedMethodType();
+
+        } catch (ClassNotFoundException cnfe) {
+            throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
+                                               owner, cnfe.getMessage());
+        }
+
+    }
+
     /** {@inheritDoc} */
-    public NablaDifferentiated differentiate(final UnivariateFunction d) {
+    public NablaUnivariateDifferentiableFunction differentiate(final UnivariateFunction d) {
+
+        // lookup in the map if the class has already been differentiated
+        Class<? extends NablaUnivariateDifferentiableFunction> derivativeClass = map.get(d.getClass());
+
+        // build the derivative class if it does not exist yet
+        if (derivativeClass == null) {
 
-        // get the derivative class
-        final Class<? extends NablaDifferentiated> derivativeClass = getDerivativeClass(d.getClass());
+            // perform algorithmic differentiation
+            derivativeClass = createDerivativeClass(d.getClass());
+
+            // put the newly created class in the map
+            map.put(d.getClass(), derivativeClass);
+
+        }
 
         try {
 
             // create the instance
-            final Constructor<? extends NablaDifferentiated> constructor =
+            final Constructor<? extends NablaUnivariateDifferentiableFunction> constructor =
                 derivativeClass.getConstructor(d.getClass());
             return constructor.newInstance(d);
 
@@ -129,65 +185,73 @@ public class ForwardModeDifferentiator i
 
     }
 
-    /** Get the derivative class of a differentiable class.
-     * <p>The derivative class is either built on the fly
-     * or retrieved from the cache if it has been built previously.</p>
-     * @param differentiableClass class to differentiate
-     * @return derivative class
-     * @throws DifferentiationException if the class cannot be differentiated
-     */
-    private Class<? extends NablaDifferentiated>
-    getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
-        throws DifferentiationException {
-
-        // lookup in the map if the class has already been differentiated
-        Class<? extends NablaDifferentiated> derivativeClass = map.get(differentiableClass);
-
-        // build the derivative class if it does not exist yet
-        if (derivativeClass == null) {
-
-            // perform algorithmic differentiation
-            derivativeClass = createDerivativeClass(differentiableClass);
-
-            // put the newly created class in the map
-            map.put(differentiableClass, derivativeClass);
-
-        }
-
-        // return the derivative class
-        return derivativeClass;
-
-    }
-
     /** Build a derivative class of a differentiable class.
      * @param differentiableClass class to differentiate
      * @return derivative class
      * @throws DifferentiationException if the class cannot be differentiated
      */
-    private Class<? extends NablaDifferentiated>
+    @SuppressWarnings("unchecked")
+    private Class<? extends NablaUnivariateDifferentiableFunction>
     createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
         throws DifferentiationException {
         try {
 
-            // differentiate the function embedded in the differentiable class
-            final ClassDifferentiator differentiator =
-                new ClassDifferentiator(differentiableClass, mathClasses);
-            final Type dsType = Type.getType(DerivativeStructure.class);
-            differentiator.differentiateMethod("value",
-                                               Type.getMethodType(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE),
-                                               Type.getMethodType(dsType, dsType));
-
-            // create the derivative class
-            final ClassNode   derived = differentiator.getDerivedClass();
-            final ClassWriter writer  = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
-            final String name = derived.name.replace('/', '.');
-            derived.accept(writer);
-            final byte[] bytecode = writer.toByteArray();
-
-            final Class<? extends NablaDifferentiated> dClass =
-                    new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
-            byteCodeMap.put(name, bytecode);
-            return dClass;
+            final Set<ClassDifferentiator> differentiators = new HashSet<ClassDifferentiator>();
+
+            // bootstrap differentiation using the top level value function from the UnivariateFunction interface
+            requestMethodDifferentiation(differentiableClass.getName(), false, "value",
+                                         Type.getMethodType(Type.DOUBLE_TYPE, Type.DOUBLE_TYPE));
+
+            while (!pendingDifferentiations.isEmpty()) {
+
+                // move the method from pending to processed
+                final Iterator<DifferentiableMethod> iterator = pendingDifferentiations.iterator();
+                final DifferentiableMethod dm = iterator.next();
+                iterator.remove();
+                processedDifferentiations.add(dm);
+
+                // find a differentiator for the class owning the method
+                ClassDifferentiator differentiator = null;
+                for (Iterator<ClassDifferentiator> dIter = differentiators.iterator();
+                     differentiator == null && dIter.hasNext();) {
+                    ClassDifferentiator current = dIter.next();
+                    if (Type.getInternalName(dm.getPrimitiveClass()).equals(current.getPrimitive().name)) {
+                        // we have already build a differentiator for the same class, reuse it
+                        differentiator = current;
+                    }
+                }
+                if (differentiator == null) {
+                    // it is the first time we process this class, create a differentiator for it
+                    differentiator = new ClassDifferentiator(dm.getPrimitiveClass(), mathClasses, this);
+                    differentiators.add(differentiator);
+                }
+
+                differentiator.differentiateMethod(dm.getMethod(), dm.getPrimitiveMethodType(),
+                                                   dm.getDifferentiatedMethodType());
+
+            }
+
+            // create the differential classes
+            Class<? extends NablaUnivariateDifferentiableFunction> nudf = null;
+            for (ClassDifferentiator differentiator : differentiators) {
+                final ClassNode   derived = differentiator.getDerived();
+                final ClassWriter writer  = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
+                final String name = derived.name.replace('/', '.');
+                derived.accept(writer);
+                final byte[] bytecode = writer.toByteArray();
+
+                final Class<? extends NablaDifferentiated> dClass =
+                        new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
+                byteCodeMap.put(name, bytecode);
+
+                if (differentiator.getPrimitive().name.equals(Type.getType(differentiableClass).getInternalName())) {
+                    nudf = (Class<? extends NablaUnivariateDifferentiableFunction>) dClass;
+                }
+
+            }
+
+            // return the top level one
+            return nudf;
 
         } catch (IOException ioe) {
             throw new DifferentiationException(NablaMessages.CANNOT_READ_CLASS,
@@ -217,4 +281,104 @@ public class ForwardModeDifferentiator i
         }
     }
 
+    /** Identifier for a method to differentiate. */
+    private static class DifferentiableMethod {
+
+        /** Primitive class to which the method belongs. */
+        private final Class<?> primitiveClass;
+
+        /** Indicator for static methods. */
+        private final boolean isStatic;
+
+        /** Name of the method. */
+        private final String method;
+
+        /** Type of the method in the primitive class. */
+        private final Type primitiveMethodType;
+
+        /** Simple constructor. */
+        public DifferentiableMethod(final Class<?> primitiveClass, final boolean isStatic,
+                                    final String method, final Type primitiveMethodType) {
+            this.primitiveClass       = primitiveClass;
+            this.isStatic             = isStatic;
+            this.method               = method;
+            this.primitiveMethodType  = primitiveMethodType;
+        }
+
+        /** Get the primitive class to which the method belongs.
+         * @return primitive class to which the method belongs
+         */
+        public Class<?> getPrimitiveClass() {
+            return primitiveClass;
+        }
+
+        /** Get the name of the method.
+         * @return name of the method
+         */
+        public String getMethod() {
+            return method;
+        }
+
+        /** Get the type of the method in the primitive class.
+         * @return type of the method in the primitive class
+         */
+        public Type getPrimitiveMethodType() {
+            return primitiveMethodType;
+        }
+
+        /** Get the type of the method in the differentiated class.
+         * @return type of the method in the differentiated class
+         */
+        public Type getDifferentiatedMethodType() {
+
+            // transform arguments types
+            final Type[] argumentsTypes = primitiveMethodType.getArgumentTypes();
+            for (int i = 0; i < argumentsTypes.length; ++i) {
+                if (argumentsTypes[i].equals(Type.DOUBLE_TYPE)) {
+                    argumentsTypes[i] = Type.getType(DerivativeStructure.class);
+                }
+            }
+
+            // transform return type
+            final Type returnType = primitiveMethodType.getReturnType().equals(Type.DOUBLE_TYPE) ?
+                                    Type.getType(DerivativeStructure.class) : primitiveMethodType.getReturnType();
+
+            return Type.getMethodType(returnType, argumentsTypes);
+
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        public boolean equals(final Object other) {
+
+            if (this == other) {
+                return true;
+            }
+
+            if (other instanceof DifferentiableMethod) {
+                // since fractions are always in lowest terms, numerators and
+                // denominators can be compared directly for equality.
+                DifferentiableMethod dm = (DifferentiableMethod)other;
+                return (primitiveClass      == dm.primitiveClass) &&
+                       (isStatic            == dm.isStatic)       &&
+                       (method              == dm.method)         &&
+                       (primitiveMethodType == dm.primitiveMethodType);
+            }
+
+            return false;
+
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        public int hashCode() {
+            // the following coefficients are arbitrarily chosen prime numbers
+            return 109 * primitiveClass.hashCode() +
+                   479 * Boolean.valueOf(isStatic).hashCode() +
+                   601 * method.hashCode() +
+                   571 * primitiveMethodType.hashCode();
+        }
+
+    }
+
 }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java?rev=1393702&r1=1393701&r2=1393702&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java Wed Oct  3 19:19:36 2012
@@ -19,21 +19,19 @@ package org.apache.commons.nabla.forward
 import java.lang.reflect.Field;
 
 import org.apache.commons.math3.analysis.UnivariateFunction;
-import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
-import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
 
 /** Base class for Nabla differentiated functions.
  * @version $Id$
  */
-public abstract class NablaDifferentiated implements UnivariateDifferentiableFunction {
+public class NablaDifferentiated {
 
     /** Primitive instance field. */
-    private final UnivariateFunction primitive;
+    private final Object primitive;
 
     /** Simple constructor.
      * @param primitive primitive instance
      */
-    protected NablaDifferentiated(final UnivariateFunction primitive) {
+    protected NablaDifferentiated(final Object primitive) {
         this.primitive = primitive;
     }
 
@@ -71,12 +69,11 @@ public abstract class NablaDifferentiate
         }
     }
 
-    /** {@inheritDoc} */
-    public double value(final double x) {
-        return primitive.value(x);
+    /** Get the primitive instance.
+     * @return primitive instance.
+     */
+    protected Object getPrimitiveObject() {
+        return primitive;
     }
 
-    /** {@inheritDoc} */
-    public abstract DerivativeStructure value(DerivativeStructure t);
-
 }

Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaUnivariateDifferentiableFunction.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaUnivariateDifferentiableFunction.java?rev=1393702&view=auto
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaUnivariateDifferentiableFunction.java (added)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaUnivariateDifferentiableFunction.java Wed Oct  3 19:19:36 2012
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.nabla.forward;
+
+import org.apache.commons.math3.analysis.UnivariateFunction;
+import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
+import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
+
+/** Base class for Nabla differentiated functions the implement {@link UnivariateDifferentiableFunction}.
+ * @version $Id$
+ */
+public abstract class NablaUnivariateDifferentiableFunction
+    extends NablaDifferentiated implements UnivariateDifferentiableFunction {
+
+    /** Simple constructor.
+     * @param primitive primitive instance
+     */
+    protected NablaUnivariateDifferentiableFunction(final UnivariateFunction primitive) {
+        super(primitive);
+    }
+
+    /** Get the primitive instance.
+     * @return primitive instance.
+     */
+    public UnivariateFunction getPrimitive() {
+        return (UnivariateFunction) getPrimitiveObject();
+    }
+
+    /** {@inheritDoc} */
+    public double value(final double x) {
+        return getPrimitive().value(x);
+    }
+
+    /** {@inheritDoc} */
+    public abstract DerivativeStructure value(DerivativeStructure t);
+
+}

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaUnivariateDifferentiableFunction.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaUnivariateDifferentiableFunction.java
------------------------------------------------------------------------------
    svn:keywords = "Author Date Id Revision"

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java?rev=1393702&r1=1393701&r2=1393702&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java Wed Oct  3 19:19:36 2012
@@ -20,10 +20,11 @@ import java.io.IOException;
 import java.util.Set;
 
 import org.apache.commons.math3.analysis.UnivariateFunction;
-import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
 import org.apache.commons.nabla.DifferentiationException;
 import org.apache.commons.nabla.NablaMessages;
+import org.apache.commons.nabla.forward.ForwardModeDifferentiator;
 import org.apache.commons.nabla.forward.NablaDifferentiated;
+import org.apache.commons.nabla.forward.NablaUnivariateDifferentiableFunction;
 import org.objectweb.asm.ClassReader;
 import org.objectweb.asm.Opcodes;
 import org.objectweb.asm.Type;
@@ -33,12 +34,6 @@ import org.objectweb.asm.tree.MethodNode
 /**
  * Differentiator for classes using forward mode.
  * <p>
- * This differentiator transforms classes implementing the
- * {@link UnivariateFunction UnivariateFunction} interface and convert
- * them to classes implementing the {@link UnivariateDifferentiableFunction
- * UnivariateDifferentiableFunction} interface.
- * </p>
- * <p>
  * The differentiator creates a new class in the same package as the primitive class and
  * which only preserve a private reference to the primitive instance. They access the
  * current value of all necessary primitive instance fields thanks to reflection and
@@ -51,37 +46,39 @@ import org.objectweb.asm.tree.MethodNode
  */
 public class ClassDifferentiator {
 
-    /** Name for the primitive instance field. */
-    private static final String PRIMITIVE_FIELD = "primitive";
-
     /** Name fo the constructor methods. */
     private static final String INIT = "<init>";
 
+    /** Suffix for differentiated classes. */
+    private static final String CLASS_SUFFIX = "_NablaForwardModeDifferentiated";
+
     /** Math implementation classes. */
     private final Set<String> mathClasses;
 
-    /** Class to differentiate. */
-    private final Class<? extends UnivariateFunction> primitiveClass;
-
     /** Node of the class to differentiate. */
     private final ClassNode primitiveNode;
 
     /** Class to differentiate. */
     private final ClassNode classNode;
 
+    /** Global differentiator. */
+    private final ForwardModeDifferentiator forwardDifferentiator;
+
     /**
      * Simple constructor.
      * @param primitiveClass primitive class
      * @param mathClasses math implementation classes
+     * @param forwardDifferentiator global differentiator
      * @exception DifferentiationException if class cannot be differentiated
      * @throws IOException if class cannot be read
      */
-    public ClassDifferentiator(final Class<? extends UnivariateFunction> primitiveClass,
-                               final Set<String> mathClasses)
+    public ClassDifferentiator(final Class<?> primitiveClass, final Set<String> mathClasses,
+                               final ForwardModeDifferentiator forwardDifferentiator)
         throws DifferentiationException, IOException {
 
+        this.forwardDifferentiator = forwardDifferentiator;
+
         // get the original class
-        this.primitiveClass = primitiveClass;
         final ClassReader reader =
                 new ClassReader(primitiveClass.getResourceAsStream("/" + Type.getInternalName(primitiveClass) + ".class"));
         primitiveNode = new ClassNode(Opcodes.ASM4);
@@ -89,7 +86,7 @@ public class ClassDifferentiator {
         this.mathClasses = mathClasses;
         classNode = new ClassNode(Opcodes.ASM4);
 
-        // check the UnivariateFunction interface is implemented
+        // check if the UnivariateFunction interface is implemented
         final Class<UnivariateFunction> uFuncClass = UnivariateFunction.class;
         boolean isDifferentiable = false;
         for (String interf : primitiveNode.interfaces) {
@@ -108,27 +105,22 @@ public class ClassDifferentiator {
             }
         }
 
-        if (!isDifferentiable) {
-            throw new DifferentiationException(NablaMessages.CLASS_DOES_NOT_IMPLEMENT_INTERFACE,
-                                               primitiveNode.name, uFuncClass.getName());
-        }
-
         // change the class properties for the derived class
+        final Type superType = isDifferentiable ?
+                               Type.getType(NablaUnivariateDifferentiableFunction.class) :
+                               Type.getType(NablaDifferentiated.class);
         classNode.visit(primitiveNode.version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
-                        primitiveNode.name + "_NablaForwardModeUnivariateDerivative",
-                        null, Type.getType(NablaDifferentiated.class).getInternalName(),
-                        new String[] {
-                            Type.getType(UnivariateDifferentiableFunction.class).getInternalName()
-                        });
+                        primitiveNode.name + CLASS_SUFFIX,
+                        null, superType.getInternalName(), new String[0]);
 
-        // add constructor calling NablaDifferentiated superclass constructor
+        // add constructor calling superclass constructor
         final MethodNode constructor =
                 new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT,
                                Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)),
                                null, null);
         constructor.visitVarInsn(Opcodes.ALOAD, 0);
         constructor.visitVarInsn(Opcodes.ALOAD, 1);
-        constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(NablaDifferentiated.class),
+        constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, superType.getInternalName(),
                                     INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(UnivariateFunction.class)));
         constructor.visitInsn(Opcodes.RETURN);
         constructor.visitMaxs(2, 2);
@@ -136,18 +128,18 @@ public class ClassDifferentiator {
 
     }
 
-    /** Get the name of the primitive class.
-     * @return name of the primitive class
+    /** Get the primitive class node.
+     * @return primitive class node
      */
-    public String getPrimitiveName() {
-        return primitiveNode.name;
+    public ClassNode getPrimitive() {
+        return primitiveNode;
     }
 
-    /** Get the name of the derived class.
-     * @return name of the derived class
+    /** Get the derived class node.
+     * @return derived class node
      */
-    public String getDerivedName() {
-        return classNode.name;
+    public ClassNode getDerived() {
+        return classNode;
     }
 
     /**
@@ -170,14 +162,20 @@ public class ClassDifferentiator {
 
             }
         }
+
     }
 
-    /**
-     * Get the derived class.
-     * @return derived class
+    /** Request differentiation of a method.
+     * @param owner class in which the method is defined
+     * @param isStatic if true, the method is static
+     * @param method method name
+     * @param primitiveMethodType method type in the primitive (includes return and arguments types)
+     * @return type of the differentiated method
+     * @exception DifferentiationException if class cannot be found
      */
-    public ClassNode getDerivedClass() {
-        return classNode;
+    public Type requestMethodDifferentiation(final String owner, final boolean isStatic,
+                                             final String method, final Type primitiveMethodType) {
+        return forwardDifferentiator.requestMethodDifferentiation(owner, isStatic, method, primitiveMethodType);
     }
 
 }

Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java?rev=1393702&r1=1393701&r2=1393702&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java Wed Oct  3 19:19:36 2012
@@ -42,7 +42,8 @@ import org.apache.commons.nabla.forward.
 import org.apache.commons.nabla.forward.instructions.Dup2X1Transformer;
 import org.apache.commons.nabla.forward.instructions.Dup2X2Transformer;
 import org.apache.commons.nabla.forward.instructions.GetTransformer;
-import org.apache.commons.nabla.forward.instructions.InvokeStaticTransformer;
+import org.apache.commons.nabla.forward.instructions.InvokeNonMathTransformer;
+import org.apache.commons.nabla.forward.instructions.InvokeMathTransformer;
 import org.apache.commons.nabla.forward.instructions.NarrowingTransformer;
 import org.apache.commons.nabla.forward.instructions.Pop2Transformer;
 import org.apache.commons.nabla.forward.instructions.WideningTransformer;
@@ -97,11 +98,18 @@ public class MethodDifferentiator {
         this.successors          = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>();
     }
 
+    /** Get the name of the primitive class.
+     * @return name of the primitive class
+     */
+    public String getPrimitiveName() {
+        return classDifferentiator.getPrimitive().name;
+    }
+
     /** Get the name of the derived class.
      * @return name of the derived class
      */
     public String getDerivedName() {
-        return classDifferentiator.getDerivedName();
+        return classDifferentiator.getDerived().name;
     }
 
     /**
@@ -120,8 +128,7 @@ public class MethodDifferentiator {
             // analyze the original code, tracing values production/consumption
             final FlowAnalyzer analyzer =
                 new FlowAnalyzer(new TrackingInterpreter(), method.instructions);
-            final Frame<TrackingValue>[] array =
-                    analyzer.analyze(classDifferentiator.getPrimitiveName(), method);
+            final Frame<TrackingValue>[] array = analyzer.analyze(getPrimitiveName(), method);
 
             // convert the array into a map, since code changes will shift all indices
             for (int i = 0; i < array.length; ++i) {
@@ -153,12 +160,24 @@ public class MethodDifferentiator {
                 throw (DifferentiationException) ae.getCause();
             } else {
                 throw new DifferentiationException(NablaMessages.UNABLE_TO_ANALYZE_METHOD,
-                                                   classDifferentiator.getPrimitiveName(),
-                                                   method.name, ae.getMessage());
+                                                   getPrimitiveName(), method.name, ae.getMessage());
             }
         }
     }
 
+    /** Request differentiation of a method.
+     * @param owner class in which the method is defined
+     * @param isStatic if true, the method is static
+     * @param method method name
+     * @param primitiveMethodType method type in the primitive (includes return and arguments types)
+     * @return type of the differentiated method
+     * @exception DifferentiationException if class cannot be found
+     */
+    public Type requestMethodDifferentiation(final String owner, final boolean isStatic,
+                                             final String method, final Type primitiveMethodType) {
+        return classDifferentiator.requestMethodDifferentiation(owner, isStatic, method, primitiveMethodType);
+    }
+
     /** Identify the instructions that must be changed.
      * <p>Identification is based on data flow analysis. We start by changing
      * the local variables in the initial frame to match the parameters of
@@ -372,16 +391,17 @@ public class MethodDifferentiator {
                 // TODO: add support for PUTSTATIC/PUTFIELD differentiation
                 throw new RuntimeException("PUTSTATIC/PUTFIELD not handled yet");
             case Opcodes.INVOKEVIRTUAL :
-                // TODO: add support for INVOKEVIRTUAL differentiation
-                throw new RuntimeException("INVOKEVIRTUAL not handled yet");
             case Opcodes.INVOKESPECIAL :
-                // TODO: add support for INVOKESPECIAL differentiation
-                throw new RuntimeException("INVOKESPECIAL not handled yet");
-            case Opcodes.INVOKESTATIC :
-                return new InvokeStaticTransformer().getReplacement(insn, this, dsIndex);
             case Opcodes.INVOKEINTERFACE :
-                // TODO: add support for INVOKEINTERFACE differentiation
-                throw new RuntimeException("INVOKEINTERFACE not handled yet");
+                return new InvokeNonMathTransformer().getReplacement(insn, this, dsIndex);
+            case Opcodes.INVOKESTATIC : {
+                final MethodInsnNode methodInsn = (MethodInsnNode) insn;
+                if (mathClasses.contains(methodInsn.owner)) {
+                    return new InvokeMathTransformer().getReplacement(insn, this, dsIndex);
+                } else {
+                    return new InvokeNonMathTransformer().getReplacement(insn, this, dsIndex);
+                }
+            }
             case Opcodes.INVOKEDYNAMIC :
                 // TODO: add support for INVOKEDYNAMIC differentiation
                 throw new RuntimeException("INVOKEDYNAMIC not handled yet");
@@ -400,14 +420,6 @@ public class MethodDifferentiator {
 
     }
 
-    /** Test if a class is a math implementation class.
-     * @param name name of the class to test
-     * @return true if the named class is a math implementation class
-     */
-    public boolean isMathImplementationClass(final String name) {
-        return mathClasses.contains(name);
-    }
-
     /** Create instructions to preserve a reference {@link DerivativeStructure} variable.
      * @param derivedMethodType type of the derived method
      * @param isStatic if true, the method is a static method

Copied: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeMathTransformer.java (from r1393701, commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeStaticTransformer.java)
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeMathTransformer.java?p2=commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeMathTransformer.java&p1=commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeStaticTransformer.java&r1=1393701&r2=1393702&rev=1393702&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeStaticTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeMathTransformer.java Wed Oct  3 19:19:36 2012
@@ -28,10 +28,10 @@ import org.objectweb.asm.tree.InsnList;
 import org.objectweb.asm.tree.InsnNode;
 import org.objectweb.asm.tree.MethodInsnNode;
 
-/** Differentiation transformer for INVOKESTATIC instructions.
+/** Differentiation transformer for INVOKESTATIC instructions on math related classes.
  * @version $Id$
  */
-public class InvokeStaticTransformer implements InstructionsTransformer {
+public class InvokeMathTransformer implements InstructionsTransformer {
 
     /** {@inheritDoc} */
     public InsnList getReplacement(final AbstractInsnNode insn,
@@ -40,11 +40,6 @@ public class InvokeStaticTransformer imp
         throws DifferentiationException {
 
         final MethodInsnNode methodInsn = (MethodInsnNode) insn;
-        if (!methodDifferentiator.isMathImplementationClass(methodInsn.owner)) {
-            // TODO: handle INVOKESTATIC on non math related classes
-            throw new RuntimeException("INVOKESTATIC on non math related classes not handled yet" +
-                                       methodInsn.owner + "." + methodInsn.name);
-        }
 
         final InsnList list = new InsnList();
 

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeMathTransformer.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeMathTransformer.java
------------------------------------------------------------------------------
    svn:keywords = "Author Date Id Revision"

Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java?rev=1393702&view=auto
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java (added)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java Wed Oct  3 19:19:36 2012
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.nabla.forward.instructions;
+
+import org.apache.commons.nabla.DifferentiationException;
+import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
+import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
+import org.objectweb.asm.Opcodes;
+import org.objectweb.asm.Type;
+import org.objectweb.asm.tree.AbstractInsnNode;
+import org.objectweb.asm.tree.InsnList;
+import org.objectweb.asm.tree.MethodInsnNode;
+
+/** Differentiation transformer for INVOKESPECIAL/INVOKEVIRTUAL/INVOKESTATIC/INVOKEINTERFACE
+ * instructions on non-math related classes.
+ * @version $Id$
+ */
+public class InvokeNonMathTransformer implements InstructionsTransformer {
+
+    /** {@inheritDoc} */
+    public InsnList getReplacement(final AbstractInsnNode insn,
+                                   final MethodDifferentiator methodDifferentiator,
+                                   final int dsIndex)
+        throws DifferentiationException {
+
+        final MethodInsnNode methodInsn = (MethodInsnNode) insn;
+
+        // request the global differentiator to differentiate the invoked method
+        Type differentiatedMethodType =
+                methodDifferentiator.requestMethodDifferentiation(Type.getType("L" + methodInsn.owner + ";").getClassName(),
+                                                                  methodInsn.getOpcode() == Opcodes.INVOKESTATIC,
+                                                                  methodInsn.name, Type.getMethodType(methodInsn.desc));
+
+        final InsnList list = new InsnList();
+        list.add(new MethodInsnNode(methodInsn.getOpcode(),
+                                    methodDifferentiator.getDerivedName(), methodInsn.name,
+                                    differentiatedMethodType.getDescriptor()));
+        return list;
+
+    }
+
+}

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/InvokeNonMathTransformer.java
------------------------------------------------------------------------------
    svn:keywords = "Author Date Id Revision"

Modified: commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties?rev=1393702&r1=1393701&r2=1393702&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties (original)
+++ commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties Wed Oct  3 19:19:36 2012
@@ -20,7 +20,6 @@ CANNOT_BUILD_CLASS_FROM_OTHER_CLASS = la
 CANNOT_INSTANTIATE_CLASS_FROM_OTHER_INSTANCE = la classe {0} ne peut pas \u00eatre instanci\u00e9e \u00e0 partir d''une instance de la classe {1} ({2})
 INCORRECT_GENERATED_CODE = code incorrect pour la classe {0} g\u00e9n\u00e9r\u00e9e \u00e0 partir d''une instance de la classe {1} ({2})
 INTERFACE_NOT_FOUND_WHILE_DIFFERENTIATING = interface {0} non trouv\u00e9e lors de la diff\u00e9rentiation de la classe {1}
-CLASS_DOES_NOT_IMPLEMENT_INTERFACE = la classe {0} ne met pas en \u0153uvre l''interface {1}
 UNABLE_TO_ANALYZE_METHOD = impossible d''analyser la m\u00e9thode {0}.{1} ({2})
 UNKNOWN_METHOD = m\u00e9thode {0}.{1} inconnue
 UNEXPECTED_INSTRUCTION = instruction ayant l''opcode {0} inattendue

Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/NablaMessagesTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/NablaMessagesTest.java?rev=1393702&r1=1393701&r2=1393702&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/NablaMessagesTest.java (original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/NablaMessagesTest.java Wed Oct  3 19:19:36 2012
@@ -36,7 +36,7 @@ public class NablaMessagesTest {
 
     @Test
     public void testMessageNumber() {
-        Assert.assertEquals(15, NablaMessages.values().length);
+        Assert.assertEquals(14, NablaMessages.values().length);
     }
 
     @Test