You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by em...@apache.org on 2019/08/16 02:43:55 UTC

[arrow] branch master updated: ARROW-6212: [Java] Support vector rank operation

This is an automated email from the ASF dual-hosted git repository.

emkornfield pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 71b32b9  ARROW-6212: [Java] Support vector rank operation
71b32b9 is described below

commit 71b32b9b87fa9825d2112644c7ce15d6f71b9174
Author: liyafan82 <fa...@foxmail.com>
AuthorDate: Thu Aug 15 19:43:36 2019 -0700

    ARROW-6212: [Java] Support vector rank operation
    
    Given an unsorted vector, we want to get the index of the ith smallest element in the vector. This function is supported by the rank operation.
    
    We provide an implementation that gets the index with the desired rank, without sorting the vector (the vector is left intact), and the implementation takes O(n) time, where n is the vector length.
    
    Closes #5066 from liyafan82/fly_0812_rank and squashes the following commits:
    
    623b08531 <liyafan82>  Support vector rank operation
    
    Authored-by: liyafan82 <fa...@foxmail.com>
    Signed-off-by: Micah Kornfield <em...@gmail.com>
---
 .../apache/arrow/algorithm/rank/VectorRank.java    |  89 +++++++++++++
 .../apache/arrow/algorithm/sort/IndexSorter.java   |  16 ++-
 .../arrow/algorithm/rank/TestVectorRank.java       | 146 +++++++++++++++++++++
 3 files changed, 249 insertions(+), 2 deletions(-)

diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/rank/VectorRank.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/rank/VectorRank.java
new file mode 100644
index 0000000..43c9a5b
--- /dev/null
+++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/rank/VectorRank.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.algorithm.rank;
+
+import java.util.stream.IntStream;
+
+import org.apache.arrow.algorithm.sort.IndexSorter;
+import org.apache.arrow.algorithm.sort.VectorValueComparator;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.ValueVector;
+
+/**
+ * Utility for calculating ranks of vector elements.
+ * @param <V> the vector type
+ */
+public class VectorRank<V extends ValueVector> {
+
+  private VectorValueComparator<V> comparator;
+
+  /**
+   * Vector indices.
+   */
+  private IntVector indices;
+
+  private final BufferAllocator allocator;
+
+  /**
+   * Constructs a vector rank utility.
+   * @param allocator the allocator to use.
+   */
+  public VectorRank(BufferAllocator allocator) {
+    this.allocator = allocator;
+  }
+
+  /**
+   * Given a rank r, gets the index of the element that is the rth smallest in the vector.
+   * The operation is performed without changing the vector, and takes O(n) time,
+   * where n is the length of the vector.
+   * @param vector the vector from which to get the element index.
+   * @param comparator the criteria for vector element comparison.
+   * @param rank the rank to determine.
+   * @return the element index with the given rank.
+   */
+  public int indexAtRank(V vector, VectorValueComparator<V> comparator, int rank) {
+    Preconditions.checkArgument(rank >= 0 && rank < vector.getValueCount());
+    try {
+      indices = new IntVector("index vector", allocator);
+      indices.allocateNew(vector.getValueCount());
+      IntStream.range(0, vector.getValueCount()).forEach(i -> indices.set(i, i));
+
+      comparator.attachVector(vector);
+      this.comparator = comparator;
+
+      int pos = getRank(0, vector.getValueCount() - 1, rank);
+      return indices.get(pos);
+    } finally {
+      indices.close();
+    }
+  }
+
+  private int getRank(int low, int high, int rank) {
+    int mid = IndexSorter.partition(low, high, indices, comparator);
+    if (mid < rank) {
+      return getRank(mid + 1, high, rank);
+    } else if (mid > rank) {
+      return getRank(low, mid - 1, rank);
+    } else {
+      // mid == rank
+      return mid;
+    }
+  }
+}
diff --git a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java
index d85eb6f..0f03e5c 100644
--- a/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java
+++ b/java/algorithm/src/main/java/org/apache/arrow/algorithm/sort/IndexSorter.java
@@ -60,13 +60,25 @@ public class IndexSorter<V extends ValueVector> {
 
   private void quickSort(int low, int high) {
     if (low < high) {
-      int mid = partition(low, high);
+      int mid = partition(low, high, indices, comparator);
       quickSort(low, mid - 1);
       quickSort(mid + 1, high);
     }
   }
 
-  private int partition(int low, int high) {
+  /**
+   * Partition a range of values in a vector into two parts, with elements in one part smaller than
+   * elements from the other part. The partition is based on the element indices, so it does
+   * not modify the underlying vector.
+   * @param low the lower bound of the range.
+   * @param high the upper bound of the range.
+   * @param indices vector element indices.
+   * @param comparator criteria for comparison.
+   * @param <T> the vector type.
+   * @return the index of the split point.
+   */
+  public static <T extends ValueVector> int partition(
+          int low, int high, IntVector indices, VectorValueComparator<T> comparator) {
     int pivotIndex = indices.get(low);
 
     while (low < high) {
diff --git a/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java b/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java
new file mode 100644
index 0000000..f139b2e
--- /dev/null
+++ b/java/algorithm/src/test/java/org/apache/arrow/algorithm/rank/TestVectorRank.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.algorithm.rank;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import org.apache.arrow.algorithm.sort.DefaultVectorComparators;
+import org.apache.arrow.algorithm.sort.VectorValueComparator;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+
+import org.apache.arrow.vector.VarCharVector;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Test cases for {@link org.apache.arrow.algorithm.rank.VectorRank}.
+ */
+public class TestVectorRank {
+
+  private BufferAllocator allocator;
+
+  private static final int VECTOR_LENGTH = 10;
+
+  @Before
+  public void prepare() {
+    allocator = new RootAllocator(1024 * 1024);
+  }
+
+  @After
+  public void shutdown() {
+    allocator.close();
+  }
+
+  @Test
+  public void testFixedWidthRank() {
+    VectorRank<IntVector> rank = new VectorRank<>(allocator);
+    try (IntVector vector = new IntVector("int vec", allocator)) {
+      vector.allocateNew(VECTOR_LENGTH);
+      vector.setValueCount(VECTOR_LENGTH);
+
+      vector.set(0, 1);
+      vector.set(1, 5);
+      vector.set(2, 3);
+      vector.set(3, 7);
+      vector.set(4, 9);
+      vector.set(5, 8);
+      vector.set(6, 2);
+      vector.set(7, 0);
+      vector.set(8, 4);
+      vector.set(9, 6);
+
+      VectorValueComparator<IntVector> comparator =
+              DefaultVectorComparators.createDefaultComparator(vector);
+      assertEquals(7, rank.indexAtRank(vector, comparator, 0));
+      assertEquals(0, rank.indexAtRank(vector, comparator, 1));
+      assertEquals(6, rank.indexAtRank(vector, comparator, 2));
+      assertEquals(2, rank.indexAtRank(vector, comparator, 3));
+      assertEquals(8, rank.indexAtRank(vector, comparator, 4));
+      assertEquals(1, rank.indexAtRank(vector, comparator, 5));
+      assertEquals(9, rank.indexAtRank(vector, comparator, 6));
+      assertEquals(3, rank.indexAtRank(vector, comparator, 7));
+      assertEquals(5, rank.indexAtRank(vector, comparator, 8));
+      assertEquals(4, rank.indexAtRank(vector, comparator, 9));
+    }
+  }
+
+  @Test
+  public void testVariableWidthRank() {
+    VectorRank<VarCharVector> rank = new VectorRank<>(allocator);
+    try (VarCharVector vector = new VarCharVector("varchar vec", allocator)) {
+      vector.allocateNew(VECTOR_LENGTH * 5, VECTOR_LENGTH);
+      vector.setValueCount(VECTOR_LENGTH);
+
+      vector.set(0, String.valueOf(1).getBytes());
+      vector.set(1, String.valueOf(5).getBytes());
+      vector.set(2, String.valueOf(3).getBytes());
+      vector.set(3, String.valueOf(7).getBytes());
+      vector.set(4, String.valueOf(9).getBytes());
+      vector.set(5, String.valueOf(8).getBytes());
+      vector.set(6, String.valueOf(2).getBytes());
+      vector.set(7, String.valueOf(0).getBytes());
+      vector.set(8, String.valueOf(4).getBytes());
+      vector.set(9, String.valueOf(6).getBytes());
+
+      VectorValueComparator<VarCharVector> comparator =
+              DefaultVectorComparators.createDefaultComparator(vector);
+
+      assertEquals(7, rank.indexAtRank(vector, comparator, 0));
+      assertEquals(0, rank.indexAtRank(vector, comparator, 1));
+      assertEquals(6, rank.indexAtRank(vector, comparator, 2));
+      assertEquals(2, rank.indexAtRank(vector, comparator, 3));
+      assertEquals(8, rank.indexAtRank(vector, comparator, 4));
+      assertEquals(1, rank.indexAtRank(vector, comparator, 5));
+      assertEquals(9, rank.indexAtRank(vector, comparator, 6));
+      assertEquals(3, rank.indexAtRank(vector, comparator, 7));
+      assertEquals(5, rank.indexAtRank(vector, comparator, 8));
+      assertEquals(4, rank.indexAtRank(vector, comparator, 9));
+    }
+  }
+
+  @Test
+  public void testRankNegative() {
+    VectorRank<IntVector> rank = new VectorRank<>(allocator);
+    try (IntVector vector = new IntVector("int vec", allocator)) {
+      vector.allocateNew(VECTOR_LENGTH);
+      vector.setValueCount(VECTOR_LENGTH);
+
+      vector.set(0, 1);
+      vector.set(1, 5);
+      vector.set(2, 3);
+      vector.set(3, 7);
+      vector.set(4, 9);
+      vector.set(5, 8);
+      vector.set(6, 2);
+      vector.set(7, 0);
+      vector.set(8, 4);
+      vector.set(9, 6);
+
+      VectorValueComparator<IntVector> comparator =
+              DefaultVectorComparators.createDefaultComparator(vector);
+
+      assertThrows(IllegalArgumentException.class, () -> {
+        rank.indexAtRank(vector, comparator, VECTOR_LENGTH + 1);
+      });
+    }
+  }
+}