You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by sn...@apache.org on 2015/07/12 21:22:58 UTC

[2/3] cassandra git commit: sum() and avg() functions missing for smallint and tinyint types

sum() and avg() functions missing for smallint and tinyint types

patch by Robert Stupp; reviewed by Aleksey Yeschenko for CASSANDRA-9671


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/01f3d0a1
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/01f3d0a1
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/01f3d0a1

Branch: refs/heads/trunk
Commit: 01f3d0a15476ccada7cefeb3c4fbbc157404fc8b
Parents: fc202a7
Author: Robert Stupp <sn...@snazy.de>
Authored: Sun Jul 12 21:19:20 2015 +0200
Committer: Robert Stupp <sn...@snazy.de>
Committed: Sun Jul 12 21:19:20 2015 +0200

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 .../cassandra/cql3/functions/AggregateFcts.java | 158 +++++++++++++++++++
 .../cassandra/cql3/functions/Functions.java     |   4 +
 .../validation/operations/AggregationTest.java  |  45 ++++--
 4 files changed, 195 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/01f3d0a1/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 15796c4..03458b2 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 2.2.0-rc3
+ * sum() and avg() functions missing for smallint and tinyint types (CASSANDRA-9671)
  * Revert CASSANDRA-9542 (allow native functions in UDA) (CASSANDRA-9771)
 Merged from 2.1:
  * Fix clientutil jar and tests (CASSANDRA-9760)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/01f3d0a1/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
index 865dfbf..1b22da6 100644
--- a/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
+++ b/src/java/org/apache/cassandra/cql3/functions/AggregateFcts.java
@@ -23,12 +23,14 @@ import java.nio.ByteBuffer;
 import java.util.List;
 
 import org.apache.cassandra.db.marshal.AbstractType;
+import org.apache.cassandra.db.marshal.ByteType;
 import org.apache.cassandra.db.marshal.DecimalType;
 import org.apache.cassandra.db.marshal.DoubleType;
 import org.apache.cassandra.db.marshal.FloatType;
 import org.apache.cassandra.db.marshal.Int32Type;
 import org.apache.cassandra.db.marshal.IntegerType;
 import org.apache.cassandra.db.marshal.LongType;
+import org.apache.cassandra.db.marshal.ShortType;
 
 /**
  * Factory methods for aggregate functions.
@@ -228,6 +230,162 @@ public abstract class AggregateFcts
     /**
      * The SUM function for int32 values.
      */
+    public static final AggregateFunction sumFunctionForByte =
+            new NativeAggregateFunction("sum", ByteType.instance, ByteType.instance)
+            {
+                public Aggregate newAggregate()
+                {
+                    return new Aggregate()
+                    {
+                        private byte sum;
+
+                        public void reset()
+                        {
+                            sum = 0;
+                        }
+
+                        public ByteBuffer compute(int protocolVersion)
+                        {
+                            return ((ByteType) returnType()).decompose(sum);
+                        }
+
+                        public void addInput(int protocolVersion, List<ByteBuffer> values)
+                        {
+                            ByteBuffer value = values.get(0);
+
+                            if (value == null)
+                                return;
+
+                            Number number = ((Number) argTypes().get(0).compose(value));
+                            sum += number.byteValue();
+                        }
+                    };
+                }
+            };
+
+    /**
+     * AVG function for int32 values.
+     */
+    public static final AggregateFunction avgFunctionForByte =
+            new NativeAggregateFunction("avg", ByteType.instance, ByteType.instance)
+            {
+                public Aggregate newAggregate()
+                {
+                    return new Aggregate()
+                    {
+                        private byte sum;
+
+                        private int count;
+
+                        public void reset()
+                        {
+                            count = 0;
+                            sum = 0;
+                        }
+
+                        public ByteBuffer compute(int protocolVersion)
+                        {
+                            int avg = count == 0 ? 0 : sum / count;
+
+                            return ((ByteType) returnType()).decompose((byte) avg);
+                        }
+
+                        public void addInput(int protocolVersion, List<ByteBuffer> values)
+                        {
+                            ByteBuffer value = values.get(0);
+
+                            if (value == null)
+                                return;
+
+                            count++;
+                            Number number = ((Number) argTypes().get(0).compose(value));
+                            sum += number.byteValue();
+                        }
+                    };
+                }
+            };
+
+    /**
+     * The SUM function for int32 values.
+     */
+    public static final AggregateFunction sumFunctionForShort =
+            new NativeAggregateFunction("sum", ShortType.instance, ShortType.instance)
+            {
+                public Aggregate newAggregate()
+                {
+                    return new Aggregate()
+                    {
+                        private short sum;
+
+                        public void reset()
+                        {
+                            sum = 0;
+                        }
+
+                        public ByteBuffer compute(int protocolVersion)
+                        {
+                            return ((ShortType) returnType()).decompose(sum);
+                        }
+
+                        public void addInput(int protocolVersion, List<ByteBuffer> values)
+                        {
+                            ByteBuffer value = values.get(0);
+
+                            if (value == null)
+                                return;
+
+                            Number number = ((Number) argTypes().get(0).compose(value));
+                            sum += number.shortValue();
+                        }
+                    };
+                }
+            };
+
+    /**
+     * AVG function for int32 values.
+     */
+    public static final AggregateFunction avgFunctionForShort =
+            new NativeAggregateFunction("avg", ShortType.instance, ShortType.instance)
+            {
+                public Aggregate newAggregate()
+                {
+                    return new Aggregate()
+                    {
+                        private short sum;
+
+                        private int count;
+
+                        public void reset()
+                        {
+                            count = 0;
+                            sum = 0;
+                        }
+
+                        public ByteBuffer compute(int protocolVersion)
+                        {
+                            int avg = count == 0 ? 0 : sum / count;
+
+                            return ((ShortType) returnType()).decompose((short) avg);
+                        }
+
+                        public void addInput(int protocolVersion, List<ByteBuffer> values)
+                        {
+                            ByteBuffer value = values.get(0);
+
+                            if (value == null)
+                                return;
+
+                            count++;
+                            Number number = ((Number) argTypes().get(0).compose(value));
+                            sum += number.shortValue();
+                        }
+                    };
+                }
+            };
+
+    /**
+     * The SUM function for int32 values.
+     */
     public static final AggregateFunction sumFunctionForInt32 =
             new NativeAggregateFunction("sum", Int32Type.instance, Int32Type.instance)
             {

http://git-wip-us.apache.org/repos/asf/cassandra/blob/01f3d0a1/src/java/org/apache/cassandra/cql3/functions/Functions.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/cql3/functions/Functions.java b/src/java/org/apache/cassandra/cql3/functions/Functions.java
index 85f2817..e31fc9f 100644
--- a/src/java/org/apache/cassandra/cql3/functions/Functions.java
+++ b/src/java/org/apache/cassandra/cql3/functions/Functions.java
@@ -83,12 +83,16 @@ public abstract class Functions
                 declare(AggregateFcts.makeMinFunction(type.getType()));
             }
         }
+        declare(AggregateFcts.sumFunctionForByte);
+        declare(AggregateFcts.sumFunctionForShort);
         declare(AggregateFcts.sumFunctionForInt32);
         declare(AggregateFcts.sumFunctionForLong);
         declare(AggregateFcts.sumFunctionForFloat);
         declare(AggregateFcts.sumFunctionForDouble);
         declare(AggregateFcts.sumFunctionForDecimal);
         declare(AggregateFcts.sumFunctionForVarint);
+        declare(AggregateFcts.avgFunctionForByte);
+        declare(AggregateFcts.avgFunctionForShort);
         declare(AggregateFcts.avgFunctionForInt32);
         declare(AggregateFcts.avgFunctionForLong);
         declare(AggregateFcts.avgFunctionForFloat);

http://git-wip-us.apache.org/repos/asf/cassandra/blob/01f3d0a1/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
----------------------------------------------------------------------
diff --git a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
index 7455dbc..62461b8 100644
--- a/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
+++ b/test/unit/org/apache/cassandra/cql3/validation/operations/AggregationTest.java
@@ -42,27 +42,46 @@ public class AggregationTest extends CQLTester
     @Test
     public void testFunctions() throws Throwable
     {
-        createTable("CREATE TABLE %s (a int, b int, c double, d decimal, primary key (a, b))");
+        createTable("CREATE TABLE %s (a int, b int, c double, d decimal, e smallint, f tinyint, primary key (a, b))");
 
         // Test with empty table
         assertColumnNames(execute("SELECT COUNT(*) FROM %s"), "count");
         assertRows(execute("SELECT COUNT(*) FROM %s"), row(0L));
-        assertColumnNames(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d) FROM %s"),
-                          "system.max(b)", "system.min(b)", "system.sum(b)", "system.avg(b)", "system.max(c)", "system.sum(c)", "system.avg(c)", "system.sum(d)", "system.avg(d)");
-        assertRows(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d) FROM %s"),
-                   row(null, null, 0, 0, null, 0.0, 0.0, new BigDecimal("0"), new BigDecimal("0")));
-
-        execute("INSERT INTO %s (a, b, c, d) VALUES (1, 1, 11.5, 11.5)");
-        execute("INSERT INTO %s (a, b, c, d) VALUES (1, 2, 9.5, 1.5)");
-        execute("INSERT INTO %s (a, b, c, d) VALUES (1, 3, 9.0, 2.0)");
-
-        assertRows(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d) FROM %s"),
-                   row(3, 1, 6, 2, 11.5, 30.0, 10.0, new BigDecimal("15.0"), new BigDecimal("5.0")));
+        assertColumnNames(execute("SELECT max(b), min(b), sum(b), avg(b)," +
+                                  "max(c), sum(c), avg(c)," +
+                                  "sum(d), avg(d)," +
+                                  "max(e), min(e), sum(e), avg(e)," +
+                                  "max(f), min(f), sum(f), avg(f) FROM %s"),
+                          "system.max(b)", "system.min(b)", "system.sum(b)", "system.avg(b)",
+                          "system.max(c)", "system.sum(c)", "system.avg(c)",
+                          "system.sum(d)", "system.avg(d)",
+                          "system.max(e)", "system.min(e)", "system.sum(e)", "system.avg(e)",
+                          "system.max(f)", "system.min(f)", "system.sum(f)", "system.avg(f)");
+        assertRows(execute("SELECT max(b), min(b), sum(b), avg(b)," +
+                           "max(c), sum(c), avg(c)," +
+                           "sum(d), avg(d)," +
+                           "max(e), min(e), sum(e), avg(e)," +
+                           "max(f), min(f), sum(f), avg(f) FROM %s"),
+                   row(null, null, 0, 0, null, 0.0, 0.0, new BigDecimal("0"), new BigDecimal("0"),
+                       null, null, (short)0, (short)0,
+                       null, null, (byte)0, (byte)0));
+
+        execute("INSERT INTO %s (a, b, c, d, e, f) VALUES (1, 1, 11.5, 11.5, 1, 1)");
+        execute("INSERT INTO %s (a, b, c, d, e, f) VALUES (1, 2, 9.5, 1.5, 2, 2)");
+        execute("INSERT INTO %s (a, b, c, d, e, f) VALUES (1, 3, 9.0, 2.0, 3, 3)");
+
+        assertRows(execute("SELECT max(b), min(b), sum(b), avg(b) , max(c), sum(c), avg(c), sum(d), avg(d)," +
+                           "max(e), min(e), sum(e), avg(e)," +
+                           "max(f), min(f), sum(f), avg(f)" +
+                           " FROM %s"),
+                   row(3, 1, 6, 2, 11.5, 30.0, 10.0, new BigDecimal("15.0"), new BigDecimal("5.0"),
+                       (short)3, (short)1, (short)6, (short)2,
+                       (byte)3, (byte)1, (byte)6, (byte)2));
 
         execute("INSERT INTO %s (a, b, d) VALUES (1, 5, 1.0)");
         assertRows(execute("SELECT COUNT(*) FROM %s"), row(4L));
         assertRows(execute("SELECT COUNT(1) FROM %s"), row(4L));
-        assertRows(execute("SELECT COUNT(b), count(c) FROM %s"), row(4L, 3L));
+        assertRows(execute("SELECT COUNT(b), count(c), count(e), count(f) FROM %s"), row(4L, 3L, 3L, 3L));
     }
 
     @Test