You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2012/10/31 18:46:42 UTC
svn commit: r1404266 - in /commons/sandbox/nabla/trunk/src:
main/java/org/apache/commons/nabla/
main/java/org/apache/commons/nabla/forward/
main/java/org/apache/commons/nabla/forward/analysis/
main/java/org/apache/commons/nabla/forward/instructions/ ma...
Author: luc
Date: Wed Oct 31 17:46:41 2012
New Revision: 1404266
URL: http://svn.apache.org/viewvc?rev=1404266&view=rev
Log:
started support for PUTFIELD/PUTSTATIC.
Added:
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java
- copied, changed from r1395676, commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java
Modified:
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java
commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties
commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/AbstractMathTest.java
commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java?rev=1404266&r1=1404265&r2=1404266&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/NablaMessages.java Wed Oct 31 17:46:41 2012
@@ -53,7 +53,7 @@ public enum NablaMessages implements Loc
UNKNOWN_METHOD("unknown method {0}.{1}"),
UNEXPECTED_INSTRUCTION("unexpected instruction with opcode {0}"),
UNABLE_TO_HANDLE_INSTRUCTION("unable to handle instruction with opcode {0}"),
- CANNOT_GET_VOID_FIELD("unable to get value of void type field {0}"),
+ CANNOT_USE_VOID_FIELD("unable to use value of void type field {0}"),
ILLEGAL_LDC_CONSTANT("illegal LDC constant {0}"),
INTERNAL_ERROR("internal error, please fill a bug report at {0}");
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java?rev=1404266&r1=1404265&r2=1404266&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/NablaDifferentiated.java Wed Oct 31 17:46:41 2012
@@ -51,6 +51,22 @@ public class NablaDifferentiated {
}
}
+ /** Put a field into the primitive instance.
+ * @param value field value, boxed as an object
+ * @param name field name
+ */
+ protected void putPrimitiveField(final Object value, final String name) {
+ try {
+ final Field field = primitive.getClass().getDeclaredField(name);
+ field.setAccessible(true);
+ field.set(primitive, value);
+ } catch (NoSuchFieldException nsfe) {
+ throw new RuntimeException(nsfe);
+ } catch (IllegalAccessException iae) {
+ throw new RuntimeException(iae);
+ }
+ }
+
/** Get a field from the primitive class.
* @param primitiveClass class of the primitive
* @param name field name
@@ -69,6 +85,25 @@ public class NablaDifferentiated {
}
}
+ /** Put a field into the primitive class.
+ * @param value field value, boxed as an object
+ * @param primitiveClass class of the primitive
+ * @param name field name
+ */
+ protected static void putPrimitiveStaticField(final Object value,
+ final Class<? extends UnivariateFunction> primitiveClass,
+ final String name) {
+ try {
+ final Field field = primitiveClass.getDeclaredField(name);
+ field.setAccessible(true);
+ field.set(null, value);
+ } catch (NoSuchFieldException nsfe) {
+ throw new RuntimeException(nsfe);
+ } catch (IllegalAccessException iae) {
+ throw new RuntimeException(iae);
+ }
+ }
+
/** Get the primitive instance.
* @return primitive instance.
*/
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java?rev=1404266&r1=1404265&r2=1404266&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/analysis/MethodDifferentiator.java Wed Oct 31 17:46:41 2012
@@ -46,6 +46,7 @@ import org.apache.commons.nabla.forward.
import org.apache.commons.nabla.forward.instructions.InvokeNonMathTransformer;
import org.apache.commons.nabla.forward.instructions.NarrowingTransformer;
import org.apache.commons.nabla.forward.instructions.Pop2Transformer;
+import org.apache.commons.nabla.forward.instructions.PutTransformer;
import org.apache.commons.nabla.forward.instructions.WideningTransformer;
import org.apache.commons.nabla.forward.trimming.SwappedDloadTrimmer;
import org.apache.commons.nabla.forward.trimming.SwappedLDCTrimmer;
@@ -578,8 +579,7 @@ public class MethodDifferentiator {
return new GetTransformer().getReplacement(insn, this, dsIndex);
case Opcodes.PUTSTATIC :
case Opcodes.PUTFIELD :
- // TODO: add support for PUTSTATIC/PUTFIELD differentiation
- throw new RuntimeException("PUTSTATIC/PUTFIELD not handled yet");
+ return new PutTransformer().getReplacement(insn, this, dsIndex);
case Opcodes.INVOKEVIRTUAL :
case Opcodes.INVOKESPECIAL :
case Opcodes.INVOKEINTERFACE :
Modified: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java?rev=1404266&r1=1404265&r2=1404266&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java Wed Oct 31 17:46:41 2012
@@ -77,7 +77,7 @@ public class GetTransformer implements I
final String valueMethodName;
switch (type.getSort()) {
case Type.VOID:
- throw new DifferentiationException(NablaMessages.CANNOT_GET_VOID_FIELD, fieldInsn.name);
+ throw new DifferentiationException(NablaMessages.CANNOT_USE_VOID_FIELD, fieldInsn.name);
case Type.BOOLEAN:
valueMethodName = "booleanValue";
boxedType = Type.getType(Boolean.class);
Copied: commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java (from r1395676, commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java)
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java?p2=commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java&p1=commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java&r1=1395676&r2=1404266&rev=1404266&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/GetTransformer.java (original)
+++ commons/sandbox/nabla/trunk/src/main/java/org/apache/commons/nabla/forward/instructions/PutTransformer.java Wed Oct 31 17:46:41 2012
@@ -16,6 +16,7 @@
*/
package org.apache.commons.nabla.forward.instructions;
+import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
import org.apache.commons.nabla.DifferentiationException;
import org.apache.commons.nabla.NablaMessages;
import org.apache.commons.nabla.forward.NablaDifferentiated;
@@ -28,19 +29,28 @@ import org.objectweb.asm.tree.FieldInsnN
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.LdcInsnNode;
import org.objectweb.asm.tree.MethodInsnNode;
-import org.objectweb.asm.tree.TypeInsnNode;
-/** Differentiation transformer for GETFIELD/GETSTATIC instructions.
- * <p>Each GETFIELD/GETSTATIC instruction is replaced by an instruction
- * list getting the field from the primitive class using reflection.
+/** Differentiation transformer for PUTFIELD/PUTSTATIC instructions.
+ * <p>
+ * PUTFIELD/PUTSTATIC instructions that do not hold transformed data
+ * (i.e. those that hold only values that do not directly depend on the
+ * transformed input double argument) are replaced by instructions
+ * putting the field from the primitive class using reflection.
+ * </p>
+ * <p>
+ * PUTFIELD/PUTSTATIC instructions that hold transformed data
+ * (i.e. those that hold values that do directly depend on the
+ * transformed input double argument) are replaced by instructions
+ * putting a dedicated transformed field in the differentiated class
+ * and also putting the value in the primitive class using reflection.
* </p>
* @version $Id$
*/
-public class GetTransformer implements InstructionsTransformer {
+public class PutTransformer implements InstructionsTransformer {
/** Simple constructor.
*/
- public GetTransformer() {
+ public PutTransformer() {
}
/** {@inheritDoc} */
@@ -48,81 +58,107 @@ public class GetTransformer implements I
final MethodDifferentiator methodDifferentiator,
final int dsIndex)
throws DifferentiationException {
-
- final FieldInsnNode fieldInsn = (FieldInsnNode) insn;
- final InsnList list = new InsnList();
-
- // get the field as an object
- if (insn.getOpcode() == Opcodes.GETFIELD) {
- // GETFIELD case
- list.add(new LdcInsnNode(fieldInsn.name));
- list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, Type.getInternalName(NablaDifferentiated.class),
- "getPrimitiveField",
- Type.getMethodDescriptor(Type.getType(Object.class),
- Type.getType(String.class))));
+ if (methodDifferentiator.stackElementIsConverted(insn, 0)) {
+ return getReplacementTransformedField((FieldInsnNode) insn, methodDifferentiator, dsIndex);
} else {
- // GETSTATIC case
- list.add(new LdcInsnNode(Type.getType("L" + methodDifferentiator.getPrimitiveName() + ";")));
- list.add(new LdcInsnNode(fieldInsn.name));
- list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, Type.getInternalName(NablaDifferentiated.class),
- "getPrimitiveStaticField",
- Type.getMethodDescriptor(Type.getType(Object.class),
- Type.getType(Class.class),
- Type.getType(String.class))));
+ return getReplacementPrimitiveField((FieldInsnNode) insn, methodDifferentiator);
}
+ }
+
+ /** Get the replacement instructions when the field type is transformed.
+ * <p>
+ * Transformed fields are kept within the transformed class.
+ * </p>
+ * @param original original instruction
+ * @param methodDifferentiator method differentiator driving this transformer
+ * @param dsIndex index of a reference {@link DerivativeStructure derivative structure} variable
+ * @return replacement instructions
+ * @exception DifferentiationException if the method differentiator cannot provide
+ * a temporary variable
+ */
+ private InsnList getReplacementTransformedField(final FieldInsnNode insn,
+ final MethodDifferentiator methodDifferentiator,
+ final int dsIndex)
+ throws DifferentiationException {
+ // TODO ad support for PUTFIELD/PUTSTATIC in the case of transformed fields
+ throw new RuntimeException("PUTFIELD/PUTSTATIC not handled yet for transformed fields");
+ }
+
+ /** Get the replacement instructions when the field type is not transformed.
+ * <p>
+ * Original fields are kept within the primitive class.
+ * </p>
+ * @param methodDifferentiator method differentiator driving this transformer
+ * @param original original instruction
+ * @return replacement instructions
+ * @exception DifferentiationException if the method differentiator cannot provide
+ * a temporary variable
+ */
+ private InsnList getReplacementPrimitiveField(final FieldInsnNode insn,
+ final MethodDifferentiator methodDifferentiator)
+ throws DifferentiationException {
- // convert it to the expected type
- final Type type = Type.getType(fieldInsn.desc);
+ final InsnList list = new InsnList();
+
+ // box the value into the expected type
+ final Type type = Type.getType(insn.desc);
final Type boxedType;
- final String valueMethodName;
switch (type.getSort()) {
case Type.VOID:
- throw new DifferentiationException(NablaMessages.CANNOT_GET_VOID_FIELD, fieldInsn.name);
+ throw new DifferentiationException(NablaMessages.CANNOT_USE_VOID_FIELD, insn.name);
case Type.BOOLEAN:
- valueMethodName = "booleanValue";
boxedType = Type.getType(Boolean.class);
break;
case Type.CHAR:
- valueMethodName = "charValue";
boxedType = Type.getType(Character.class);
break;
case Type.BYTE:
- valueMethodName = "byteValue";
boxedType = Type.getType(Byte.class);
break;
case Type.SHORT:
- valueMethodName = "shortValue";
boxedType = Type.getType(Short.class);
break;
case Type.INT:
- valueMethodName = "intValue";
boxedType = Type.getType(Integer.class);
break;
case Type.FLOAT:
- valueMethodName = "floatValue";
boxedType = Type.getType(Float.class);
break;
case Type.LONG:
- valueMethodName = "longValue";
boxedType = Type.getType(Long.class);
break;
case Type.DOUBLE:
- valueMethodName = "doubleValue";
boxedType = Type.getType(Double.class);
break;
default :
// do nothing for Type.ARRAY and Type.OBJECT
- valueMethodName = null;
boxedType = null;
}
if (boxedType != null) {
- list.add(new TypeInsnNode(Opcodes.CHECKCAST, boxedType.getInternalName()));
+ list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, boxedType.getInternalName(),
+ "valueOf", Type.getMethodDescriptor(boxedType, type)));
}
- if (valueMethodName != null) {
- list.add(new MethodInsnNode(Opcodes.INVOKEVIRTUAL, boxedType.getInternalName(),
- valueMethodName,
- Type.getMethodDescriptor(type, new Type[0])));
+
+ // put the field as an object
+ if (insn.getOpcode() == Opcodes.PUTFIELD) {
+ // PUTFIELD case
+ list.add(new LdcInsnNode(insn.name));
+ list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, Type.getInternalName(NablaDifferentiated.class),
+ "putPrimitiveField",
+ Type.getMethodDescriptor(Type.VOID_TYPE,
+ Type.getType(Object.class),
+ Type.getType(String.class))));
+ } else {
+ // PUTSTATIC case
+ list.add(new LdcInsnNode(Type.getType("L" + methodDifferentiator.getPrimitiveName() + ";")));
+ list.add(new LdcInsnNode(insn.name));
+ list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, Type.getInternalName(NablaDifferentiated.class),
+ "putPrimitiveStaticField",
+ Type.getMethodDescriptor(Type.VOID_TYPE,
+ Type.getType(Object.class),
+ Type.getType(Class.class),
+ Type.getType(String.class))));
}
return list;
Modified: commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties?rev=1404266&r1=1404265&r2=1404266&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties (original)
+++ commons/sandbox/nabla/trunk/src/main/resources/assets/org/apache/commons/nabla/NablaMessages_fr.properties Wed Oct 31 17:46:41 2012
@@ -24,6 +24,6 @@ UNABLE_TO_ANALYZE_METHOD = impossible d'
UNKNOWN_METHOD = m\u00e9thode {0}.{1} inconnue
UNEXPECTED_INSTRUCTION = instruction ayant l''opcode {0} inattendue
UNABLE_TO_HANDLE_INSTRUCTION = incapable de g\u00e9rer une instruction ayant l''opcode {0}
-CANNOT_GET_VOID_FIELD = impossible de r\u00e9cup\u00e9rer la valeur du champ {0} de type void
+CANNOT_USE_VOID_FIELD = impossible d''utiliser la valeur du champ de type \u00ab void \u00bb {0}
ILLEGAL_LDC_CONSTANT = constante ill\u00e9gale pour une instruction LDC : {0}
INTERNAL_ERROR = erreur interne, veuillez signaler l''erreur \u00e0 {0}
Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/AbstractMathTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/AbstractMathTest.java?rev=1404266&r1=1404265&r2=1404266&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/AbstractMathTest.java (original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/AbstractMathTest.java Wed Oct 31 17:46:41 2012
@@ -30,9 +30,9 @@ public abstract class AbstractMathTest {
double firstDerivative(double t);
}
- protected void checkReference(ReferenceFunction reference,
- double t0, double t1, int n,
- double threshold) {
+ protected UnivariateDifferentiableFunction checkReference(ReferenceFunction reference,
+ double t0, double t1, int n,
+ double threshold) {
try {
ForwardModeDifferentiator differentiator = new ForwardModeDifferentiator();
differentiator.addMathImplementation(MathExtensions.class);
@@ -47,9 +47,11 @@ public abstract class AbstractMathTest {
reference.value(t), derivative.value(dpT).getValue(),
threshold);
}
+ return derivative;
} catch (DifferentiationException de) {
de.printStackTrace(System.err);
Assert.fail(de.getLocalizedMessage());
+ return null;
}
}
Modified: commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java
URL: http://svn.apache.org/viewvc/commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java?rev=1404266&r1=1404265&r2=1404266&view=diff
==============================================================================
--- commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java (original)
+++ commons/sandbox/nabla/trunk/src/test/java/org/apache/commons/nabla/forward/ForwardModeDifferentiatorTest.java Wed Oct 31 17:46:41 2012
@@ -163,6 +163,113 @@ public class ForwardModeDifferentiatorTe
}
public double firstDerivative(double t) { return 0.5 - 1.0 / (t * t); }
}, -5.25, 5, 20, 8.0e-15);
+
+ }
+
+ @Test
+ public void testFieldAsSideEffectOnly() {
+ SideEffectsFunction sef = new SideEffectsFunction();
+ UnivariateDifferentiableFunction derivative = checkReference(sef, -5.25, 5, 20, 8.0e-15);
+ boolean formerZ = sef.getSideZ();
+ char formerC = sef.getSideC();
+ byte formerB = sef.getSideB();
+ short formerS = sef.getSideS();
+ int formerI = sef.getSideI();
+ long formerJ = sef.getSideJ();
+ float formerF = sef.getSideF();
+ double formerD = sef.getSideD();
+
+ // this call should have a side effect on the original sef primitive object
+ derivative.value(new DerivativeStructure(1, 1, 0, 1.0));
+
+ Assert.assertEquals(!formerZ, sef.getSideZ());
+ Assert.assertEquals(formerC + 1, sef.getSideC());
+ Assert.assertEquals(formerB + 1, sef.getSideB());
+ Assert.assertEquals(formerS + 1, sef.getSideS());
+ Assert.assertEquals(formerI + 1, sef.getSideI());
+ Assert.assertEquals(formerJ + 1, sef.getSideJ());
+ Assert.assertEquals(formerF + 1, sef.getSideF(), 1.0e-15);
+ Assert.assertEquals(formerD + 1, sef.getSideD(), 1.0e-15);
+
+ }
+
+ public static class SideEffectsFunction implements ReferenceFunction {
+ // these fields are changed as side effect of the call to value
+ // but they do not affect the function value
+ private boolean sideZ;
+ private char sideC;
+ private byte sideB;
+ private short sideS;
+ private int sideI;
+ private long sideJ;
+ private float sideF;
+ private double sideD;
+
+ public double value(double t) {
+ sideZ = !sideZ;
+ sideC++;
+ sideB++;
+ sideS++;
+ sideI++;
+ sideJ++;
+ sideF++;
+ sideD++;
+ return 2 * t;
+ }
+
+ public double firstDerivative(double t) { return 2.0; }
+
+ public boolean getSideZ() {
+ return sideZ;
+ }
+
+ public char getSideC() {
+ return sideC;
+ }
+
+ public byte getSideB() {
+ return sideB;
+ }
+
+ public short getSideS() {
+ return sideS;
+ }
+
+ public int getSideI() {
+ return sideI;
+ }
+
+ public long getSideJ() {
+ return sideJ;
+ }
+
+ public float getSideF() {
+ return sideF;
+ }
+
+ public double getSideD() {
+ return sideD;
+ }
+
+ }
+
+ @Test
+ public void testFieldAsIntermediateData() {
+ checkReference(new ReferenceFunction() {
+
+ // this field is used to store data transfered from value to f
+ private double intermediateData;
+
+ private double f() {
+ return 2 * intermediateData;
+ }
+
+ public double value(double t) {
+ intermediateData = t + 3.0;
+ return f();
+ }
+ public double firstDerivative(double t) { return 2.0; }
+ }, -5.25, 5, 20, 8.0e-15);
}
@Test
@@ -192,9 +299,6 @@ public class ForwardModeDifferentiatorTe
public static void setA(double newA) {
a = newA;
}
- public double getX() {
- return x;
- }
public double value(double y) {
return x * y * a + y * y;
}