You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@doris.apache.org by zh...@apache.org on 2020/01/10 06:11:27 UTC

[incubator-doris] branch master updated: Convert from arrow to rowbatch (#2723)

This is an automated email from the ASF dual-hosted git repository.

zhaoc pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 18a11f5  Convert from arrow to rowbatch (#2723)
18a11f5 is described below

commit 18a11f5663c792075956364a778ec510f1cf6c6c
Author: Youngwb <ya...@163.com>
AuthorDate: Fri Jan 10 14:11:15 2020 +0800

    Convert from arrow to rowbatch (#2723)
    
    For #2722
    In our test environment, Doris cluster used 1 fe and 7 be (32C+128G). When using spakr-doris connecter to query a table containing 67 columns, it took about 1 hour for the query to return 69 million rows of data. After the improvement, the same query condition took 2.5 minutes and the query performance was significantly improved
---
 .../apache/doris/spark/serialization/RowBatch.java | 136 +++++++++++++++------
 1 file changed, 96 insertions(+), 40 deletions(-)

diff --git a/extension/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/extension/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index 668e72d..d710fbb 100644
--- a/extension/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++ b/extension/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -73,6 +73,7 @@ public class RowBatch {
     private int offsetInOneBatch = 0;
     private int rowCountInOneBatch = 0;
     private int readRowCount = 0;
+    private List<Row> rowBatch = new ArrayList<>();
     private final ArrowStreamReader arrowStreamReader;
     private final VectorSchemaRoot root;
     private List<FieldVector> fieldVectors;
@@ -115,6 +116,11 @@ public class RowBatch {
                     }
                     offsetInOneBatch = 0;
                     rowCountInOneBatch = root.getRowCount();
+                    // init the rowBatch
+                    for (int i = 0; i < rowCountInOneBatch; ++i) {
+                        rowBatch.add(new Row(fieldVectors.size()));
+                    }
+                    convertArrowToRowBatch();
                     return true;
                 }
             } catch (IOException e) {
@@ -128,98 +134,135 @@ public class RowBatch {
         return false;
     }
 
-    public List<Object> next() throws DorisException {
+    private void addValueToRow(int rowIndex, Object obj) {
+        if (rowIndex > rowCountInOneBatch) {
+            String errMsg = "Get row offset: " + rowIndex + " larger than row size: " +
+                    rowCountInOneBatch;
+            logger.error(errMsg);
+            throw new NoSuchElementException(errMsg);
+        }
+        rowBatch.get(rowIndex).put(obj);
+    }
+
+    public void convertArrowToRowBatch() throws DorisException {
         try {
-            if (!hasNext()) {
-                String errMsg = "Get row offset:" + offsetInOneBatch + " larger than row size: " + rowCountInOneBatch;
-                logger.error(errMsg);
-                throw new NoSuchElementException(errMsg);
-            }
-            Row row = new Row(fieldVectors.size());
-            for (int j = 0; j < fieldVectors.size(); j++) {
-                FieldVector curFieldVector = fieldVectors.get(j);
+            for (int col = 0; col < fieldVectors.size(); col++) {
+                FieldVector curFieldVector = fieldVectors.get(col);
                 Types.MinorType mt = curFieldVector.getMinorType();
-                if (curFieldVector.isNull(offsetInOneBatch)) {
-                    row.put(null);
-                    continue;
-                }
 
-                final String currentType = schema.get(j).getType();
+                final String currentType = schema.get(col).getType();
                 switch (currentType) {
                     case "NULL_TYPE":
-                        row.put(null);
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            addValueToRow(rowIndex, null);
+                        }
                         break;
                     case "BOOLEAN":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.BIT),
                                 typeMismatchMessage(currentType, mt));
                         BitVector bitVector = (BitVector) curFieldVector;
-                        int bit = bitVector.get(offsetInOneBatch);
-                        row.put(bit != 0);
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            Object fieldValue = bitVector.isNull(rowIndex) ? null : bitVector.get(rowIndex) != 0;
+                            addValueToRow(rowIndex, fieldValue);
+                        }
                         break;
                     case "TINYINT":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.TINYINT),
                                 typeMismatchMessage(currentType, mt));
                         TinyIntVector tinyIntVector = (TinyIntVector) curFieldVector;
-                        row.put(tinyIntVector.get(offsetInOneBatch));
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            Object fieldValue = tinyIntVector.isNull(rowIndex) ? null : tinyIntVector.get(rowIndex);
+                            addValueToRow(rowIndex, fieldValue);
+                        }
                         break;
                     case "SMALLINT":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.SMALLINT),
                                 typeMismatchMessage(currentType, mt));
                         SmallIntVector smallIntVector = (SmallIntVector) curFieldVector;
-                        row.put(smallIntVector.get(offsetInOneBatch));
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            Object fieldValue = smallIntVector.isNull(rowIndex) ? null : smallIntVector.get(rowIndex);
+                            addValueToRow(rowIndex, fieldValue);
+                        }
                         break;
                     case "INT":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.INT),
                                 typeMismatchMessage(currentType, mt));
                         IntVector intVector = (IntVector) curFieldVector;
-                        row.put(intVector.get(offsetInOneBatch));
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            Object fieldValue = intVector.isNull(rowIndex) ? null : intVector.get(rowIndex);
+                            addValueToRow(rowIndex, fieldValue);
+                        }
                         break;
                     case "BIGINT":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.BIGINT),
                                 typeMismatchMessage(currentType, mt));
                         BigIntVector bigIntVector = (BigIntVector) curFieldVector;
-                        row.put(bigIntVector.get(offsetInOneBatch));
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            Object fieldValue = bigIntVector.isNull(rowIndex) ? null : bigIntVector.get(rowIndex);
+                            addValueToRow(rowIndex, fieldValue);
+                        }
                         break;
                     case "FLOAT":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT4),
                                 typeMismatchMessage(currentType, mt));
                         Float4Vector float4Vector = (Float4Vector) curFieldVector;
-                        row.put(float4Vector.get(offsetInOneBatch));
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            Object fieldValue = float4Vector.isNull(rowIndex) ? null : float4Vector.get(rowIndex);
+                            addValueToRow(rowIndex, fieldValue);
+                        }
                         break;
                     case "TIME":
                     case "DOUBLE":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT8),
                                 typeMismatchMessage(currentType, mt));
                         Float8Vector float8Vector = (Float8Vector) curFieldVector;
-                        row.put(float8Vector.get(offsetInOneBatch));
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            Object fieldValue = float8Vector.isNull(rowIndex) ? null : float8Vector.get(rowIndex);
+                            addValueToRow(rowIndex, fieldValue);
+                        }
                         break;
                     case "BINARY":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.VARBINARY),
                                 typeMismatchMessage(currentType, mt));
                         VarBinaryVector varBinaryVector = (VarBinaryVector) curFieldVector;
-                        row.put(varBinaryVector.get(offsetInOneBatch));
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            Object fieldValue = varBinaryVector.isNull(rowIndex) ? null : varBinaryVector.get(rowIndex);
+                            addValueToRow(rowIndex, fieldValue);
+                        }
                         break;
                     case "DECIMAL":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR),
                                 typeMismatchMessage(currentType, mt));
                         VarCharVector varCharVectorForDecimal = (VarCharVector) curFieldVector;
-                        String decimalValue = new String(varCharVectorForDecimal.get(offsetInOneBatch));
-                        Decimal decimal = new Decimal();
-                        try {
-                            decimal.set(new scala.math.BigDecimal(new BigDecimal(decimalValue)));
-                        } catch (NumberFormatException e) {
-                            String errMsg = "Decimal response result '" + decimalValue + "' is illegal.";
-                            logger.error(errMsg, e);
-                            throw new DorisException(errMsg);
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            if (varCharVectorForDecimal.isNull(rowIndex)) {
+                                addValueToRow(rowIndex, null);
+                                continue;
+                            }
+                            String decimalValue = new String(varCharVectorForDecimal.get(rowIndex));
+                            Decimal decimal = new Decimal();
+                            try {
+                                decimal.set(new scala.math.BigDecimal(new BigDecimal(decimalValue)));
+                            } catch (NumberFormatException e) {
+                                String errMsg = "Decimal response result '" + decimalValue + "' is illegal.";
+                                logger.error(errMsg, e);
+                                throw new DorisException(errMsg);
+                            }
+                            addValueToRow(rowIndex, decimal);
                         }
-                        row.put(decimal);
                         break;
                     case "DECIMALV2":
                         Preconditions.checkArgument(mt.equals(Types.MinorType.DECIMAL),
                                 typeMismatchMessage(currentType, mt));
                         DecimalVector decimalVector = (DecimalVector) curFieldVector;
-                        Decimal decimalV2 = Decimal.apply(decimalVector.getObject(offsetInOneBatch));
-                        row.put(decimalV2);
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            if (decimalVector.isNull(rowIndex)) {
+                                addValueToRow(rowIndex, null);
+                                continue;
+                            }
+                            Decimal decimalV2 = Decimal.apply(decimalVector.getObject(rowIndex));
+                            addValueToRow(rowIndex, decimalV2);
+                        }
                         break;
                     case "DATE":
                     case "DATETIME":
@@ -229,23 +272,36 @@ public class RowBatch {
                         Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR),
                                 typeMismatchMessage(currentType, mt));
                         VarCharVector varCharVector = (VarCharVector) curFieldVector;
-                        String value = new String(varCharVector.get(offsetInOneBatch));
-                        row.put(value);
+                        for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
+                            if (varCharVector.isNull(rowIndex)) {
+                                addValueToRow(rowIndex, null);
+                                continue;
+                            }
+                            String value = new String(varCharVector.get(rowIndex));
+                            addValueToRow(rowIndex, value);
+                        }
                         break;
                     default:
-                        String errMsg = "Unsupported type " + schema.get(j).getType();
+                        String errMsg = "Unsupported type " + schema.get(col).getType();
                         logger.error(errMsg);
                         throw new DorisException(errMsg);
                 }
             }
-            offsetInOneBatch++;
-            return row.getCols();
         } catch (Exception e) {
             close();
             throw e;
         }
     }
 
+    public List<Object> next() throws DorisException {
+        if (!hasNext()) {
+            String errMsg = "Get row offset:" + offsetInOneBatch + " larger than row size: " + rowCountInOneBatch;
+            logger.error(errMsg);
+            throw new NoSuchElementException(errMsg);
+        }
+        return rowBatch.get(offsetInOneBatch++).getCols();
+    }
+
     private String typeMismatchMessage(final String sparkType, final Types.MinorType arrowType) {
         final String messageTemplate = "Spark type is %1$s, but arrow type is %2$s.";
         return String.format(messageTemplate, sparkType, arrowType.name());


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@doris.apache.org
For additional commands, e-mail: commits-help@doris.apache.org