You are viewing a plain text version of this content. The canonical link for it is here.
Posted to pr@cassandra.apache.org by GitBox <gi...@apache.org> on 2022/11/24 10:13:10 UTC

[GitHub] [cassandra] blerer commented on a diff in pull request #2024: CASSANDRA-18060 trunk: Add aggregation scalar functions on collections

blerer commented on code in PR #2024:
URL: https://github.com/apache/cassandra/pull/2024#discussion_r1031328860


##########
src/java/org/apache/cassandra/serializers/ListSerializer.java:
##########
@@ -255,4 +256,33 @@ public Range<Integer> getIndexesRangeFromSerialized(ByteBuffer collection,
     {
         throw new UnsupportedOperationException();
     }
+
+    @Override
+    public void forEach(ByteBuffer input, ProtocolVersion version, Consumer<ByteBuffer> action)
+    {
+        try
+        {
+            int s = readCollectionSize(input, ByteBufferAccessor.instance, ProtocolVersion.V3);
+            int offset = sizeOfCollectionSize(s, ProtocolVersion.V3);
+
+            for (int i = 0; i < s; i++)
+            {
+                int size = ByteBufferAccessor.instance.getInt(input, offset);
+                if (size < 0)
+                    continue;
+
+                offset += TypeSizes.INT_SIZE;
+
+                ByteBuffer value = ByteBufferAccessor.instance.slice(input, offset, size);
+
+                action.accept(value);
+
+                offset += size;

Review Comment:
   ListSerializer inherit from CollectionSerializer so it should be possible to use the same logic that the one used for SetSerializer. Am I missing something?



##########
src/java/org/apache/cassandra/cql3/functions/CollectionFcts.java:
##########
@@ -0,0 +1,371 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.cassandra.cql3.functions;
+
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.google.common.collect.ImmutableList;
+
+import org.apache.cassandra.cql3.CQL3Type;
+import org.apache.cassandra.db.marshal.AbstractType;
+import org.apache.cassandra.db.marshal.CollectionType;
+import org.apache.cassandra.db.marshal.Int32Type;
+import org.apache.cassandra.db.marshal.ListType;
+import org.apache.cassandra.db.marshal.MapType;
+import org.apache.cassandra.db.marshal.SetType;
+import org.apache.cassandra.transport.ProtocolVersion;
+
+/**
+ * Native CQL functions for collections (sets, list and maps).
+ * <p>
+ * All the functions provided here are {@link NativeScalarFunction}, and they are meant to be applied to single
+ * collection values to perform some kind of aggregation with the elements of the collection argument. When possible,
+ * the implementation of these aggregation functions is based on the accross-rows aggregation functions available on
+ * {@link AggregateFcts}, so both across-rows and within-collection aggregations have the same behaviour.
+ */
+public class CollectionFcts
+{
+    public static void addFunctionsTo(NativeFunctions functions)
+    {
+        functions.add(new FunctionFactory("map_keys", FunctionParameter.anyMap())
+        {
+            @Override
+            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType)
+            {
+                return makeMapKeysFunction(name.name, (MapType<?, ?>) argTypes.get(0));
+            }
+        });
+
+        functions.add(new FunctionFactory("map_values", FunctionParameter.anyMap())
+        {
+            @Override
+            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType)
+            {
+                return makeMapValuesFunction(name.name, (MapType<?, ?>) argTypes.get(0));
+            }
+        });
+
+        functions.add(new FunctionFactory("collection_count", FunctionParameter.anyCollection())
+        {
+            @Override
+            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType)
+            {
+                return makeCollectionCountFunction(name.name, (CollectionType<?>) argTypes.get(0));
+            }
+        });
+
+        functions.add(new FunctionFactory("collection_min", FunctionParameter.setOrList())
+        {
+            @Override
+            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType)
+            {
+                return makeCollectionMinFunction(name.name, (CollectionType<?>) argTypes.get(0));
+            }
+        });
+
+        functions.add(new FunctionFactory("collection_max", FunctionParameter.setOrList())
+        {
+            @Override
+            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType)
+            {
+                return makeCollectionMaxFunction(name.name, (CollectionType<?>) argTypes.get(0));
+            }
+        });
+
+        functions.add(new FunctionFactory("collection_sum", FunctionParameter.numericSetOrList())
+        {
+            @Override
+            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType)
+            {
+                return makeCollectionSumFunction(name.name, (CollectionType<?>) argTypes.get(0));
+            }
+        });
+
+        functions.add(new FunctionFactory("collection_avg", FunctionParameter.numericSetOrList())
+        {
+            @Override
+            protected NativeFunction doGetOrCreateFunction(List<AbstractType<?>> argTypes, AbstractType<?> receiverType)
+            {
+                return makeCollectionAvgFunction(name.name, (CollectionType<?>) argTypes.get(0));
+            }
+        });
+    }
+
+    /**
+     * Returns a native scalar function for getting the keys of a map column, as a set.
+     *
+     * @param name      the name of the function
+     * @param inputType the type of the map argument by the returned function
+     * @param <K>       the class of the map argument keys
+     * @param <V>       the class of the map argument values
+     * @return a function returning a serialized set containing the keys of the map passed as argument
+     */
+    private static <K, V> NativeScalarFunction makeMapKeysFunction(String name, MapType<K, V> inputType)
+    {
+        SetType<K> outputType = SetType.getInstance(inputType.getKeysType(), false);
+
+        return new NativeScalarFunction(name, outputType, inputType)
+        {
+            @Override
+            public ByteBuffer execute(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
+            {
+                ByteBuffer value = parameters.get(0);
+                if (value == null)
+                    return null;
+
+                Map<K, V> map = inputType.compose(value);
+                Set<K> keys = map.keySet();
+                return outputType.decompose(keys);
+            }
+        };
+    }
+
+    /**
+     * Returns a native scalar function for getting the values of a map column, as a list.
+     *
+     * @param name      the name of the function
+     * @param inputType the type of the map argument accepted by the returned function
+     * @param <K>       the class of the map argument keys
+     * @param <V>       the class of the map argument values
+     * @return a function returning a serialized list containing the values of the map passed as argument
+     */
+    private static <K, V> NativeScalarFunction makeMapValuesFunction(String name, MapType<K, V> inputType)
+    {
+        ListType<V> outputType = ListType.getInstance(inputType.getValuesType(), false);
+
+        return new NativeScalarFunction(name, outputType, inputType)
+        {
+            @Override
+            public ByteBuffer execute(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
+            {
+                ByteBuffer value = parameters.get(0);
+                if (value == null)
+                    return null;
+
+                Map<K, V> map = inputType.compose(value);
+                List<V> values = ImmutableList.copyOf(map.values());
+                return outputType.decompose(values);
+            }
+        };
+    }
+
+    /**
+     * Returns a native scalar function for getting the number of elements in a collection.
+     *
+     * @param name      the name of the function
+     * @param inputType the type of the collection argument accepted by the returned function
+     * @param <T>       the type of the elements of the collection argument
+     * @return a function returning the number of elements in the collection passed as argument, as a 32-bit integer
+     */
+    private static <T> NativeScalarFunction makeCollectionCountFunction(String name, CollectionType<T> inputType)
+    {
+        return new NativeScalarFunction(name, Int32Type.instance, inputType)
+        {
+            @Override
+            public ByteBuffer execute(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
+            {
+                ByteBuffer value = parameters.get(0);
+                if (value == null)
+                    return null;
+
+                int size = inputType.size(value);
+                return Int32Type.instance.decompose(size);
+            }
+        };
+    }
+
+    /**
+     * Returns a native scalar function for getting the min element in a collection.
+     *
+     * @param name      the name of the function
+     * @param inputType the type of the collection argument accepted by the returned function
+     * @param <T>       the type of the elements of the collection argument
+     * @return a function returning the min element in the collection passed as argument
+     */
+    private static <T> NativeScalarFunction makeCollectionMinFunction(String name, CollectionType<T> inputType)
+    {
+        AbstractType<?> elementsType = elementsType(inputType);
+        NativeAggregateFunction function = elementsType.isCounter()
+                                           ? AggregateFcts.minFunctionForCounter
+                                           : AggregateFcts.makeMinFunction(elementsType);
+        return new CollectionAggregationFunction(name, inputType, function);
+    }
+
+    /**
+     * Returns a native scalar function for getting the max element in a collection.
+     *
+     * @param name      the name of the function
+     * @param inputType the type of the collection argument accepted by the returned function
+     * @param <T>       the type of the elements of the collection argument
+     * @return a function returning the max element in the collection passed as argument
+     */
+    private static <T> NativeScalarFunction makeCollectionMaxFunction(String name, CollectionType<T> inputType)
+    {
+        AbstractType<?> elementsType = elementsType(inputType);
+        NativeAggregateFunction function = elementsType.isCounter()
+                                           ? AggregateFcts.maxFunctionForCounter
+                                           : AggregateFcts.makeMaxFunction(elementsType);
+        return new CollectionAggregationFunction(name, inputType, function);
+    }
+
+    /**
+     * Returns a native scalar function for getting the sum of the elements in a numeric collection.
+     * </p>
+     * The value returned by the function is of the same type as elements of its input collection, so there is a risk
+     * of overflow if the sum of the values exceeds the maximum value that the type can represent.
+     *
+     * @param name      the name of the function
+     * @param inputType the type of the collection argument accepted by the returned function
+     * @param <T>       the type of the elements of the collection argument
+     * @return a function returning the sum of the elements in the collection passed as argument
+     */
+    private static <T> NativeScalarFunction makeCollectionSumFunction(String name, CollectionType<T> inputType)
+    {
+        CQL3Type elementsType = elementsType(inputType).asCQL3Type();
+        NativeAggregateFunction function = getSumFunction((CQL3Type.Native) elementsType);
+        return new CollectionAggregationFunction(name, inputType, function);
+    }
+
+    private static NativeAggregateFunction getSumFunction(CQL3Type.Native type)
+    {
+        switch (type)
+        {
+            case TINYINT:
+                return AggregateFcts.sumFunctionForByte;
+            case SMALLINT:
+                return AggregateFcts.sumFunctionForShort;
+            case INT:
+                return AggregateFcts.sumFunctionForInt32;
+            case BIGINT:
+                return AggregateFcts.sumFunctionForLong;
+            case FLOAT:
+                return AggregateFcts.sumFunctionForFloat;
+            case DOUBLE:
+                return AggregateFcts.sumFunctionForDouble;
+            case VARINT:
+                return AggregateFcts.sumFunctionForVarint;
+            case DECIMAL:
+                return AggregateFcts.sumFunctionForDecimal;
+            default:
+                throw new AssertionError("Expected numeric collection but found " + type);
+        }
+    }
+
+    /**
+     * Returns a native scalar function for getting the average of the elements in a numeric collection.
+     * </p>
+     * The average of an empty collection returns zero. The value returned by the function is of the same type as the
+     * elements of its input collection, so if those don't have a decimal part then the returned average won't have a
+     * decimal part either.
+     *
+     * @param name      the name of the function
+     * @param inputType the type of the collection argument accepted by the returned function
+     * @param <T>       the type of the elements of the collection argument
+     * @return a function returning the average value of the elements in the collection passed as argument
+     */
+    private static <T> NativeScalarFunction makeCollectionAvgFunction(String name, CollectionType<T> inputType)
+    {
+        CQL3Type elementsType = elementsType(inputType).asCQL3Type();
+        NativeAggregateFunction function = getAvgFunction((CQL3Type.Native) elementsType);
+        return new CollectionAggregationFunction(name, inputType, function);
+    }
+
+    private static NativeAggregateFunction getAvgFunction(CQL3Type.Native type)
+    {
+        switch (type)
+        {
+            case TINYINT:
+                return AggregateFcts.avgFunctionForByte;
+            case SMALLINT:
+                return AggregateFcts.avgFunctionForShort;
+            case INT:
+                return AggregateFcts.avgFunctionForInt32;
+            case BIGINT:
+                return AggregateFcts.avgFunctionForLong;
+            case FLOAT:
+                return AggregateFcts.avgFunctionForFloat;
+            case DOUBLE:
+                return AggregateFcts.avgFunctionForDouble;
+            case VARINT:
+                return AggregateFcts.avgFunctionForVarint;
+            case DECIMAL:
+                return AggregateFcts.avgFunctionForDecimal;
+            default:
+                throw new AssertionError("Expected numeric collection but found " + type);
+        }
+    }
+
+    /**
+     * @return the type of the elements of the specified collection type.
+     */
+    private static AbstractType<?> elementsType(CollectionType<?> type)
+    {
+        if (type.kind == CollectionType.Kind.LIST)
+        {
+            return ((ListType<?>) type).getElementsType();
+        }
+        else if (type.kind == CollectionType.Kind.SET)
+        {
+            return ((SetType<?>) type).getElementsType();
+        }
+        else
+        {
+            throw new AssertionError("Cannot get the element type of: " + type);
+        }

Review Comment:
   Nit: else are of no use as each statement end with a return



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: pr-unsubscribe@cassandra.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: pr-unsubscribe@cassandra.apache.org
For additional commands, e-mail: pr-help@cassandra.apache.org