You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2012/10/03 21:18:39 UTC
svn commit: r1393700 - in /commons/sandbox/nabla/trunk/src:
main/java/org/apache/commons/nabla/forward/
main/java/org/apache/commons/nabla/forward/analysis/
main/java/org/apache/commons/nabla/forward/instructions/
test/java/org/apache/commons/nabla/for...
Author: luc
Date: Wed Oct 3 19:18:38 2012
New Revision: 1393700
URL: http://svn.apache.org/viewvc?rev=1393700&view=rev
Log:
Use a normal base class instead of creating boilerplate code manually.
The fact the differentiated class will now extend a base class provided
by Nabla implies there will remain a dependence to Nabla at runtime,
even if we set up class persistance and save the classes in a jar. So
the base class NablaDifferentiated should probably be packaged
separately with maven as a nabla-runtime jar (which will contain only
this class).
Added:
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java (with props)
Modified:
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java
commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java?rev=1393700&r1=1393699&r2=1393700&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/ForwardModeDifferentiator.java Wed Oct 3 19:18:38 2012
@@ -26,7 +26,6 @@ import java.util.Set;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
-import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
import org.apache.commons.math3.analysis.differentiation.UnivariateFunctionDifferentiator;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.nabla.DifferentiationException;
@@ -57,8 +56,7 @@ import org.objectweb.asm.tree.ClassNode;
public class ForwardModeDifferentiator implements UnivariateFunctionDifferentiator {
/** UnivariateFunction/UnivariateDifferentiableFunction map. */
- private final HashMap<Class<? extends UnivariateFunction>,
- Class<? extends UnivariateDifferentiableFunction>> map;
+ private final HashMap<Class<? extends UnivariateFunction>, Class<? extends NablaDifferentiated>> map;
/** Class name/ bytecode map. */
private final HashMap<String, byte[]> byteCodeMap;
@@ -70,8 +68,7 @@ public class ForwardModeDifferentiator i
* <p>Build a ForwardAlgorithmicDifferentiator instance with an empty cache.</p>
*/
public ForwardModeDifferentiator() {
- map = new HashMap<Class<? extends UnivariateFunction>,
- Class<? extends UnivariateDifferentiableFunction>>();
+ map = new HashMap<Class<? extends UnivariateFunction>, Class<? extends NablaDifferentiated>>();
byteCodeMap = new HashMap<String, byte[]>();
mathClasses = new HashSet<String>();
addMathImplementation(Math.class);
@@ -101,16 +98,15 @@ public class ForwardModeDifferentiator i
}
/** {@inheritDoc} */
- public UnivariateDifferentiableFunction differentiate(final UnivariateFunction d) {
+ public NablaDifferentiated differentiate(final UnivariateFunction d) {
// get the derivative class
- final Class<? extends UnivariateDifferentiableFunction> derivativeClass =
- getDerivativeClass(d.getClass());
+ final Class<? extends NablaDifferentiated> derivativeClass = getDerivativeClass(d.getClass());
try {
// create the instance
- final Constructor<? extends UnivariateDifferentiableFunction> constructor =
+ final Constructor<? extends NablaDifferentiated> constructor =
derivativeClass.getConstructor(d.getClass());
return constructor.newInstance(d);
@@ -140,13 +136,12 @@ public class ForwardModeDifferentiator i
* @return derivative class
* @throws DifferentiationException if the class cannot be differentiated
*/
- private Class<? extends UnivariateDifferentiableFunction>
+ private Class<? extends NablaDifferentiated>
getDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
throws DifferentiationException {
// lookup in the map if the class has already been differentiated
- Class<? extends UnivariateDifferentiableFunction> derivativeClass =
- map.get(differentiableClass);
+ Class<? extends NablaDifferentiated> derivativeClass = map.get(differentiableClass);
// build the derivative class if it does not exist yet
if (derivativeClass == null) {
@@ -169,7 +164,7 @@ public class ForwardModeDifferentiator i
* @return derivative class
* @throws DifferentiationException if the class cannot be differentiated
*/
- private Class<? extends UnivariateDifferentiableFunction>
+ private Class<? extends NablaDifferentiated>
createDerivativeClass(final Class<? extends UnivariateFunction> differentiableClass)
throws DifferentiationException {
try {
@@ -189,7 +184,7 @@ public class ForwardModeDifferentiator i
derived.accept(writer);
final byte[] bytecode = writer.toByteArray();
- final Class<? extends UnivariateDifferentiableFunction> dClass =
+ final Class<? extends NablaDifferentiated> dClass =
new DerivativeLoader(differentiableClass).defineClass(name, bytecode);
byteCodeMap.put(name, bytecode);
return dClass;
@@ -216,9 +211,9 @@ public class ForwardModeDifferentiator i
* @return a generated derivative class
*/
@SuppressWarnings("unchecked")
- public Class<? extends UnivariateDifferentiableFunction>
+ public Class<? extends NablaDifferentiated>
defineClass(final String name, final byte[] bytecode) {
- return (Class<? extends UnivariateDifferentiableFunction>) defineClass(name, bytecode, 0, bytecode.length);
+ return (Class<? extends NablaDifferentiated>) defineClass(name, bytecode, 0, bytecode.length);
}
}
Added: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java?rev=1393700&view=auto
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java (added)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java Wed Oct 3 19:18:38 2012
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.nabla.forward;
+
+import java.lang.reflect.Field;
+
+import org.apache.commons.math3.analysis.UnivariateFunction;
+import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
+import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
+
+/** Base class for Nabla differentiated functions.
+ * @version $Id$
+ */
+public abstract class NablaDifferentiated implements UnivariateDifferentiableFunction {
+
+ /** Primitive instance field. */
+ private final UnivariateFunction primitive;
+
+ /** Simple constructor.
+ * @param primitive primitive instance
+ */
+ protected NablaDifferentiated(final UnivariateFunction primitive) {
+ this.primitive = primitive;
+ }
+
+ /** Get a field from the primitive instance.
+ * @param name field name
+ * @return field value, boxed as an object
+ */
+ protected Object getPrimitiveField(final String name) {
+ try {
+ final Field field = primitive.getClass().getDeclaredField(name);
+ field.setAccessible(true);
+ return field.get(primitive);
+ } catch (NoSuchFieldException nsfe) {
+ throw new RuntimeException(nsfe);
+ } catch (IllegalAccessException iae) {
+ throw new RuntimeException(iae);
+ }
+ }
+
+ /** Get a field from the primitive class.
+ * @param primitiveClass class of the primitive
+ * @param name field name
+ * @return field value, boxed as an object
+ */
+ protected static Object getPrimitiveStaticField(final Class<? extends UnivariateFunction> primitiveClass,
+ final String name) {
+ try {
+ final Field field = primitiveClass.getDeclaredField(name);
+ field.setAccessible(true);
+ return field.get(null);
+ } catch (NoSuchFieldException nsfe) {
+ throw new RuntimeException(nsfe);
+ } catch (IllegalAccessException iae) {
+ throw new RuntimeException(iae);
+ }
+ }
+
+ /** {@inheritDoc} */
+ public double value(double x) {
+ return primitive.value(x);
+ }
+
+ /** {@inheritDoc} */
+ public abstract DerivativeStructure value(DerivativeStructure t);
+
+}
Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java
------------------------------------------------------------------------------
svn:eol-style = native
Propchange: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java
------------------------------------------------------------------------------
svn:keywords = "Author Date Id Revision"
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java?rev=1393700&r1=1393699&r2=1393700&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/ClassDifferentiator.java Wed Oct 3 19:18:38 2012
@@ -17,19 +17,17 @@
package org.apache.commons.nabla.forward.analysis;
import java.io.IOException;
-import java.lang.reflect.Field;
import java.util.Set;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.differentiation.UnivariateDifferentiableFunction;
import org.apache.commons.nabla.DifferentiationException;
import org.apache.commons.nabla.NablaMessages;
+import org.apache.commons.nabla.forward.NablaDifferentiated;
import org.objectweb.asm.ClassReader;
-import org.objectweb.asm.Label;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.ClassNode;
-import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.MethodNode;
/**
@@ -118,16 +116,23 @@ public class ClassDifferentiator {
// change the class properties for the derived class
classNode.visit(primitiveNode.version, Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC,
primitiveNode.name + "_NablaForwardModeUnivariateDerivative",
- null, Type.getType(Object.class).getInternalName(),
+ null, Type.getType(NablaDifferentiated.class).getInternalName(),
new String[] {
Type.getType(UnivariateDifferentiableFunction.class).getInternalName()
});
- // add boilerplate code
- addPrimitiveField();
- addConstructor();
- addGetPrimitiveFieldMethod();
- addGetPrimitiveStaticFieldMethod();
+ // add constructor calling NablaDifferentiated superclass constructor
+ final MethodNode constructor =
+ new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT,
+ Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)),
+ null, null);
+ constructor.visitVarInsn(Opcodes.ALOAD, 0);
+ constructor.visitVarInsn(Opcodes.ALOAD, 1);
+ constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(NablaDifferentiated.class),
+ INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(UnivariateFunction.class)));
+ constructor.visitInsn(Opcodes.RETURN);
+ constructor.visitMaxs(2, 2);
+ classNode.methods.add(constructor);
}
@@ -175,114 +180,4 @@ public class ClassDifferentiator {
return classNode;
}
- /** Add the primitive field.
- */
- private void addPrimitiveField() {
- final FieldNode primitiveField =
- new FieldNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_FINAL | Opcodes.ACC_SYNTHETIC,
- PRIMITIVE_FIELD, Type.getDescriptor(primitiveClass), null, null);
- classNode.fields.add(primitiveField);
- }
-
- /** Add the class constructor.
- */
- private void addConstructor() {
- final MethodNode constructor =
- new MethodNode(Opcodes.ACC_PUBLIC | Opcodes.ACC_SYNTHETIC, INIT,
- Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(primitiveClass)),
- null, null);
- constructor.visitVarInsn(Opcodes.ALOAD, 0);
- constructor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getType(Object.class).getInternalName(),
- INIT, Type.getMethodDescriptor(Type.VOID_TYPE));
- constructor.visitVarInsn(Opcodes.ALOAD, 0);
- constructor.visitVarInsn(Opcodes.ALOAD, 1);
- constructor.visitFieldInsn(Opcodes.PUTFIELD, classNode.name, PRIMITIVE_FIELD,
- Type.getDescriptor(primitiveClass));
- constructor.visitInsn(Opcodes.RETURN);
- constructor.visitMaxs(0, 0);
- classNode.methods.add(constructor);
- }
-
- /** Add the getPrimitiveField method.
- */
- private void addGetPrimitiveFieldMethod() {
- final MethodNode method =
- new MethodNode(Opcodes.ACC_PRIVATE | Opcodes.ACC_SYNTHETIC, "getPrimitiveField",
- Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(String.class)),
- null, null);
- final Label start = new Label();
- final Label end = new Label();
- method.visitTryCatchBlock(start, end, end, Type.getInternalName(IllegalAccessException.class));
- method.visitTryCatchBlock(start, end, end, Type.getInternalName(NoSuchFieldException.class));
- method.visitLabel(start);
- method.visitLdcInsn(Type.getType(primitiveClass));
- method.visitVarInsn(Opcodes.ALOAD, 1);
- method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
- "getDeclaredField",
- Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)));
- method.visitVarInsn(Opcodes.ASTORE, 2);
- method.visitVarInsn(Opcodes.ALOAD, 2);
- method.visitInsn(Opcodes.ICONST_1);
- method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
- "setAccessible",
- Type.getMethodDescriptor(Type.VOID_TYPE, Type.BOOLEAN_TYPE));
- method.visitVarInsn(Opcodes.ALOAD, 2);
- method.visitVarInsn(Opcodes.ALOAD, 0);
- method.visitFieldInsn(Opcodes.GETFIELD, classNode.name, PRIMITIVE_FIELD,
- Type.getDescriptor(primitiveClass));
- method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
- "get",
- Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)));
- method.visitInsn(Opcodes.ARETURN);
- method.visitLabel(end);
- method.visitVarInsn(Opcodes.ASTORE, 2);
- method.visitTypeInsn(Opcodes.NEW, Type.getInternalName(RuntimeException.class));
- method.visitInsn(Opcodes.DUP);
- method.visitVarInsn(Opcodes.ALOAD, 2);
- method.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(RuntimeException.class),
- INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Throwable.class)));
- method.visitInsn(Opcodes.ATHROW);
- classNode.methods.add(method);
- }
-
- /** Add the getPrimitiveStaticField method.
- */
- private void addGetPrimitiveStaticFieldMethod() {
- final MethodNode method =
- new MethodNode(Opcodes.ACC_STATIC | Opcodes.ACC_PRIVATE | Opcodes.ACC_SYNTHETIC, "getPrimitiveStaticField",
- Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(String.class)),
- null, null);
- final Label start = new Label();
- final Label end = new Label();
- method.visitTryCatchBlock(start, end, end, Type.getInternalName(IllegalAccessException.class));
- method.visitTryCatchBlock(start, end, end, Type.getInternalName(NoSuchFieldException.class));
- method.visitLabel(start);
- method.visitLdcInsn(Type.getType(primitiveClass));
- method.visitVarInsn(Opcodes.ALOAD, 0);
- method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class),
- "getDeclaredField",
- Type.getMethodDescriptor(Type.getType(Field.class), Type.getType(String.class)));
- method.visitVarInsn(Opcodes.ASTORE, 1);
- method.visitVarInsn(Opcodes.ALOAD, 1);
- method.visitInsn(Opcodes.ICONST_1);
- method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
- "setAccessible",
- Type.getMethodDescriptor(Type.VOID_TYPE, Type.BOOLEAN_TYPE));
- method.visitVarInsn(Opcodes.ALOAD, 1);
- method.visitInsn(Opcodes.ACONST_NULL);
- method.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Field.class),
- "get",
- Type.getMethodDescriptor(Type.getType(Object.class), Type.getType(Object.class)));
- method.visitInsn(Opcodes.ARETURN);
- method.visitLabel(end);
- method.visitVarInsn(Opcodes.ASTORE, 2);
- method.visitTypeInsn(Opcodes.NEW, Type.getInternalName(RuntimeException.class));
- method.visitInsn(Opcodes.DUP);
- method.visitVarInsn(Opcodes.ALOAD, 2);
- method.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(RuntimeException.class),
- INIT, Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Throwable.class)));
- method.visitInsn(Opcodes.ATHROW);
- classNode.methods.add(method);
- }
-
}
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java?rev=1393700&r1=1393699&r2=1393700&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java Wed Oct 3 19:18:38 2012
@@ -18,6 +18,7 @@ package org.apache.commons.nabla.forward
import org.apache.commons.nabla.DifferentiationException;
import org.apache.commons.nabla.NablaMessages;
+import org.apache.commons.nabla.forward.NablaDifferentiated;
import org.apache.commons.nabla.forward.analysis.InstructionsTransformer;
import org.apache.commons.nabla.forward.analysis.MethodDifferentiator;
import org.objectweb.asm.Opcodes;
@@ -52,18 +53,21 @@ public class GetTransformer implements I
final InsnList list = new InsnList();
// get the field as an object
- list.add(new LdcInsnNode(fieldInsn.name));
if (insn.getOpcode() == Opcodes.GETFIELD) {
// GETFIELD case
- list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, methodDifferentiator.getDerivedName(),
+ list.add(new LdcInsnNode(fieldInsn.name));
+ list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, Type.getInternalName(NablaDifferentiated.class),
"getPrimitiveField",
Type.getMethodDescriptor(Type.getType(Object.class),
Type.getType(String.class))));
} else {
// GETSTATIC case
- list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, methodDifferentiator.getDerivedName(),
+ list.add(new LdcInsnNode(Type.getType("L" + methodDifferentiator.getPrimitiveName() + ";")));
+ list.add(new LdcInsnNode(fieldInsn.name));
+ list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, Type.getInternalName(NablaDifferentiated.class),
"getPrimitiveStaticField",
Type.getMethodDescriptor(Type.getType(Object.class),
+ Type.getType(Class.class),
Type.getType(String.class))));
}
Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java?rev=1393700&r1=1393699&r2=1393700&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java (original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java Wed Oct 3 19:18:38 2012
@@ -119,6 +119,25 @@ public class ForwardModeDifferentiatorTe
}
@Test
+ public void testEmbeddedInvoke() {
+ checkReference(new ReferenceFunction() {
+ private double f(double t) {
+ return 2 * t;
+ }
+ private double g(double h) {
+ return h * h;
+ }
+ private double h(double t) {
+ return t - 1;
+ }
+ public double value(double t) {
+ return f(t) + g(h(t));
+ }
+ public double firstDerivative(double t) { return 2 * t; }
+ }, -5, 5, 20, 8.0e-15);
+ }
+
+ @Test
public void testPartialDerivatives() throws Exception {
PartialFunction function = new PartialFunction(1);