You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by sn...@apache.org on 2015/05/15 19:37:10 UTC

[2/5] cassandra git commit: Better support of null for UDF

Better support of null for UDF

patch by Ronbert Stupp; reviewed by Benjamin Lerer for CASSANDRA-8374


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/1937bed9
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/1937bed9
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/1937bed9

Branch: refs/heads/cassandra-2.2
Commit: 1937bed9035e953b4cb9099ddeb581d3bf38bca3
Parents: 94480b1
Author: Robert Stupp <sn...@snazy.de>
Authored: Fri May 15 19:34:32 2015 +0200
Committer: Robert Stupp <sn...@snazy.de>
Committed: Fri May 15 19:34:32 2015 +0200

----------------------------------------------------------------------
 CHANGES.txt                                     |   3 +-
 doc/cql3/CQL.textile                            |  26 +-
 pylib/cqlshlib/cql3handling.py                  |   1 +
 src/java/org/apache/cassandra/cql3/Cql.g        |   9 +-
 .../cql3/functions/JavaSourceUDFFactory.java    |  14 +-
 .../cql3/functions/NativeScalarFunction.java    |   5 +
 .../cql3/functions/ScalarFunction.java          |   2 +
 .../cql3/functions/ScriptBasedUDF.java          |   5 +-
 .../cassandra/cql3/functions/UDAggregate.java   |  20 +-
 .../cassandra/cql3/functions/UDFunction.java    |  68 ++-
 .../cassandra/cql3/functions/UDHelper.java      |  39 +-
 .../statements/CreateAggregateStatement.java    |   5 +-
 .../statements/CreateFunctionStatement.java     |   8 +-
 .../cassandra/schema/LegacySchemaTables.java    |   7 +-
 .../cassandra/cql3/functions/JavaSourceUDF.txt  |   6 +-
 .../apache/cassandra/cql3/AggregationTest.java  | 196 ++++++-
 .../org/apache/cassandra/cql3/JsonTest.java     |   4 +-
 .../org/apache/cassandra/cql3/PgStringTest.java |   4 +-
 .../org/apache/cassandra/cql3/UFAuthTest.java   |   6 +
 .../cassandra/cql3/UFIdentificationTest.java    |   1 +
 test/unit/org/apache/cassandra/cql3/UFTest.java | 527 ++++++++++++++++---
 21 files changed, 821 insertions(+), 135 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 08e1005..c5c505a 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,5 +1,6 @@
 2.2.0-beta1
- * Use ecj instead of javassist for UDFs
+ * Better support of null for UDF (CASSANDRA-8374)
+ * Use ecj instead of javassist for UDFs (CASSANDRA-8241)
  * faster async logback configuration for tests (CASSANDRA-9376)
  * Add `smallint` and `tinyint` data types (CASSANDRA-8951)
  * Avoid thrift schema creation when native driver is used in stress tool (CASSANDRA-9374)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/doc/cql3/CQL.textile
----------------------------------------------------------------------
diff --git a/doc/cql3/CQL.textile b/doc/cql3/CQL.textile
index 89bd1b9..2e58615 100644
--- a/doc/cql3/CQL.textile
+++ b/doc/cql3/CQL.textile
@@ -591,8 +591,9 @@ __Syntax:__
 bc(syntax).. 
 <create-function-stmt> ::= CREATE ( OR REPLACE )? 
                             FUNCTION ( IF NOT EXISTS )?
-                            ( ( <keyspace> '.' )? <function-name> )?
+                            ( <keyspace> '.' )? <function-name>
                             '(' <arg-name> <arg-type> ( ',' <arg-name> <arg-type> )* ')'
+                            ( CALLED | RETURNS NULL ) ON NULL INPUT
                             RETURNS <type>
                             LANGUAGE <language>
                             AS <body>
@@ -602,6 +603,7 @@ __Sample:__
 bc(sample). 
 CREATE OR REPLACE FUNCTION somefunction
     ( somearg int, anotherarg text, complexarg frozen<someUDT>, listarg list<bigint> )
+    RETURNS NULL ON NULL INPUT
     RETURNS text
     LANGUAGE java
     AS $$
@@ -609,6 +611,7 @@ CREATE OR REPLACE FUNCTION somefunction
     $$;
 CREATE FUNCTION akeyspace.fname IF NOT EXISTS
     ( someArg int )
+    CALLED ON NULL INPUT
     RETURNS text
     LANGUAGE java
     AS $$
@@ -628,6 +631,11 @@ Note that keyspace names, function names and argument types are subject to the d
 
 @CREATE FUNCTION@ with the optional @OR REPLACE@ keywords either creates a function or replaces an existing one with the same signature. A @CREATE FUNCTION@ without @OR REPLACE@ fails if a function with the same signature already exists.
 
+Behavior on invocation with @null@ values must be defined for each function. There are two options:
+
+# @RETURNS NULL ON NULL INPUT@ declares that the function will always return @null@ if any of the input arguments is @null@.
+# @CALLED ON NULL INPUT@ declares that the function will always be executed.
+
 If the optional @IF NOT EXISTS@ keywords are used, the function will only be created if another function with the same signature does not exist.
 
 @OR REPLACE@ and @IF NOT EXIST@ cannot be used together.
@@ -642,7 +650,7 @@ __Syntax:__
 
 bc(syntax).. 
 <drop-function-stmt> ::= DROP FUNCTION ( IF EXISTS )?
-                         ( ( <keyspace> '.' )? <function-name> )?
+                         ( <keyspace> '.' )? <function-name>
                          ( '(' <arg-type> ( ',' <arg-type> )* ')' )?
 
 p. 
@@ -666,7 +674,7 @@ __Syntax:__
 bc(syntax).. 
 <create-aggregate-stmt> ::= CREATE ( OR REPLACE )? 
                             AGGREGATE ( IF NOT EXISTS )?
-                            ( ( <keyspace> '.' )? <aggregate-name> )?
+                            ( <keyspace> '.' )? <aggregate-name>
                             '(' <arg-type> ( ',' <arg-type> )* ')'
                             SFUNC ( <keyspace> '.' )? <state-functionname>
                             STYPE <state-type>
@@ -698,11 +706,11 @@ Signatures for user-defined aggregates follow the "same rules":#functionSignatur
 
 @STYPE@ defines the type of the state value and must be specified.
 
-The optional @INITCOND@ defines the initial state value for the aggregate. It defaults to @null@.
+The optional @INITCOND@ defines the initial state value for the aggregate. It defaults to @null@. A non-@null@ @INITCOND@ must be specified for state functions that are declared with @RETURNS NULL ON NULL INPUT@.
 
-@SFUNC@ references an existing function to be used as the state modifying function. The type of first argument of the state function must match @STYPE@. The remaining argument types of the state function must match the argument types of the aggregate function.
+@SFUNC@ references an existing function to be used as the state modifying function. The type of first argument of the state function must match @STYPE@. The remaining argument types of the state function must match the argument types of the aggregate function. State is not updated for state functions declared with @RETURNS NULL ON NULL INPUT@ and called with @null@.
 
-The optional @FINALFUNC@ is called just before the aggregate result is returned. It must take only one argument with type @STYPE@. The return type of the @FINALFUNC@ may be a different type.
+The optional @FINALFUNC@ is called just before the aggregate result is returned. It must take only one argument with type @STYPE@. The return type of the @FINALFUNC@ may be a different type. A final function declared with @RETURNS NULL ON NULL INPUT@ means that the aggregate's return value will be @null@, if the last state is @null@.
 
 If no @FINALFUNC@ is defined, the overall return type of the aggregate function is @STYPE@.  If a @FINALFUNC@ is defined, it is the return type of that function.
 
@@ -713,8 +721,8 @@ h3(#dropAggregateStmt). DROP AGGREGATE
 __Syntax:__
 
 bc(syntax).. 
-<drop-function-stmt> ::= DROP AGGREGATE ( IF EXISTS )?
-                         ( ( <keyspace> '.' )? <functionname> )?
+<drop-aggregate-stmt> ::= DROP AGGREGATE ( IF EXISTS )?
+                         ( <keyspace> '.' )? <aggregate-name>
                          ( '(' <arg-type> ( ',' <arg-type> )* ')' )?
 p. 
 
@@ -1447,6 +1455,7 @@ A complete working example for user-defined aggregates (assuming that a keyspace
 
 bc(sample).. 
 CREATE FUNCTION averageState ( state tuple<int,bigint>, val int )
+  CALLED ON NULL INPUT
   RETURNS tuple<int,bigint>
   LANGUAGE java
   AS '
@@ -1458,6 +1467,7 @@ CREATE FUNCTION averageState ( state tuple<int,bigint>, val int )
   ';
 
 CREATE FUNCTION averageFinal ( state tuple<int,bigint> )
+  CALLED ON NULL INPUT
   RETURNS double
   LANGUAGE java
   AS '

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/pylib/cqlshlib/cql3handling.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/cql3handling.py b/pylib/cqlshlib/cql3handling.py
index add4a7f..3e155d0 100644
--- a/pylib/cqlshlib/cql3handling.py
+++ b/pylib/cqlshlib/cql3handling.py
@@ -1027,6 +1027,7 @@ syntax_rules += r'''
                             ( "(" ( newcol=<cident> <storageType>
                               ( "," [newcolname]=<cident> <storageType> )* )?
                             ")" )?
+                            ("RETURNS" "NULL" | "CALLED") "ON" "NULL" "INPUT"
                             "RETURNS" <storageType>
                             "LANGUAGE" <cident> "AS" <stringLiteral>
                          ;

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/Cql.g
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/Cql.g b/src/java/org/apache/cassandra/cql3/Cql.g
index 831b012..3600cd1 100644
--- a/src/java/org/apache/cassandra/cql3/Cql.g
+++ b/src/java/org/apache/cassandra/cql3/Cql.g
@@ -594,6 +594,7 @@ createFunctionStatement returns [CreateFunctionStatement expr]
 
         List<ColumnIdentifier> argsNames = new ArrayList<>();
         List<CQL3Type.Raw> argsTypes = new ArrayList<>();
+        boolean calledOnNullInput = false;
     }
     : K_CREATE (K_OR K_REPLACE { orReplace = true; })?
       K_FUNCTION
@@ -605,10 +606,12 @@ createFunctionStatement returns [CreateFunctionStatement expr]
           ( ',' k=ident v=comparatorType { argsNames.add(k); argsTypes.add(v); } )*
         )?
       ')'
+      ( (K_RETURNS K_NULL) | (K_CALLED { calledOnNullInput=true; })) K_ON K_NULL K_INPUT
       K_RETURNS rt = comparatorType
       K_LANGUAGE language = IDENT
       K_AS body = STRING_LITERAL
-      { $expr = new CreateFunctionStatement(fn, $language.text.toLowerCase(), $body.text, argsNames, argsTypes, rt, orReplace, ifNotExists); }
+      { $expr = new CreateFunctionStatement(fn, $language.text.toLowerCase(), $body.text,
+                                            argsNames, argsTypes, rt, calledOnNullInput, orReplace, ifNotExists); }
     ;
 
 dropFunctionStatement returns [DropFunctionStatement expr]
@@ -1550,6 +1553,8 @@ basic_unreserved_keyword returns [String str]
         | K_RETURNS
         | K_LANGUAGE
         | K_JSON
+        | K_CALLED
+        | K_INPUT
         ) { $str = $k.text; }
     ;
 
@@ -1678,6 +1683,8 @@ K_STYPE:       S T Y P E;
 K_FINALFUNC:   F I N A L F U N C;
 K_INITCOND:    I N I T C O N D;
 K_RETURNS:     R E T U R N S;
+K_CALLED:      C A L L E D;
+K_INPUT:       I N P U T;
 K_LANGUAGE:    L A N G U A G E;
 K_OR:          O R;
 K_REPLACE:     R E P L A C E;

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java b/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java
index 3809bf3..97a08b1 100644
--- a/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java
+++ b/src/java/org/apache/cassandra/cql3/functions/JavaSourceUDFFactory.java
@@ -118,6 +118,7 @@ public final class JavaSourceUDFFactory
                                List<ColumnIdentifier> argNames,
                                List<AbstractType<?>> argTypes,
                                AbstractType<?> returnType,
+                               boolean calledOnNullInput,
                                String body)
     throws InvalidRequestException
     {
@@ -126,7 +127,7 @@ public final class JavaSourceUDFFactory
         // returnDataType is just the C* internal returnType converted to the Java Driver DataType
         DataType returnDataType = UDHelper.driverType(returnType);
         // javaParamTypes is just the Java representation for argTypes resp. argDataTypes
-        Class<?>[] javaParamTypes = UDHelper.javaTypes(argDataTypes);
+        Class<?>[] javaParamTypes = UDHelper.javaTypes(argDataTypes, calledOnNullInput);
         // javaReturnType is just the Java representation for returnType resp. returnDataType
         Class<?> javaReturnType = returnDataType.asJavaClass();
 
@@ -226,11 +227,11 @@ public final class JavaSourceUDFFactory
             MethodType methodType = MethodType.methodType(void.class)
                                               .appendParameterTypes(FunctionName.class, List.class, List.class, DataType[].class,
                                                                     AbstractType.class, DataType.class,
-                                                                    String.class);
+                                                                    boolean.class, String.class);
             MethodHandle ctor = MethodHandles.lookup().findConstructor(cls, methodType);
             return (UDFunction) ctor.invokeWithArguments(name, argNames, argTypes, argDataTypes,
                                                          returnType, returnDataType,
-                                                         body);
+                                                         calledOnNullInput, body);
         }
         catch (InvocationTargetException e)
         {
@@ -309,11 +310,16 @@ public final class JavaSourceUDFFactory
                 // cast to Java type
                 .append("                (").append(javaSourceName(paramTypes[i])).append(") ")
                 // generate object representation of input parameter (call UDFunction.compose)
-                .append("compose(protocolVersion, ").append(i).append(", params.get(").append(i).append("))");
+                .append(composeMethod(paramTypes[i])).append("(protocolVersion, ").append(i).append(", params.get(").append(i).append("))");
         }
         return code.toString();
     }
 
+    private static String composeMethod(Class<?> type)
+    {
+        return (type.isPrimitive()) ? ("compose_" + type.getName()) : "compose";
+    }
+
     // Java source UDFs are a very simple compilation task, which allows us to let one class implement
     // all interfaces required by ECJ.
     static final class EcjCompilationUnit implements ICompilationUnit, ICompilerRequestor, INameEnvironment

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java b/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java
index 8f7f221..3ae0607 100644
--- a/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java
+++ b/src/java/org/apache/cassandra/cql3/functions/NativeScalarFunction.java
@@ -29,6 +29,11 @@ public abstract class NativeScalarFunction extends NativeFunction implements Sca
         super(name, returnType, argsType);
     }
 
+    public boolean isCalledOnNullInput()
+    {
+        return true;
+    }
+
     public final boolean isAggregate()
     {
         return false;

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java b/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java
index f00faf7..ba258df 100644
--- a/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java
+++ b/src/java/org/apache/cassandra/cql3/functions/ScalarFunction.java
@@ -27,6 +27,8 @@ import org.apache.cassandra.exceptions.InvalidRequestException;
  */
 public interface ScalarFunction extends Function
 {
+    public boolean isCalledOnNullInput();
+
     /**
      * Applies this function to the specified parameter.
      *

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java b/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java
index b38f483..319c948 100644
--- a/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java
+++ b/src/java/org/apache/cassandra/cql3/functions/ScriptBasedUDF.java
@@ -66,11 +66,12 @@ public class ScriptBasedUDF extends UDFunction
                    List<ColumnIdentifier> argNames,
                    List<AbstractType<?>> argTypes,
                    AbstractType<?> returnType,
+                   boolean calledOnNullInput,
                    String language,
                    String body)
     throws InvalidRequestException
     {
-        super(name, argNames, argTypes, returnType, language, body);
+        super(name, argNames, argTypes, returnType, calledOnNullInput, language, body);
 
         Compilable scriptEngine = scriptEngines.get(language);
         if (scriptEngine == null)
@@ -88,7 +89,7 @@ public class ScriptBasedUDF extends UDFunction
         }
     }
 
-    public ByteBuffer execute(int protocolVersion, List<ByteBuffer> parameters) throws InvalidRequestException
+    public ByteBuffer executeUserDefined(int protocolVersion, List<ByteBuffer> parameters) throws InvalidRequestException
     {
         Object[] params = new Object[argTypes.size()];
         for (int i = 0; i < params.length; i++)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java b/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java
index f5a1af0..f153aed 100644
--- a/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java
+++ b/src/java/org/apache/cassandra/cql3/functions/UDAggregate.java
@@ -149,17 +149,27 @@ public class UDAggregate extends AbstractFunction implements AggregateFunction
 
             public void addInput(int protocolVersion, List<ByteBuffer> values) throws InvalidRequestException
             {
-                List<ByteBuffer> copy = new ArrayList<>(values.size() + 1);
-                copy.add(state);
-                copy.addAll(values);
-                state = stateFunction.execute(protocolVersion, copy);
+                List<ByteBuffer> fArgs = new ArrayList<>(values.size() + 1);
+                fArgs.add(state);
+                fArgs.addAll(values);
+                if (stateFunction instanceof UDFunction)
+                {
+                    UDFunction udf = (UDFunction)stateFunction;
+                    if (udf.isCallableWrtNullable(fArgs))
+                        state = udf.executeUserDefined(protocolVersion, fArgs);
+                }
+                else
+                {
+                    state = stateFunction.execute(protocolVersion, fArgs);
+                }
             }
 
             public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
             {
                 if (finalFunction == null)
                     return state;
-                return finalFunction.execute(protocolVersion, Collections.singletonList(state));
+                List<ByteBuffer> fArgs = Collections.singletonList(state);
+                return finalFunction.execute(protocolVersion, fArgs);
             }
 
             public void reset()

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/functions/UDFunction.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java
index a56af6e..873a1f2 100644
--- a/src/java/org/apache/cassandra/cql3/functions/UDFunction.java
+++ b/src/java/org/apache/cassandra/cql3/functions/UDFunction.java
@@ -49,16 +49,18 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct
 
     protected final DataType[] argDataTypes;
     protected final DataType returnDataType;
+    protected final boolean calledOnNullInput;
 
     protected UDFunction(FunctionName name,
                          List<ColumnIdentifier> argNames,
                          List<AbstractType<?>> argTypes,
                          AbstractType<?> returnType,
+                         boolean calledOnNullInput,
                          String language,
                          String body)
     {
         this(name, argNames, argTypes, UDHelper.driverTypes(argTypes), returnType,
-             UDHelper.driverType(returnType), language, body);
+             UDHelper.driverType(returnType), calledOnNullInput, language, body);
     }
 
     protected UDFunction(FunctionName name,
@@ -67,6 +69,7 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct
                          DataType[] argDataTypes,
                          AbstractType<?> returnType,
                          DataType returnDataType,
+                         boolean calledOnNullInput,
                          String language,
                          String body)
     {
@@ -77,20 +80,22 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct
         this.body = body;
         this.argDataTypes = argDataTypes;
         this.returnDataType = returnDataType;
+        this.calledOnNullInput = calledOnNullInput;
     }
 
     public static UDFunction create(FunctionName name,
                                     List<ColumnIdentifier> argNames,
                                     List<AbstractType<?>> argTypes,
                                     AbstractType<?> returnType,
+                                    boolean calledOnNullInput,
                                     String language,
                                     String body)
     throws InvalidRequestException
     {
         switch (language)
         {
-            case "java": return JavaSourceUDFFactory.buildUDF(name, argNames, argTypes, returnType, body);
-            default: return new ScriptBasedUDF(name, argNames, argTypes, returnType, language, body);
+            case "java": return JavaSourceUDFFactory.buildUDF(name, argNames, argTypes, returnType, calledOnNullInput, body);
+            default: return new ScriptBasedUDF(name, argNames, argTypes, returnType, calledOnNullInput, language, body);
         }
     }
 
@@ -107,13 +112,14 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct
                                                   List<ColumnIdentifier> argNames,
                                                   List<AbstractType<?>> argTypes,
                                                   AbstractType<?> returnType,
+                                                  boolean calledOnNullInput,
                                                   String language,
                                                   String body,
                                                   final InvalidRequestException reason)
     {
-        return new UDFunction(name, argNames, argTypes, returnType, language, body)
+        return new UDFunction(name, argNames, argTypes, returnType, calledOnNullInput, language, body)
         {
-            public ByteBuffer execute(int protocolVersion, List<ByteBuffer> parameters) throws InvalidRequestException
+            public ByteBuffer executeUserDefined(int protocolVersion, List<ByteBuffer> parameters) throws InvalidRequestException
             {
                 throw new InvalidRequestException(String.format("Function '%s' exists but hasn't been loaded successfully "
                                                                 + "for the following reason: %s. Please see the server log for details",
@@ -123,6 +129,23 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct
         };
     }
 
+    public final ByteBuffer execute(int protocolVersion, List<ByteBuffer> parameters) throws InvalidRequestException
+    {
+        if (!isCallableWrtNullable(parameters))
+            return null;
+        return executeUserDefined(protocolVersion, parameters);
+    }
+
+    public boolean isCallableWrtNullable(List<ByteBuffer> parameters)
+    {
+        if (!calledOnNullInput)
+            for (ByteBuffer parameter : parameters)
+                if (parameter == null || parameter.remaining() == 0)
+                    return false;
+        return true;
+    }
+
+    protected abstract ByteBuffer executeUserDefined(int protocolVersion, List<ByteBuffer> parameters) throws InvalidRequestException;
 
     public boolean isAggregate()
     {
@@ -134,6 +157,11 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct
         return false;
     }
 
+    public boolean isCalledOnNullInput()
+    {
+        return calledOnNullInput;
+    }
+
     public List<ColumnIdentifier> argNames()
     {
         return argNames;
@@ -162,6 +190,36 @@ public abstract class UDFunction extends AbstractFunction implements ScalarFunct
         return value == null ? null : argDataTypes[argIndex].deserialize(value, ProtocolVersion.fromInt(protocolVersion));
     }
 
+    // do not remove - used by generated Java UDFs
+    protected float compose_float(int protocolVersion, int argIndex, ByteBuffer value)
+    {
+        return value == null ? 0f : (float)DataType.cfloat().deserialize(value, ProtocolVersion.fromInt(protocolVersion));
+    }
+
+    // do not remove - used by generated Java UDFs
+    protected double compose_double(int protocolVersion, int argIndex, ByteBuffer value)
+    {
+        return value == null ? 0d : (double)DataType.cdouble().deserialize(value, ProtocolVersion.fromInt(protocolVersion));
+    }
+
+    // do not remove - used by generated Java UDFs
+    protected int compose_int(int protocolVersion, int argIndex, ByteBuffer value)
+    {
+        return value == null ? 0 : (int)DataType.cint().deserialize(value, ProtocolVersion.fromInt(protocolVersion));
+    }
+
+    // do not remove - used by generated Java UDFs
+    protected long compose_long(int protocolVersion, int argIndex, ByteBuffer value)
+    {
+        return value == null ? 0L : (long)DataType.bigint().deserialize(value, ProtocolVersion.fromInt(protocolVersion));
+    }
+
+    // do not remove - used by generated Java UDFs
+    protected boolean compose_boolean(int protocolVersion, int argIndex, ByteBuffer value)
+    {
+        return value != null && (boolean) DataType.cboolean().deserialize(value, ProtocolVersion.fromInt(protocolVersion));
+    }
+
     /**
      * Used by UDF implementations (both Java code generated by {@link org.apache.cassandra.cql3.functions.JavaSourceUDFFactory}
      * and script executor {@link org.apache.cassandra.cql3.functions.ScriptBasedUDF}) to convert the Java

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/functions/UDHelper.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/UDHelper.java b/src/java/org/apache/cassandra/cql3/functions/UDHelper.java
index af62c5a..55a0888 100644
--- a/src/java/org/apache/cassandra/cql3/functions/UDHelper.java
+++ b/src/java/org/apache/cassandra/cql3/functions/UDHelper.java
@@ -20,22 +20,17 @@ package org.apache.cassandra.cql3.functions;
 import java.lang.invoke.MethodHandle;
 import java.lang.invoke.MethodHandles;
 import java.lang.reflect.Method;
-import java.util.*;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import java.util.List;
 
 import com.datastax.driver.core.DataType;
-import org.apache.cassandra.cql3.*;
-import org.apache.cassandra.db.marshal.*;
+import org.apache.cassandra.cql3.CQL3Type;
+import org.apache.cassandra.db.marshal.AbstractType;
 
 /**
  * Helper class for User Defined Functions + Aggregates.
  */
 public final class UDHelper
 {
-    protected static final Logger logger = LoggerFactory.getLogger(UDHelper.class);
-
     // TODO make these c'tors and methods public in Java-Driver - see https://datastax-oss.atlassian.net/browse/JAVA-502
     static final MethodHandle methodParseOne;
     static
@@ -56,14 +51,36 @@ public final class UDHelper
     /**
      * Construct an array containing the Java classes for the given Java Driver {@link com.datastax.driver.core.DataType}s.
      *
-     * @param dataTypes array with UDF argument types
+     * @param dataTypes  array with UDF argument types
+     * @param calledOnNullInput whether to allow {@code null} as an argument value
      * @return array of same size with UDF arguments
      */
-    public static Class<?>[] javaTypes(DataType[] dataTypes)
+    public static Class<?>[] javaTypes(DataType[] dataTypes, boolean calledOnNullInput)
     {
         Class<?>[] paramTypes = new Class[dataTypes.length];
         for (int i = 0; i < paramTypes.length; i++)
-            paramTypes[i] = dataTypes[i].asJavaClass();
+        {
+            Class<?> clazz = dataTypes[i].asJavaClass();
+            if (!calledOnNullInput)
+            {
+                // only care about classes that can be used in a data type
+                if (clazz == Integer.class)
+                    clazz = int.class;
+                else if (clazz == Long.class)
+                    clazz = long.class;
+                else if (clazz == Byte.class)
+                    clazz = byte.class;
+                else if (clazz == Short.class)
+                    clazz = short.class;
+                else if (clazz == Float.class)
+                    clazz = float.class;
+                else if (clazz == Double.class)
+                    clazz = double.class;
+                else if (clazz == Boolean.class)
+                    clazz = boolean.class;
+            }
+            paramTypes[i] = clazz;
+        }
         return paramTypes;
     }
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java b/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java
index 74cc521..2524a7c 100644
--- a/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java
+++ b/src/java/org/apache/cassandra/cql3/statements/CreateAggregateStatement.java
@@ -199,6 +199,9 @@ public final class CreateAggregateStatement extends SchemaAlteringStatement
                                                                 functionName, returnType.asCQL3Type(), old.returnType().asCQL3Type()));
         }
 
+        if (!stateFunction.isCalledOnNullInput() && initcond == null)
+            throw new InvalidRequestException(String.format("Cannot create aggregate %s without INITCOND because state function %s does not accept 'null' arguments", functionName, stateFunc));
+
         udAggregate = new UDAggregate(functionName, argTypes, returnType,
                                                   stateFunction,
                                                   finalFunction,
@@ -220,7 +223,7 @@ public final class CreateAggregateStatement extends SchemaAlteringStatement
         return sb.toString();
     }
 
-    private List<AbstractType<?>> stateArguments(AbstractType<?> stateType, List<AbstractType<?>> argTypes)
+    private static List<AbstractType<?>> stateArguments(AbstractType<?> stateType, List<AbstractType<?>> argTypes)
     {
         List<AbstractType<?>> r = new ArrayList<>(argTypes.size() + 1);
         r.add(stateType);

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java b/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java
index faab043..3cef6e4 100644
--- a/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java
+++ b/src/java/org/apache/cassandra/cql3/statements/CreateFunctionStatement.java
@@ -49,6 +49,7 @@ public final class CreateFunctionStatement extends SchemaAlteringStatement
     private final List<ColumnIdentifier> argNames;
     private final List<CQL3Type.Raw> argRawTypes;
     private final CQL3Type.Raw rawReturnType;
+    private final boolean calledOnNullInput;
 
     private List<AbstractType<?>> argTypes;
     private AbstractType<?> returnType;
@@ -61,6 +62,7 @@ public final class CreateFunctionStatement extends SchemaAlteringStatement
                                    List<ColumnIdentifier> argNames,
                                    List<CQL3Type.Raw> argRawTypes,
                                    CQL3Type.Raw rawReturnType,
+                                   boolean calledOnNullInput,
                                    boolean orReplace,
                                    boolean ifNotExists)
     {
@@ -70,6 +72,7 @@ public final class CreateFunctionStatement extends SchemaAlteringStatement
         this.argNames = argNames;
         this.argRawTypes = argRawTypes;
         this.rawReturnType = rawReturnType;
+        this.calledOnNullInput = calledOnNullInput;
         this.orReplace = orReplace;
         this.ifNotExists = ifNotExists;
     }
@@ -152,13 +155,16 @@ public final class CreateFunctionStatement extends SchemaAlteringStatement
                 throw new InvalidRequestException(String.format("Function %s already exists", old));
             if (!(old instanceof ScalarFunction))
                 throw new InvalidRequestException(String.format("Function %s can only replace a function", old));
+            if (calledOnNullInput != ((ScalarFunction) old).isCalledOnNullInput())
+                throw new InvalidRequestException(String.format("Function %s can only be replaced with %s", old,
+                                                                calledOnNullInput ? "CALLED ON NULL INPUT" : "RETURNS NULL ON NULL INPUT"));
 
             if (!Functions.typeEquals(old.returnType(), returnType))
                 throw new InvalidRequestException(String.format("Cannot replace function %s, the new return type %s is not compatible with the return type %s of existing function",
                                                                 functionName, returnType.asCQL3Type(), old.returnType().asCQL3Type()));
         }
 
-        this.udFunction = UDFunction.create(functionName, argNames, argTypes, returnType, language, body);
+        this.udFunction = UDFunction.create(functionName, argNames, argTypes, returnType, calledOnNullInput, language, body);
         this.replaced = old != null;
 
         MigrationManager.announceNewFunction(udFunction, isLocalOnly);

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/java/org/apache/cassandra/schema/LegacySchemaTables.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/schema/LegacySchemaTables.java b/src/java/org/apache/cassandra/schema/LegacySchemaTables.java
index 4eb800b..720f309 100644
--- a/src/java/org/apache/cassandra/schema/LegacySchemaTables.java
+++ b/src/java/org/apache/cassandra/schema/LegacySchemaTables.java
@@ -162,6 +162,7 @@ public class LegacySchemaTables
                 + "body text,"
                 + "language text,"
                 + "return_type text,"
+                + "called_on_null_input boolean,"
                 + "PRIMARY KEY ((keyspace_name), function_name, signature))");
 
     private static final CFMetaData Aggregates =
@@ -1285,6 +1286,7 @@ public class LegacySchemaTables
         adder.add("body", function.body());
         adder.add("language", function.language());
         adder.add("return_type", function.returnType().toString());
+        adder.add("called_on_null_input", function.isCalledOnNullInput());
     }
 
     public static Mutation makeDropFunctionMutation(KSMetaData keyspace, UDFunction function, long timestamp)
@@ -1333,15 +1335,16 @@ public class LegacySchemaTables
 
         String language = row.getString("language");
         String body = row.getString("body");
+        boolean calledOnNullInput = row.getBoolean("called_on_null_input");
 
         try
         {
-            return UDFunction.create(name, argNames, argTypes, returnType, language, body);
+            return UDFunction.create(name, argNames, argTypes, returnType, calledOnNullInput, language, body);
         }
         catch (InvalidRequestException e)
         {
             logger.error(String.format("Cannot load function '%s' from schema: this function won't be available (on this node)", name), e);
-            return UDFunction.createBrokenFunction(name, argNames, argTypes, returnType, language, body, e);
+            return UDFunction.createBrokenFunction(name, argNames, argTypes, returnType, calledOnNullInput, language, body, e);
         }
     }
 

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/src/resources/org/apache/cassandra/cql3/functions/JavaSourceUDF.txt
----------------------------------------------------------------------
diff --git a/src/resources/org/apache/cassandra/cql3/functions/JavaSourceUDF.txt b/src/resources/org/apache/cassandra/cql3/functions/JavaSourceUDF.txt
index 985c8ff..f57b01e 100644
--- a/src/resources/org/apache/cassandra/cql3/functions/JavaSourceUDF.txt
+++ b/src/resources/org/apache/cassandra/cql3/functions/JavaSourceUDF.txt
@@ -13,12 +13,12 @@ import org.apache.cassandra.exceptions.InvalidRequestException;
 public final class #class_name# extends org.apache.cassandra.cql3.functions.UDFunction
 {
     public #class_name#(FunctionName name, List<ColumnIdentifier> argNames, List<AbstractType<?>> argTypes,
-                        DataType[] argDataTypes, AbstractType<?> returnType, DataType returnDataType, String body)
+                        DataType[] argDataTypes, AbstractType<?> returnType, DataType returnDataType, boolean calledOnNullInput, String body)
     {
-        super(name, argNames, argTypes, argDataTypes, returnType, returnDataType, "java", body);
+        super(name, argNames, argTypes, argDataTypes, returnType, returnDataType, calledOnNullInput, "java", body);
     }
 
-    public ByteBuffer execute(int protocolVersion, List<ByteBuffer> params) throws InvalidRequestException
+    protected ByteBuffer executeUserDefined(int protocolVersion, List<ByteBuffer> params) throws InvalidRequestException
     {
         try
         {

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/test/unit/org/apache/cassandra/cql3/AggregationTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/AggregationTest.java b/test/unit/org/apache/cassandra/cql3/AggregationTest.java
index 7fe665d..7bec8a1 100644
--- a/test/unit/org/apache/cassandra/cql3/AggregationTest.java
+++ b/test/unit/org/apache/cassandra/cql3/AggregationTest.java
@@ -106,9 +106,10 @@ public class AggregationTest extends CQLTester
         String copySign = createFunction(KEYSPACE,
                                          "double, double",
                                          "CREATE OR REPLACE FUNCTION %s(magnitude double, sign double) " +
+                                         "RETURNS NULL ON NULL INPUT " +
                                          "RETURNS double " +
                                          "LANGUAGE JAVA " +
-                                         "AS 'return Double.valueOf(Math.copySign(magnitude.doubleValue(), sign.doubleValue()));';");
+                                         "AS 'return Double.valueOf(Math.copySign(magnitude, sign));';");
 
         assertColumnNames(execute("SELECT max(a), max(unixTimestampOf(b)) FROM %s"), "system.max(a)", "system.max(system.unixtimestampof(b))");
         assertRows(execute("SELECT max(a), max(unixTimestampOf(b)) FROM %s"), row(null, null));
@@ -143,6 +144,7 @@ public class AggregationTest extends CQLTester
         String f = createFunction(KEYSPACE,
                                   "double, double",
                                   "CREATE OR REPLACE FUNCTION %s(state double, val double) " +
+                                  "RETURNS NULL ON NULL INPUT " +
                                   "RETURNS double " +
                                   "LANGUAGE javascript " +
                                   "AS '\"string\";';");
@@ -150,6 +152,7 @@ public class AggregationTest extends CQLTester
         createFunctionOverload(f,
                                "double, double",
                                "CREATE OR REPLACE FUNCTION %s(state int, val int) " +
+                               "RETURNS NULL ON NULL INPUT " +
                                "RETURNS int " +
                                "LANGUAGE javascript " +
                                "AS '\"string\";';");
@@ -158,7 +161,8 @@ public class AggregationTest extends CQLTester
                                    "double",
                                    "CREATE OR REPLACE AGGREGATE %s(double) " +
                                    "SFUNC " + shortFunctionName(f) + " " +
-                                   "STYPE double");
+                                   "STYPE double " +
+                                   "INITCOND 0");
 
         assertLastSchemaChange(Event.SchemaChange.Change.CREATED, Event.SchemaChange.Target.AGGREGATE,
                                KEYSPACE, parseFunctionName(a).name,
@@ -166,7 +170,8 @@ public class AggregationTest extends CQLTester
 
         schemaChange("CREATE OR REPLACE AGGREGATE " + a + "(double) " +
                      "SFUNC " + shortFunctionName(f) + " " +
-                     "STYPE double");
+                     "STYPE double " +
+                     "INITCOND 0");
 
         assertLastSchemaChange(Event.SchemaChange.Change.UPDATED, Event.SchemaChange.Target.AGGREGATE,
                                KEYSPACE, parseFunctionName(a).name,
@@ -176,7 +181,8 @@ public class AggregationTest extends CQLTester
                                 "int",
                                 "CREATE OR REPLACE AGGREGATE %s(int) " +
                                 "SFUNC " + shortFunctionName(f) + " " +
-                                "STYPE int");
+                                "STYPE int " +
+                                "INITCOND 0");
 
         assertLastSchemaChange(Event.SchemaChange.Change.CREATED, Event.SchemaChange.Target.AGGREGATE,
                                KEYSPACE, parseFunctionName(a).name,
@@ -195,6 +201,7 @@ public class AggregationTest extends CQLTester
         String f = createFunction(KEYSPACE,
                                   "double, double",
                                   "CREATE OR REPLACE FUNCTION %s(state double, val double) " +
+                                  "RETURNS NULL ON NULL INPUT " +
                                   "RETURNS double " +
                                   "LANGUAGE javascript " +
                                   "AS '\"string\";';");
@@ -202,6 +209,7 @@ public class AggregationTest extends CQLTester
         createFunctionOverload(f,
                                "double, double",
                                "CREATE OR REPLACE FUNCTION %s(state int, val int) " +
+                               "RETURNS NULL ON NULL INPUT " +
                                "RETURNS int " +
                                "LANGUAGE javascript " +
                                "AS '\"string\";';");
@@ -214,12 +222,14 @@ public class AggregationTest extends CQLTester
                                    "double",
                                    "CREATE OR REPLACE AGGREGATE %s(double) " +
                                    "SFUNC " + shortFunctionName(f) + " " +
-                                   "STYPE double");
+                                   "STYPE double " +
+                                   "INITCOND 0");
         createAggregateOverload(a,
                                 "int",
                                 "CREATE OR REPLACE AGGREGATE %s(int) " +
                                 "SFUNC " + shortFunctionName(f) + " " +
-                                "STYPE int");
+                                "STYPE int " +
+                                "INITCOND 0");
 
         // DROP FUNCTION must not succeed against an aggregate
         assertInvalidMessage("matches multiple function definitions", "DROP FUNCTION " + a);
@@ -243,6 +253,7 @@ public class AggregationTest extends CQLTester
         String f = createFunction(KEYSPACE,
                                   "double, double",
                                   "CREATE OR REPLACE FUNCTION %s(state double, val double) " +
+                                  "RETURNS NULL ON NULL INPUT " +
                                   "RETURNS double " +
                                   "LANGUAGE javascript " +
                                   "AS '\"string\";';");
@@ -251,7 +262,8 @@ public class AggregationTest extends CQLTester
                                    "double",
                                    "CREATE OR REPLACE AGGREGATE %s(double) " +
                                    "SFUNC " + shortFunctionName(f) + " " +
-                                   "STYPE double");
+                                   "STYPE double " +
+                                   "INITCOND 0");
 
         // DROP FUNCTION must not succeed because the function is still referenced by the aggregate
         assertInvalidMessage("still referenced by", "DROP FUNCTION " + f);
@@ -270,6 +282,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf((a!=null?a.intValue():0) + b.intValue());'");
@@ -277,6 +290,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE java " +
                                        "AS 'return a.toString();'");
@@ -307,6 +321,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf((a!=null?a.intValue():0) + b.intValue());'");
@@ -314,6 +329,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE java " +
                                        "AS 'return a.toString();'");
@@ -340,6 +356,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf((a!=null?a.intValue():0) + b.intValue());'");
@@ -347,6 +364,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE java " +
                                        "AS 'return a.toString();'");
@@ -365,6 +383,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf((a!=null?a.intValue():0) + b.intValue());'");
@@ -372,6 +391,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE java " +
                                        "AS 'return a.toString();'");
@@ -379,6 +399,7 @@ public class AggregationTest extends CQLTester
         String fState2 = createFunction(KEYSPACE,
                                         "int, int",
                                         "CREATE FUNCTION %s(a double, b double) " +
+                                        "CALLED ON NULL INPUT " +
                                         "RETURNS double " +
                                         "LANGUAGE java " +
                                         "AS 'return Double.valueOf((a!=null?a.doubleValue():0d) + b.doubleValue());'");
@@ -386,6 +407,7 @@ public class AggregationTest extends CQLTester
         String fFinal2 = createFunction(KEYSPACE,
                                         "int",
                                         "CREATE FUNCTION %s(a double) " +
+                                        "CALLED ON NULL INPUT " +
                                         "RETURNS text " +
                                         "LANGUAGE java " +
                                         "AS 'return a.toString();'");
@@ -433,6 +455,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf((a!=null?a.intValue():0) + b.intValue());'");
@@ -440,6 +463,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE java " +
                                        "AS 'return a.toString();'");
@@ -474,13 +498,15 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
-                                       "AS 'throw new RuntimeException();'");
+                                       "AS 'throw new RuntimeException(\"thrown to unit test - not a bug\");'");
 
         String fStateOK = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf(42);'");
@@ -488,13 +514,15 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE java " +
-                                       "AS 'throw new RuntimeException();'");
+                                       "AS 'throw new RuntimeException(\"thrown to unit test - not a bug\");'");
 
         String fFinalOK = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE java " +
                                        "AS 'return \"foobar\";'");
@@ -537,9 +565,10 @@ public class AggregationTest extends CQLTester
         String f = createFunction(KEYSPACE,
                                   "int, int",
                                   "CREATE FUNCTION %s(a int, b int) " +
+                                  "RETURNS NULL ON NULL INPUT " +
                                   "RETURNS int " +
                                   "LANGUAGE java " +
-                                  "AS 'return Integer.valueOf((a!=null?a.intValue():0) + b.intValue());'");
+                                  "AS 'return Integer.valueOf(a + b);'");
 
         assertInvalidMessage("does not exist or is not a scalar function",
                              "CREATE AGGREGATE " + KEYSPACE + ".jSumFooNE2(int) " +
@@ -561,6 +590,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf((a!=null?a.intValue():0) + b.intValue());'");
@@ -568,6 +598,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE java " +
                                        "AS 'return a.toString();'");
@@ -602,6 +633,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf((a!=null?a.intValue():0) + b.intValue());'");
@@ -637,6 +669,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "frozen<tuple<bigint, int>>, int",
                                        "CREATE FUNCTION %s(a frozen<tuple<bigint, int>>, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS frozen<tuple<bigint, int>> " +
                                        "LANGUAGE java " +
                                        "AS '" +
@@ -648,6 +681,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "frozen<tuple<bigint, int>>",
                                        "CREATE FUNCTION %s(a frozen<tuple<bigint, int>>) " +
+                                       "RETURNS NULL ON NULL INPUT " +
                                        "RETURNS double " +
                                        "LANGUAGE java " +
                                        "AS '" +
@@ -680,6 +714,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "RETURNS NULL ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE javascript " +
                                        "AS 'a + b;'");
@@ -687,6 +722,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "RETURNS NULL ON NULL INPUT " +
                                        "RETURNS text " +
                                        "LANGUAGE javascript " +
                                        "AS '\"\"+a'");
@@ -721,6 +757,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE javascript " +
                                        "AS 'a + b;'");
@@ -754,6 +791,7 @@ public class AggregationTest extends CQLTester
             String fState = createFunction(otherKS,
                                            "int, int",
                                            "CREATE FUNCTION %s(a int, b int) " +
+                                           "CALLED ON NULL INPUT " +
                                            "RETURNS int " +
                                            "LANGUAGE javascript " +
                                            "AS 'a + b;'");
@@ -796,6 +834,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE javascript " +
                                        "AS 'a + b;'");
@@ -819,6 +858,140 @@ public class AggregationTest extends CQLTester
     }
 
     @Test
+    public void testCalledOnNullInput() throws Throwable
+    {
+        String fStateNonNull = createFunction(KEYSPACE,
+                                              "int, int",
+                                              "CREATE OR REPLACE FUNCTION %s(state int, val int) " +
+                                              "RETURNS NULL ON NULL INPUT " +
+                                              "RETURNS int " +
+                                              "LANGUAGE java\n" +
+                                              "AS 'return Integer.valueOf(state + val);';");
+        String fStateNull = createFunction(KEYSPACE,
+                                           "int, int",
+                                           "CREATE OR REPLACE FUNCTION %s(state int, val int) " +
+                                           "CALLED ON NULL INPUT " +
+                                           "RETURNS int " +
+                                           "LANGUAGE java\n" +
+                                           "AS 'return Integer.valueOf(" +
+                                           "   (state != null ? state.intValue() : 0) " +
+                                           "   + (val != null ? val.intValue() : 0));';");
+        String fStateAlwaysNull = createFunction(KEYSPACE,
+                                           "int, int",
+                                           "CREATE OR REPLACE FUNCTION %s(state int, val int) " +
+                                           "CALLED ON NULL INPUT " +
+                                           "RETURNS int " +
+                                           "LANGUAGE java\n" +
+                                           "AS 'return null;';");
+        String fFinalNonNull = createFunction(KEYSPACE,
+                                              "int",
+                                              "CREATE OR REPLACE FUNCTION %s(state int) " +
+                                              "RETURNS NULL ON NULL INPUT " +
+                                              "RETURNS int " +
+                                              "LANGUAGE java\n" +
+                                              "AS 'return Integer.valueOf(state);';");
+        String fFinalNull = createFunction(KEYSPACE,
+                                           "int",
+                                           "CREATE OR REPLACE FUNCTION %s(state int) " +
+                                           "CALLED ON NULL INPUT " +
+                                           "RETURNS int " +
+                                           "LANGUAGE java\n" +
+                                           "AS 'return state;';");
+
+        assertInvalid("CREATE AGGREGATE " + KEYSPACE + ".invAggr(int) " +
+                      "SFUNC " + shortFunctionName(fStateNonNull) + " " +
+                      "STYPE int");
+        assertInvalid("CREATE AGGREGATE " + KEYSPACE + ".invAggr(int) " +
+                      "SFUNC " + shortFunctionName(fStateNonNull) + " " +
+                      "STYPE int " +
+                      "FINALFUNC " + shortFunctionName(fFinalNonNull));
+
+        String aStateNull = createAggregate(KEYSPACE,
+                                               "int",
+                                               "CREATE AGGREGATE %s(int) " +
+                                               "SFUNC " + shortFunctionName(fStateNull) + " " +
+                                               "STYPE int");
+        String aStateNullFinalNull = createAggregate(KEYSPACE,
+                                                        "int",
+                                                        "CREATE AGGREGATE %s(int) " +
+                                                        "SFUNC " + shortFunctionName(fStateNull) + " " +
+                                                        "STYPE int " +
+                                                        "FINALFUNC " + shortFunctionName(fFinalNull));
+        String aStateNullFinalNonNull = createAggregate(KEYSPACE,
+                                                        "int",
+                                                        "CREATE AGGREGATE %s(int) " +
+                                                        "SFUNC " + shortFunctionName(fStateNull) + " " +
+                                                        "STYPE int " +
+                                                        "FINALFUNC " + shortFunctionName(fFinalNonNull));
+        String aStateNonNull = createAggregate(KEYSPACE,
+                                               "int",
+                                               "CREATE AGGREGATE %s(int) " +
+                                               "SFUNC " + shortFunctionName(fStateNonNull) + " " +
+                                               "STYPE int " +
+                                               "INITCOND 0");
+        String aStateNonNullFinalNull = createAggregate(KEYSPACE,
+                                                        "int",
+                                                        "CREATE AGGREGATE %s(int) " +
+                                                        "SFUNC " + shortFunctionName(fStateNonNull) + " " +
+                                                        "STYPE int " +
+                                                        "FINALFUNC " + shortFunctionName(fFinalNull) + " " +
+                                                        "INITCOND 0");
+        String aStateNonNullFinalNonNull = createAggregate(KEYSPACE,
+                                                           "int",
+                                                           "CREATE AGGREGATE %s(int) " +
+                                                           "SFUNC " + shortFunctionName(fStateNonNull) + " " +
+                                                           "STYPE int " +
+                                                           "FINALFUNC " + shortFunctionName(fFinalNonNull) + " " +
+                                                           "INITCOND 0");
+        String aStateAlwaysNullFinalNull = createAggregate(KEYSPACE,
+                                                           "int",
+                                                           "CREATE AGGREGATE %s(int) " +
+                                                           "SFUNC " + shortFunctionName(fStateAlwaysNull) + " " +
+                                                           "STYPE int " +
+                                                           "FINALFUNC " + shortFunctionName(fFinalNull));
+        String aStateAlwaysNullFinalNonNull = createAggregate(KEYSPACE,
+                                                           "int",
+                                                           "CREATE AGGREGATE %s(int) " +
+                                                           "SFUNC " + shortFunctionName(fStateAlwaysNull) + " " +
+                                                           "STYPE int " +
+                                                           "FINALFUNC " + shortFunctionName(fFinalNonNull));
+
+        createTable("CREATE TABLE %s (key int PRIMARY KEY, i int)");
+
+        execute("INSERT INTO %s (key, i) VALUES (0, null)");
+        execute("INSERT INTO %s (key, i) VALUES (1, 1)");
+        execute("INSERT INTO %s (key, i) VALUES (2, 2)");
+        execute("INSERT INTO %s (key, i) VALUES (3, 3)");
+
+        assertRows(execute("SELECT " + aStateNull + "(i) FROM %s WHERE key = 0"), row(0));
+        assertRows(execute("SELECT " + aStateNullFinalNull + "(i) FROM %s WHERE key = 0"), row(0));
+        assertRows(execute("SELECT " + aStateNullFinalNonNull + "(i) FROM %s WHERE key = 0"), row(0));
+        assertRows(execute("SELECT " + aStateNonNull + "(i) FROM %s WHERE key = 0"), row(0));
+        assertRows(execute("SELECT " + aStateNonNullFinalNull + "(i) FROM %s WHERE key = 0"), row(0));
+        assertRows(execute("SELECT " + aStateNonNullFinalNonNull + "(i) FROM %s WHERE key = 0"), row(0));
+        assertRows(execute("SELECT " + aStateAlwaysNullFinalNull + "(i) FROM %s WHERE key = 0"), row(new Object[]{null}));
+        assertRows(execute("SELECT " + aStateAlwaysNullFinalNonNull + "(i) FROM %s WHERE key = 0"), row(new Object[]{null}));
+
+        assertRows(execute("SELECT " + aStateNull + "(i) FROM %s WHERE key = 1"), row(1));
+        assertRows(execute("SELECT " + aStateNullFinalNull + "(i) FROM %s WHERE key = 1"), row(1));
+        assertRows(execute("SELECT " + aStateNullFinalNonNull + "(i) FROM %s WHERE key = 1"), row(1));
+        assertRows(execute("SELECT " + aStateNonNull + "(i) FROM %s WHERE key = 1"), row(1));
+        assertRows(execute("SELECT " + aStateNonNullFinalNull + "(i) FROM %s WHERE key = 1"), row(1));
+        assertRows(execute("SELECT " + aStateNonNullFinalNonNull + "(i) FROM %s WHERE key = 1"), row(1));
+        assertRows(execute("SELECT " + aStateAlwaysNullFinalNull + "(i) FROM %s WHERE key = 1"), row(new Object[]{null}));
+        assertRows(execute("SELECT " + aStateAlwaysNullFinalNonNull + "(i) FROM %s WHERE key = 1"), row(new Object[]{null}));
+
+        assertRows(execute("SELECT " + aStateNull + "(i) FROM %s WHERE key IN (1, 2, 3)"), row(6));
+        assertRows(execute("SELECT " + aStateNullFinalNull + "(i) FROM %s WHERE key IN (1, 2, 3)"), row(6));
+        assertRows(execute("SELECT " + aStateNullFinalNonNull + "(i) FROM %s WHERE key IN (1, 2, 3)"), row(6));
+        assertRows(execute("SELECT " + aStateNonNull + "(i) FROM %s WHERE key IN (1, 2, 3)"), row(6));
+        assertRows(execute("SELECT " + aStateNonNullFinalNull + "(i) FROM %s WHERE key IN (1, 2, 3)"), row(6));
+        assertRows(execute("SELECT " + aStateNonNullFinalNonNull + "(i) FROM %s WHERE key IN (1, 2, 3)"), row(6));
+        assertRows(execute("SELECT " + aStateAlwaysNullFinalNull + "(i) FROM %s WHERE key IN (1, 2, 3)"), row(new Object[]{null}));
+        assertRows(execute("SELECT " + aStateAlwaysNullFinalNonNull + "(i) FROM %s WHERE key IN (1, 2, 3)"), row(new Object[]{null}));
+    }
+
+    @Test
     public void testBrokenAggregate() throws Throwable
     {
         createTable("CREATE TABLE %s (key int primary key, val int)");
@@ -827,6 +1000,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE javascript " +
                                        "AS 'a + b;'");
@@ -855,6 +1029,7 @@ public class AggregationTest extends CQLTester
         String fState = createFunction(KEYSPACE,
                                        "int, int",
                                        "CREATE FUNCTION %s(a int, b int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS double " +
                                        "LANGUAGE java " +
                                        "AS 'return Double.valueOf(1.0);'");
@@ -862,6 +1037,7 @@ public class AggregationTest extends CQLTester
         String fFinal = createFunction(KEYSPACE,
                                        "int",
                                        "CREATE FUNCTION %s(a int) " +
+                                       "CALLED ON NULL INPUT " +
                                        "RETURNS int " +
                                        "LANGUAGE java " +
                                        "AS 'return Integer.valueOf(1);';");

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/test/unit/org/apache/cassandra/cql3/JsonTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/JsonTest.java b/test/unit/org/apache/cassandra/cql3/JsonTest.java
index 305502e..0380ddd 100644
--- a/test/unit/org/apache/cassandra/cql3/JsonTest.java
+++ b/test/unit/org/apache/cassandra/cql3/JsonTest.java
@@ -72,8 +72,8 @@ public class JsonTest extends CQLTester
         // fromJson() can only be used when the receiver type is known
         assertInvalidMessage("fromJson() cannot be used in the selection clause", "SELECT fromJson(asciival) FROM %s", 0, 0);
 
-        String func1 = createFunction(KEYSPACE, "int", "CREATE FUNCTION %s (a int) RETURNS text LANGUAGE java AS $$ return a.toString(); $$");
-        createFunctionOverload(func1, "int", "CREATE FUNCTION %s (a text) RETURNS text LANGUAGE java AS $$ return new String(a); $$");
+        String func1 = createFunction(KEYSPACE, "int", "CREATE FUNCTION %s (a int) CALLED ON NULL INPUT RETURNS text LANGUAGE java AS $$ return a.toString(); $$");
+        createFunctionOverload(func1, "int", "CREATE FUNCTION %s (a text) CALLED ON NULL INPUT RETURNS text LANGUAGE java AS $$ return new String(a); $$");
 
         assertInvalidMessage("Ambiguous call to function",
                 "INSERT INTO %s (k, textval) VALUES (?, " + func1 + "(fromJson(?)))", 0, "123");

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/test/unit/org/apache/cassandra/cql3/PgStringTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/PgStringTest.java b/test/unit/org/apache/cassandra/cql3/PgStringTest.java
index 1870a9a..0a9d702 100644
--- a/test/unit/org/apache/cassandra/cql3/PgStringTest.java
+++ b/test/unit/org/apache/cassandra/cql3/PgStringTest.java
@@ -26,7 +26,7 @@ public class PgStringTest extends CQLTester
     @Test
     public void testPgSyleFunction() throws Throwable
     {
-        execute("create or replace function "+KEYSPACE+".pgfun1 ( input double ) returns text language java\n" +
+        execute("create or replace function "+KEYSPACE+".pgfun1 ( input double ) called on null input returns text language java\n" +
                 "AS $$return \"foobar\";$$");
     }
 
@@ -70,7 +70,7 @@ public class PgStringTest extends CQLTester
     public void testMarkerPgFail() throws Throwable
     {
         // must throw SyntaxException - not StringIndexOutOfBoundsException or similar
-        execute("create function "+KEYSPACE+".pgfun1 ( input double ) returns text language java\n" +
+        execute("create function "+KEYSPACE+".pgfun1 ( input double ) called on null input returns bigint language java\n" +
                 "AS $javasrc$return 0L;$javasrc$;");
     }
 }

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/test/unit/org/apache/cassandra/cql3/UFAuthTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/UFAuthTest.java b/test/unit/org/apache/cassandra/cql3/UFAuthTest.java
index 1d63d29..2c36bd1 100644
--- a/test/unit/org/apache/cassandra/cql3/UFAuthTest.java
+++ b/test/unit/org/apache/cassandra/cql3/UFAuthTest.java
@@ -249,6 +249,7 @@ public class UFAuthTest extends CQLTester
         String innerFunctionName = createSimpleFunction();
         String outerFunctionName = createFunction("int",
                                                   "CREATE FUNCTION %s(input int) " +
+                                                  " CALLED ON NULL INPUT" +
                                                   " RETURNS int" +
                                                   " LANGUAGE java" +
                                                   " AS 'return Integer.valueOf(0);'");
@@ -389,6 +390,7 @@ public class UFAuthTest extends CQLTester
     {
         String outerFunc = createFunction("int",
                                           "CREATE FUNCTION %s(input int) " +
+                                          "CALLED ON NULL INPUT " +
                                           "RETURNS int " +
                                           "LANGUAGE java " +
                                           "AS 'return input;'");
@@ -420,6 +422,7 @@ public class UFAuthTest extends CQLTester
     {
         String innerFunc = createFunction("int",
                                           "CREATE FUNCTION %s(input int) " +
+                                          "CALLED ON NULL INPUT " +
                                           "RETURNS int " +
                                           "LANGUAGE java " +
                                           "AS 'return input;'");
@@ -568,6 +571,7 @@ public class UFAuthTest extends CQLTester
     {
         return createFunction("int, int",
                               "CREATE FUNCTION %s(a int, b int) " +
+                              "CALLED ON NULL INPUT " +
                               "RETURNS int " +
                               "LANGUAGE java " +
                               "AS 'return Integer.valueOf( (a != null ? a.intValue() : 0 ) + b.intValue());'");
@@ -577,6 +581,7 @@ public class UFAuthTest extends CQLTester
     {
         return createFunction("int",
                               "CREATE FUNCTION %s(a int) " +
+                              "CALLED ON NULL INPUT " +
                               "RETURNS int " +
                               "LANGUAGE java " +
                               "AS 'return a;'");
@@ -586,6 +591,7 @@ public class UFAuthTest extends CQLTester
     {
         return createFunction("",
                               "CREATE FUNCTION %s() " +
+                              "  CALLED ON NULL INPUT " +
                               "  RETURNS int " +
                               "  LANGUAGE java " +
                               "  AS 'return Integer.valueOf(0);'");

http://git-wip-us.apache.org/repos/asf/cassandra/blob/1937bed9/test/unit/org/apache/cassandra/cql3/UFIdentificationTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/UFIdentificationTest.java b/test/unit/org/apache/cassandra/cql3/UFIdentificationTest.java
index 044c98b..7cac252 100644
--- a/test/unit/org/apache/cassandra/cql3/UFIdentificationTest.java
+++ b/test/unit/org/apache/cassandra/cql3/UFIdentificationTest.java
@@ -368,6 +368,7 @@ public class UFIdentificationTest extends CQLTester
     {
         return createFunction(KEYSPACE, type,
            "CREATE FUNCTION %s(input " + type + ")" +
+           " CALLED ON NULL INPUT" +
            " RETURNS " + type +
            " LANGUAGE java" +
            " AS ' return input;'");