You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2009/08/17 21:50:52 UTC
svn commit: r805114 - in
/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic:
./ forward/ forward/analysis/ forward/instructions/
Author: luc
Date: Mon Aug 17 19:50:51 2009
New Revision: 805114
URL: http://svn.apache.org/viewvc?rev=805114&view=rev
Log:
use the tree API for classes too (in addition to methods)
this will allow modifying fields and methods called by the differentiated method
Added:
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java (with props)
Modified:
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java
Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java?rev=805114&view=auto
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java (added)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java Mon Aug 17 19:50:51 2009
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.nabla.algorithmic;
+
+import org.apache.commons.nabla.core.DifferentialPair;
+
+/**
+ * Interface defining methods descriptors for differentials.
+ */
+public interface Descriptors {
+
+ /** Name for the DifferentialPair class. */
+ String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
+
+ /** Descriptor for the DifferentialPair class. */
+ String DP_DESCRIPTOR = "L" + DP_NAME + ";";
+
+ /** Descriptor for the primitive class f method. */
+ String D_RETURN_D_DESCRIPTOR = "(D)D";
+
+ /** Descriptor for the derivative class f method. */
+ String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
+
+ /** Descriptor for <code>DifferentialPair f(double)</code> methods. */
+ String D_RETURN_DP_DESCRIPTOR = "(D)" + DP_DESCRIPTOR;
+
+ /** Descriptor for <code>double f()</code> methods. */
+ String VOID_RETURN_D_DESCRIPTOR = "()D";
+
+}
Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java
------------------------------------------------------------------------------
svn:eol-style = native
Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/Descriptors.java
------------------------------------------------------------------------------
svn:keywords = Author Date Id Revision
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/ForwardModeAlgorithmicDifferentiator.java Mon Aug 17 19:50:51 2009
@@ -25,6 +25,7 @@
import java.util.HashSet;
import java.util.Set;
+import org.apache.commons.nabla.algorithmic.Descriptors;
import org.apache.commons.nabla.algorithmic.forward.analysis.ForwardModeClassDifferentiator;
import org.apache.commons.nabla.core.DifferentiationException;
import org.apache.commons.nabla.core.UnivariateDerivative;
@@ -32,6 +33,7 @@
import org.apache.commons.nabla.core.UnivariateDifferentiator;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
+import org.objectweb.asm.tree.ClassNode;
/** Algorithmic differentiator class in forward mode based on bytecode analysis.
* <p>This class is an implementation of the {@link UnivariateDifferentiator}
@@ -163,19 +165,21 @@
throws DifferentiationException {
try {
- // set up both ends of the class transform chain
+ // get the original class
final String classResourceName = "/" + differentiableClass.getName().replace('.', '/') + ".class";
final InputStream stream = differentiableClass.getResourceAsStream(classResourceName);
final ClassReader reader = new ClassReader(stream);
- final ClassWriter writer = new ClassWriter(reader, ClassWriter.COMPUTE_FRAMES);
// differentiate the function embedded in the differentiable class
- final ForwardModeClassDifferentiator differentiator = new ForwardModeClassDifferentiator(mathClasses, writer);
- reader.accept(differentiator, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
- differentiator.reportErrors();
+ final ForwardModeClassDifferentiator differentiator =
+ new ForwardModeClassDifferentiator(reader, mathClasses);
+ differentiator.differentiateMethod("f", Descriptors.D_RETURN_D_DESCRIPTOR,
+ Descriptors.DP_RETURN_DP_DESCRIPTOR);
// create the derivative class
- return new DerivativeLoader(differentiableClass).defineClass(differentiator, writer);
+ final ClassNode derived = differentiator.getDerivedClass();
+ final ClassWriter writer = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
+ return new DerivativeLoader(differentiableClass).defineClass(derived, writer);
} catch (IOException ioe) {
throw new DifferentiationException("class {0} cannot be read ({1})",
@@ -194,14 +198,15 @@
}
/** Define a derivative class.
- * @param differentiator class differentiator
+ * @param classNode differentiated class
* @param writer class writer
* @return a generated derivative class
*/
@SuppressWarnings("unchecked")
public Class<? extends UnivariateDerivative>
- defineClass(final ForwardModeClassDifferentiator differentiator, final ClassWriter writer) {
- final String name = differentiator.getDerivativeClassName().replace('/', '.');
+ defineClass(final ClassNode classNode, final ClassWriter writer) {
+ final String name = classNode.name.replace('/', '.');
+ classNode.accept(writer);
final byte[] bytecode = writer.toByteArray();
return (Class<? extends UnivariateDerivative>) defineClass(name, bytecode, 0, bytecode.length);
}
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/ForwardModeClassDifferentiator.java Mon Aug 17 19:50:51 2009
@@ -16,37 +16,38 @@
*/
package org.apache.commons.nabla.algorithmic.forward.analysis;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Set;
import org.apache.commons.nabla.core.DifferentiationException;
import org.apache.commons.nabla.core.UnivariateDerivative;
import org.apache.commons.nabla.core.UnivariateDifferentiable;
-import org.objectweb.asm.AnnotationVisitor;
-import org.objectweb.asm.Attribute;
-import org.objectweb.asm.ClassVisitor;
-import org.objectweb.asm.FieldVisitor;
-import org.objectweb.asm.MethodVisitor;
+import org.objectweb.asm.ClassReader;
import org.objectweb.asm.Opcodes;
+import org.objectweb.asm.tree.ClassNode;
+import org.objectweb.asm.tree.FieldNode;
+import org.objectweb.asm.tree.MethodNode;
/**
- * Visitor (in asm sense) for differentiating classes using forward mode.
+ * Differentiator for classes using forward mode.
* <p>
- * This visitor visits classes implementing the
+ * This differentiator transforms classes implementing the
* {@link UnivariateDifferentiable UnivariateDifferentiable} interface and convert
* them to classes implementing the {@link UnivariateDerivative
* UnivariateDerivative} interface.
* </p>
* <p>
- * The visitor creates a new class as an inner class of the visited class.
+ * The differentiator creates a new class as an inner class of the visited class.
* Instances of the generated class are therefore automatically bound to their
* primitive instance which is their directly enclosing instance. As such they
* have access to the current value of all fields.
* </p>
* <p>
- * The visited class bytecode is not changed at all.
+ * The original class bytecode is not changed at all.
* </p>
*/
-public class ForwardModeClassDifferentiator implements ClassVisitor {
+public class ForwardModeClassDifferentiator {
/** Name for the primitive instance field. */
private static final String PRIMITIVE_FIELD = "primitive";
@@ -54,57 +55,42 @@
/** Math implementation classes. */
private final Set<String> mathClasses;
- /** Class generating visitor. */
- private final ClassVisitor generator;
-
- /** Error reporter. */
- private final ErrorReporter errorReporter;
+ /** Class to differentiate. */
+ private final ClassNode classNode;
/** Primitive class name. */
- private String primitiveName;
-
- /** Descriptor for the primitive class. */
- private String primitiveDesc;
+ private final String primitiveName;
- /** Derivative class name. */
- private String derivativeName;
+ /** Primitive class methods. */
+ private final List<MethodNode> primitiveMethods;
- /** Indicator for specific fields and method addition. */
- private boolean specificMembersAdded;
+ /** Descriptor for the primitive class. */
+ private final String primitiveDesc;
/**
* Simple constructor.
+ * @param reader reader for the primitive class
* @param mathClasses math implementation classes
- * @param generator visitor to which class generation calls will be delegated
+ * @exception DifferentiationException if class cannot be differentiated
*/
- public ForwardModeClassDifferentiator(final Set<String> mathClasses,
- final ClassVisitor generator) {
- this.mathClasses = mathClasses;
- this.generator = generator;
- errorReporter = new ErrorReporter();
- }
+ @SuppressWarnings("unchecked")
+ public ForwardModeClassDifferentiator(final ClassReader reader,
+ final Set<String> mathClasses)
+ throws DifferentiationException {
- /**
- * Get the name of the derivative class.
- * @return name of the (generated) derivative class
- */
- public String getDerivativeClassName() {
- return derivativeName;
- }
+ classNode = new ClassNode();
+ reader.accept(classNode, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
+ this.mathClasses = mathClasses;
- /** {@inheritDoc} */
- public void visit(final int version, final int access,
- final String name, final String signature,
- final String superName, final String[] interfaces) {
- // set up the various names
- primitiveName = name;
- derivativeName = primitiveName + "$NablaForwardModeUnivariateDerivative";
- primitiveDesc = "L" + primitiveName + ";";
+ // store the primitive class properties
+ primitiveName = classNode.name;
+ primitiveDesc = "L" + primitiveName + ";";
+ primitiveMethods = classNode.methods;
// check the UnivariateDifferentiable interface is implemented
final Class<UnivariateDifferentiable> uDerClass = UnivariateDifferentiable.class;
boolean isDifferentiable = false;
- for (String interf : interfaces) {
+ for (String interf : (List<String>) classNode.interfaces) {
final String interfName = interf.replace('/', '.');
Class<?> interfClass = null;
try {
@@ -112,162 +98,106 @@
} catch (ClassNotFoundException cnfe) {
// this should never occur since class has already been loaded
// and an instance already exists ...
- errorReporter.register(new DifferentiationException("interface {0} not found " +
- "while differentiating class {1}",
- interfName, name));
+ throw new DifferentiationException("interface {0} not found " +
+ "while differentiating class {1}",
+ interfName, primitiveName);
}
if (interfClass != null) {
isDifferentiable = isDifferentiable || uDerClass.isAssignableFrom(interfClass);
}
}
- if (isDifferentiable) {
- // generate the new class implementing the UnivariateDerivative interface
- generator.visit(version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
- derivativeName, signature, superName,
- new String[] {
- UnivariateDerivative.class.getName().replace('.', '/')
- });
- } else {
- errorReporter.register(new DifferentiationException("the {0} class does not implement " +
- "the {1} interface",
- name, uDerClass.getName()));
- }
-
- specificMembersAdded = false;
+ if (!isDifferentiable) {
+ throw new DifferentiationException("the {0} class does not implement the {1} interface",
+ primitiveName, uDerClass.getName());
+ }
+
+ // change the class properties for the derived class
+ classNode.access = Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC;
+ classNode.name = primitiveName + "$NablaForwardModeUnivariateDerivative";
+ classNode.fields = new ArrayList<FieldNode>();
+ classNode.methods = new ArrayList<MethodNode>();
+ classNode.interfaces.clear();
+ classNode.interfaces.add(UnivariateDerivative.class.getName().replace('.', '/'));
+
+ // primitive instance field and methods setting/getting it
+ addPrimitiveField();
+ addConstructor();
+ addGetPrimitive();
}
- /** {@inheritDoc} */
- public MethodVisitor visitMethod(final int access, final String name,
- final String desc, final String signature,
- final String[] exceptions) {
-
- // don't do anything if an error has already been encountered
- if (errorReporter.hasError()) {
- return null;
- }
-
- if (!specificMembersAdded) {
- // add the specific members we need
- addPrimitiveField();
- addConstructor();
- addGetPrimitive();
- specificMembersAdded = true;
- }
-
- // is it the "public double f(double)" method we want to differentiate ?
- if (((access & Opcodes.ACC_PUBLIC) == Opcodes.ACC_PUBLIC) &&
- "f".equals(name) && "(D)D".equals(desc) &&
- ((exceptions == null) || (exceptions.length == 0))) {
-
- // get a generator for the method we are going to create
- final MethodVisitor visitor =
- generator.visitMethod(access | Opcodes.ACC_SYNTHETIC, name,
- MethodDifferentiator.DP_RETURN_DP_DESCRIPTOR, null, null);
-
- // make sure our own differentiator will be used to transform the code
- return new MethodDifferentiator(access, name, desc, signature, exceptions,
- visitor, primitiveName, mathClasses, errorReporter);
+ /**
+ * Differentiate a method.
+ * @param name of the method
+ * @param primitiveDesc descriptor of the method in the primitive class
+ * @param derivativeDesc descriptor of the method in the derivative class
+ * @exception DifferentiationException if method cannot be differentiated
+ */
+ @SuppressWarnings("unchecked")
+ public void differentiateMethod(final String name, final String primitiveDesc,
+ final String derivativeDesc)
+ throws DifferentiationException {
+
+ for (final MethodNode method : primitiveMethods) {
+ if (method.name.equals(name) && method.desc.equals(primitiveDesc)) {
+
+ final MethodDifferentiator differentiator = new MethodDifferentiator(mathClasses);
+ differentiator.differentiate(primitiveName, method);
+ classNode.methods.add(method);
+ }
}
-
- // we are not interested in this method
- return null;
-
- }
-
- /** {@inheritDoc} */
- public FieldVisitor visitField(final int access, final String name,
- final String desc, final String signature,
- final Object value) {
- // we are not interested in any fields
- return null;
- }
-
- /** {@inheritDoc} */
- public void visitSource(final String source, final String debug) {
- }
-
- /** {@inheritDoc} */
- public void visitOuterClass(final String owner, final String name,
- final String desc) {
- }
-
- /** {@inheritDoc} */
- public AnnotationVisitor visitAnnotation(final String desc,
- final boolean visible) {
- return null;
- }
-
- /** {@inheritDoc} */
- public void visitAttribute(final Attribute attr) {
- }
-
- /** {@inheritDoc} */
- public void visitInnerClass(final String name, final String outerName,
- final String innerName, final int access) {
}
- /** {@inheritDoc} */
- public void visitEnd() {
-
- // don't do anything if an error has already been encountered
- if (errorReporter.hasError()) {
- return;
- }
-
- generator.visitEnd();
-
+ /**
+ * Get the derived class.
+ * @return derived class
+ */
+ public ClassNode getDerivedClass() {
+ return classNode;
}
/** Add the primitive field.
*/
+ @SuppressWarnings("unchecked")
private void addPrimitiveField() {
- final FieldVisitor visitor =
- generator.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
- PRIMITIVE_FIELD, primitiveDesc, null, null);
- visitor.visitEnd();
+ FieldNode primitiveField =
+ new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
+ PRIMITIVE_FIELD, primitiveDesc, null, null);
+ classNode.fields.add(primitiveField);
}
/** Add the class constructor.
*/
+ @SuppressWarnings("unchecked")
private void addConstructor() {
final String init = "<init>";
- final MethodVisitor visitor =
- generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init,
- "(" + primitiveDesc + ")V", null, null);
- visitor.visitCode();
- visitor.visitVarInsn(Opcodes.ALOAD, 0);
- visitor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V");
- visitor.visitVarInsn(Opcodes.ALOAD, 0);
- visitor.visitVarInsn(Opcodes.ALOAD, 1);
- visitor.visitFieldInsn(Opcodes.PUTFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc);
- visitor.visitInsn(Opcodes.RETURN);
- visitor.visitMaxs(0, 0);
- visitor.visitEnd();
+ final MethodNode constructor =
+ new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, init,
+ "(" + primitiveDesc + ")V", null, null);
+ constructor.visitVarInsn(Opcodes.ALOAD, 0);
+ constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", init, "()V");
+ constructor.visitVarInsn(Opcodes.ALOAD, 0);
+ constructor.visitVarInsn(Opcodes.ALOAD, 1);
+ constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD, primitiveDesc);
+ constructor.visitInsn(Opcodes.RETURN);
+ constructor.visitMaxs(0, 0);
+ classNode.methods.add(constructor);
}
/** Add the {@link UnivariateDerivative#getPrimitive() getPrimitive()} method.
*/
+ @SuppressWarnings("unchecked")
private void addGetPrimitive() {
- final MethodVisitor visitor =
- generator.visitMethod(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive",
- "()" + primitiveDesc, null, null);
- visitor.visitCode();
- visitor.visitVarInsn(Opcodes.ALOAD, 0);
- visitor.visitFieldInsn(Opcodes.GETFIELD, derivativeName, PRIMITIVE_FIELD, primitiveDesc);
- visitor.visitInsn(Opcodes.ARETURN);
- visitor.visitMaxs(0, 0);
- visitor.visitEnd();
- }
-
- /** Report the errors that may have occurred during analysis.
- * @exception DifferentiationException if the derivative class
- * could not be generated
- */
- public void reportErrors() throws DifferentiationException {
- errorReporter.reportErrors();
- }
+ final MethodNode method =
+ new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, "getPrimitive",
+ "()" + primitiveDesc, null, null);
+ method.visitVarInsn(Opcodes.ALOAD, 0);
+ method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD, primitiveDesc);
+ method.visitInsn(Opcodes.ARETURN);
+ method.visitMaxs(0, 0);
+ classNode.methods.add(method);
+ }
}
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java?rev=805114&r1=805113&r2=805114&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/analysis/MethodDifferentiator.java Mon Aug 17 19:50:51 2009
@@ -25,6 +25,7 @@
import java.util.Map;
import java.util.Set;
+import org.apache.commons.nabla.algorithmic.Descriptors;
import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer1;
import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer12;
import org.apache.commons.nabla.algorithmic.forward.arithmetic.DAddTransformer2;
@@ -86,9 +87,7 @@
import org.apache.commons.nabla.algorithmic.forward.trimming.DLoadPop2Trimmer;
import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDloadTrimmer;
import org.apache.commons.nabla.algorithmic.forward.trimming.SwappedDstoreTrimmer;
-import org.apache.commons.nabla.core.DifferentialPair;
import org.apache.commons.nabla.core.DifferentiationException;
-import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.IincInsnNode;
@@ -107,22 +106,7 @@
/** Class transforming a method computing a value to a method
* computing both a value and its differential.
*/
-public class MethodDifferentiator extends MethodNode {
-
- /** Name for the DifferentialPair class. */
- public static final String DP_NAME = DifferentialPair.class.getName().replace('.', '/');
-
- /** Descriptor for the DifferentialPair class. */
- public static final String DP_DESCRIPTOR = "L" + DP_NAME + ";";
-
- /** Descriptor for the derivative class f method. */
- public static final String DP_RETURN_DP_DESCRIPTOR = "(" + DP_DESCRIPTOR + ")" + DP_DESCRIPTOR;
-
- /** Descriptor for <code>DifferentialPair f(double)</code> methods. */
- public static final String D_RETURN_DP_DESCRIPTOR = "(D)" + DP_DESCRIPTOR;
-
- /** Descriptor for <code>double f()</code> methods. */
- private static final String VOID_RETURN_D_DESCRIPTOR = "()D";
+public class MethodDifferentiator {
/** Math functions transformer. */
private static final Map<String, MathInvocationTransformer> MATH_TRANSFORMERS =
@@ -168,18 +152,9 @@
/** Math implementation classes. */
private final Set<String> mathClasses;
- /** Generator to use. */
- private final MethodVisitor generator;
-
/** Used locals variables array. */
private boolean[] usedLocals;
- /** Primitive class name. */
- private final String primitiveName;
-
- /** Error reporter to use. */
- private final ErrorReporter errorReporter;
-
/** Set of converted values. */
private final Set<TrackingValue> converted;
@@ -193,110 +168,95 @@
private final Map<LabelNode, LabelNode> clonedLabels;
/** Build a differentiator for a method.
- * @param access access flags of the method
- * @param name name of the method
- * @param desc descriptor of the method
- * @param signature signature of the method
- * @param exceptions exceptions thrown by the method
- * @param generator bytecode generator to use for the transformed method
- * @param primitiveName primitive class name
* @param mathClasses math implementation classes
- * @param errorReporter reporter used for delaying exceptions
*/
- public MethodDifferentiator(final int access, final String name, final String desc,
- final String signature, final String[] exceptions,
- final MethodVisitor generator,final String primitiveName,
- final Set<String> mathClasses,
- final ErrorReporter errorReporter) {
-
- super(access, name, desc, signature, exceptions);
- this.generator = generator;
+ public MethodDifferentiator(final Set<String> mathClasses) {
this.usedLocals = null;
- this.primitiveName = primitiveName;
this.mathClasses = mathClasses;
- this.errorReporter = errorReporter;
this.converted = new HashSet<TrackingValue>();
this.frames = new IdentityHashMap<AbstractInsnNode, Frame>();
this.successors = new IdentityHashMap<AbstractInsnNode, Set<AbstractInsnNode>>();
this.clonedLabels = new HashMap<LabelNode, LabelNode>();
-
}
- /** {@inheritDoc} */
- @Override
- public void visitEnd() {
+ /**
+ * Differentiate a method.
+ * @param primitiveName primitive class name
+ * @param method method to differentiate (<em>will</em> be modified)
+ * @exception DifferentiationException if method cannot be differentiated
+ */
+ public void differentiate(final String primitiveName, final MethodNode method)
+ throws DifferentiationException {
try {
// at start, "this" and one differential pair are already used
- maxLocals = 2 * (maxLocals + MAX_TEMP) - 1;
- usedLocals = new boolean[maxLocals];
+ method.maxLocals = 2 * (method.maxLocals + MAX_TEMP) - 1;
+ usedLocals = new boolean[method.maxLocals];
useLocal(0, 1);
useLocal(1, 4);
// add spare cells to hold new variables if needed
- addSpareLocalVariables();
+ addSpareLocalVariables(method.instructions);
// analyze the original code, tracing values production/consumption
- final Frame[] array =
- new FlowAnalyzer(new TrackingInterpreter()).analyze(primitiveName, this);
+ final FlowAnalyzer analyzer =
+ new FlowAnalyzer(new TrackingInterpreter(), method.instructions);
+ final Frame[] array = analyzer.analyze(primitiveName, method);
// convert the array into a map, since code changes will shift all indices
for (int i = 0; i < array.length; ++i) {
- frames.put(instructions.get(i), array[i]);
+ frames.put(method.instructions.get(i), array[i]);
}
// identify the needed changes
- final Set<AbstractInsnNode> changes = identifyChanges();
+ final Set<AbstractInsnNode> changes = identifyChanges(method.instructions);
if (changes.isEmpty()) {
// the method does not depend on the parameter at all!
// we replace all "return d;" by "return DifferentialPair.newConstant(d);"
- for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
+ for (final Iterator<?> i = method.instructions.iterator(); i.hasNext();) {
final AbstractInsnNode insn = (AbstractInsnNode) i.next();
if (insn.getOpcode() == Opcodes.DRETURN) {
final InsnList list = new InsnList();
list.add(new MethodInsnNode(Opcodes.INVOKESTATIC,
- MethodDifferentiator.DP_NAME,
- "newConstant", D_RETURN_DP_DESCRIPTOR));
+ Descriptors.DP_NAME,
+ "newConstant",
+ Descriptors.D_RETURN_DP_DESCRIPTOR));
list.add(new InsnNode(Opcodes.ARETURN));
- instructions.insert(insn, list);
- instructions.remove(insn);
+ method.instructions.insert(insn, list);
+ method.instructions.remove(insn);
}
}
} else {
// perform the code changes
- changeCode(changes);
-
- // remove the local variables added at the beginning and not used
- removeUnusedSpareLocalVariables();
+ changeCode(method.instructions, changes);
// trim generated instructions list
- SwappedDloadTrimmer.getInstance().trim(instructions);
- SwappedDstoreTrimmer.getInstance().trim(instructions);
- DLoadPop2Trimmer.getInstance().trim(instructions);
+ SwappedDloadTrimmer.getInstance().trim(method.instructions);
+ SwappedDstoreTrimmer.getInstance().trim(method.instructions);
+ DLoadPop2Trimmer.getInstance().trim(method.instructions);
}
- // change the descriptor to its true final value
- desc = DP_RETURN_DP_DESCRIPTOR;
+ // remove the local variables added at the beginning and not used
+ removeUnusedSpareLocalVariables(method.instructions);
- // generate the method
- accept(generator);
+ // change the method properties to the derivative ones
+ method.desc = Descriptors.DP_RETURN_DP_DESCRIPTOR;
+ method.access |= Opcodes.ACC_SYNTHETIC;
+ method.maxLocals = maxVariables();
} catch (AnalyzerException ae) {
+ ae.printStackTrace(System.err);
if ((ae.getCause() != null) && ae.getCause() instanceof DifferentiationException) {
- errorReporter.register((DifferentiationException) ae.getCause());
+ throw (DifferentiationException) ae.getCause();
} else {
- final DifferentiationException de =
- new DifferentiationException("unable to analyze the {0}.{1} method ({2})",
- primitiveName, name, ae.getMessage());
- errorReporter.register(de);
+ throw new DifferentiationException("unable to analyze the {0}.{1} method ({2})",
+ primitiveName, method.name, ae.getMessage());
}
- } catch (DifferentiationException de) {
- errorReporter.register(de);
}
}
@@ -309,11 +269,13 @@
* be referenced by the converted instructions in the following passes.</p>
* <p>The spare cells that will not be used will be reclaimed after
* conversion, to avoid wasting memory.</p>
+ * @param instructions instructions of the method
* @exception DifferentiationException if local variables array has not been
* expanded appropriately beforehand
* @see #removeUnusedSpareLocalVariables()
*/
- private void addSpareLocalVariables() throws DifferentiationException {
+ private void addSpareLocalVariables(final InsnList instructions)
+ throws DifferentiationException {
for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
final AbstractInsnNode insn = (AbstractInsnNode) i.next();
if (insn.getType() == AbstractInsnNode.VAR_INSN) {
@@ -340,9 +302,10 @@
}
/** Remove the unused spare cells introduced at conversion start.
+ * @param instructions instructions of the method
* @see #addSpareLocalVariables()
*/
- private void removeUnusedSpareLocalVariables() {
+ private void removeUnusedSpareLocalVariables(final InsnList instructions) {
for (final Iterator<?> i = instructions.iterator(); i.hasNext();) {
final AbstractInsnNode insn = (AbstractInsnNode) i.next();
if (insn.getType() == AbstractInsnNode.VAR_INSN) {
@@ -358,9 +321,10 @@
* instructions path, updating stack cells and local variables as needed.
* Instructions that must be changed are the ones that consume changed
* variables or stack cells.</p>
+ * @param instructions instructions of the method
* @return set containing all the instructions that must be changed
*/
- private Set<AbstractInsnNode> identifyChanges() {
+ private Set<AbstractInsnNode> identifyChanges(final InsnList instructions) {
// the pending set contains the values (local variables or stack cells)
// that have been changed, they will trigger changes on the instructions
@@ -461,21 +425,22 @@
}
/** Perform the code changes.
+ * @param instructions instructions of the method
* @param changes instructions that must be changed
* @exception DifferentiationException if some instruction cannot be handled
*/
- private void changeCode(final Set<AbstractInsnNode> changes)
+ private void changeCode(final InsnList instructions, final Set<AbstractInsnNode> changes)
throws DifferentiationException {
// insert the parameter conversion code at method start
final InsnList list = new InsnList();
list.add(new VarInsnNode(Opcodes.ALOAD, 1));
list.add(new InsnNode(Opcodes.DUP));
- list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
- "getValue", VOID_RETURN_D_DESCRIPTOR));
+ list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Descriptors.DP_NAME,
+ "getValue", Descriptors.VOID_RETURN_D_DESCRIPTOR));
list.add(new VarInsnNode(Opcodes.DSTORE, 1));
- list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, DP_NAME,
- "getFirstDerivative", VOID_RETURN_D_DESCRIPTOR));
+ list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, Descriptors.DP_NAME,
+ "getFirstDerivative", Descriptors.VOID_RETURN_D_DESCRIPTOR));
list.add(new VarInsnNode(Opcodes.DSTORE, 3));
instructions.insertBefore(instructions.get(0), list);
@@ -688,7 +653,7 @@
* @param name name of the class to test
* @return true if the named class is a math implementation class
*/
- public boolean isMathImplementationClass(final String name) {
+ private boolean isMathImplementationClass(final String name) {
return mathClasses.contains(name);
}
@@ -744,7 +709,7 @@
/** Shifted the index of a variable instruction.
* @param insn variable instruction
*/
- public void shiftVariable(final VarInsnNode insn) {
+ private void shiftVariable(final VarInsnNode insn) {
int shifted = 0;
for (int i = 0; i < insn.var; ++i) {
if (usedLocals[i]) {
@@ -754,6 +719,19 @@
insn.var = shifted;
}
+ /** Compute the maximal number of used local variables.
+ * @return maximal number of used local variables
+ */
+ private int maxVariables() {
+ int max = 0;
+ for (final boolean isUsed : usedLocals) {
+ if (isUsed) {
+ ++max;
+ }
+ }
+ return max;
+ }
+
/** Clone an instruction.
* @param insn instruction to clone
* @return cloned instruction
@@ -765,11 +743,17 @@
/** Analyzer preserving instructions successors information. */
private class FlowAnalyzer extends Analyzer {
+ /** Instructions of the method. */
+ private final InsnList instructions;
+
/** Simple constructor.
* @param interpreter associated interpreter
+ * @param instructions instructions of the method
*/
- public FlowAnalyzer(final Interpreter interpreter) {
+ public FlowAnalyzer(final Interpreter interpreter,
+ final InsnList instructions) {
super(interpreter);
+ this.instructions = instructions;
}
/** Store a new edge.
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java?rev=805114&r1=805113&r2=805114&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/algorithmic/forward/instructions/DReturnTransformer.java Mon Aug 17 19:50:51 2009
@@ -16,6 +16,7 @@
*/
package org.apache.commons.nabla.algorithmic.forward.instructions;
+import org.apache.commons.nabla.algorithmic.Descriptors;
import org.apache.commons.nabla.algorithmic.forward.analysis.InstructionsTransformer;
import org.apache.commons.nabla.algorithmic.forward.analysis.MethodDifferentiator;
import org.apache.commons.nabla.core.DifferentiationException;
@@ -68,17 +69,16 @@
final InsnList list = new InsnList();
// operand stack initial state: a0, a1
- list.add(new VarInsnNode(Opcodes.DSTORE, 3)); // => a0
- list.add(new VarInsnNode(Opcodes.DSTORE, 1)); // =>
- list.add(new TypeInsnNode(Opcodes.NEW,
- MethodDifferentiator.DP_NAME)); // => o,
- list.add(new InsnNode(Opcodes.DUP)); // => o, o
- list.add(new VarInsnNode(Opcodes.DLOAD, 1)); // => o, o, a0
- list.add(new VarInsnNode(Opcodes.DLOAD, 3)); // => o, o, a0, a1
+ list.add(new VarInsnNode(Opcodes.DSTORE, 3)); // => a0
+ list.add(new VarInsnNode(Opcodes.DSTORE, 1)); // =>
+ list.add(new TypeInsnNode(Opcodes.NEW, Descriptors.DP_NAME)); // => o,
+ list.add(new InsnNode(Opcodes.DUP)); // => o, o
+ list.add(new VarInsnNode(Opcodes.DLOAD, 1)); // => o, o, a0
+ list.add(new VarInsnNode(Opcodes.DLOAD, 3)); // => o, o, a0, a1
list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL,
- MethodDifferentiator.DP_NAME,
- "<init>", "(DD)V")); // => dp
- list.add(new InsnNode(Opcodes.ARETURN)); // =>
+ Descriptors.DP_NAME,
+ "<init>", "(DD)V")); // => dp
+ list.add(new InsnNode(Opcodes.ARETURN)); // =>
return list;
}