You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by bl...@apache.org on 2019/09/24 17:40:20 UTC

[flink] branch master updated: [FLINK-14129][hive] HiveTableSource should implement ProjectableTableSource

This is an automated email from the ASF dual-hosted git repository.

bli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 02e2603  [FLINK-14129][hive] HiveTableSource should implement ProjectableTableSource
02e2603 is described below

commit 02e26036e6d74e6c76ef1d793ac141c820715023
Author: Rui Li <li...@apache.org>
AuthorDate: Thu Sep 19 22:15:26 2019 +0800

    [FLINK-14129][hive] HiveTableSource should implement ProjectableTableSource
    
    Implement ProjectableTableSource for HiveTableSource.
    
    This closes #9721.
---
 .../connectors/hive/HiveTableInputFormat.java      | 57 ++++++++++++++--------
 .../flink/connectors/hive/HiveTableSource.java     | 47 ++++++++++++++----
 .../flink/connectors/hive/HiveTableSourceTest.java | 37 ++++++++++++--
 3 files changed, 109 insertions(+), 32 deletions(-)

diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveTableInputFormat.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveTableInputFormat.java
index 8a38fb3..f00288f 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveTableInputFormat.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveTableInputFormat.java
@@ -52,6 +52,7 @@ import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Properties;
+import java.util.stream.IntStream;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.hadoop.mapreduce.lib.input.FileInputFormat.INPUT_DIR;
@@ -73,9 +74,6 @@ public class HiveTableInputFormat extends HadoopInputFormatCommonBase<Row, HiveT
 	protected transient boolean fetched = false;
 	protected transient boolean hasNext;
 
-	// arity of each row, including partition columns
-	private int rowArity;
-
 	//Necessary info to init deserializer
 	private List<String> partitionColNames;
 	//For non-partition hive table, partitions only contains one partition which partitionValues is empty.
@@ -88,17 +86,25 @@ public class HiveTableInputFormat extends HadoopInputFormatCommonBase<Row, HiveT
 	private transient InputFormat mapredInputFormat;
 	private transient HiveTablePartition hiveTablePartition;
 
+	// indices of fields to be returned, with projection applied (if any)
+	// TODO: push projection into underlying input format that supports it
+	private int[] fields;
+	// Remember whether a row instance is reused. No need to set partition fields for reused rows
+	private transient boolean rowReused;
+
 	public HiveTableInputFormat(
 			JobConf jobConf,
 			CatalogTable catalogTable,
-			List<HiveTablePartition> partitions) {
+			List<HiveTablePartition> partitions,
+			int[] projectedFields) {
 		super(jobConf.getCredentials());
 		checkNotNull(catalogTable, "catalogTable can not be null.");
 		this.partitions = checkNotNull(partitions, "partitions can not be null.");
 
 		this.jobConf = new JobConf(jobConf);
 		this.partitionColNames = catalogTable.getPartitionKeys();
-		rowArity = catalogTable.getSchema().getFieldCount();
+		int rowArity = catalogTable.getSchema().getFieldCount();
+		fields = projectedFields != null ? projectedFields : IntStream.range(0, rowArity).toArray();
 	}
 
 	@Override
@@ -137,6 +143,7 @@ public class HiveTableInputFormat extends HadoopInputFormatCommonBase<Row, HiveT
 		} catch (Exception e) {
 			throw new FlinkHiveException("Error happens when deserialize from storage file.", e);
 		}
+		rowReused = false;
 	}
 
 	@Override
@@ -203,30 +210,40 @@ public class HiveTableInputFormat extends HadoopInputFormatCommonBase<Row, HiveT
 	}
 
 	@Override
-	public Row nextRecord(Row ignore) throws IOException {
+	public Row nextRecord(Row reuse) throws IOException {
 		if (reachedEnd()) {
 			return null;
 		}
-		Row row = new Row(rowArity);
 		try {
 			//Use HiveDeserializer to deserialize an object out of a Writable blob
 			Object hiveRowStruct = deserializer.deserialize(value);
-			int index = 0;
-			for (; index < structFields.size(); index++) {
-				StructField structField = structFields.get(index);
-				Object object = HiveInspectors.toFlinkObject(structField.getFieldObjectInspector(),
-						structObjectInspector.getStructFieldData(hiveRowStruct, structField));
-				row.setField(index, object);
-			}
-			for (String partition : partitionColNames){
-				row.setField(index++, hiveTablePartition.getPartitionSpec().get(partition));
+			for (int i = 0; i < fields.length; i++) {
+				// set non-partition columns
+				if (fields[i] < structFields.size()) {
+					StructField structField = structFields.get(fields[i]);
+					Object object = HiveInspectors.toFlinkObject(structField.getFieldObjectInspector(),
+							structObjectInspector.getStructFieldData(hiveRowStruct, structField));
+					reuse.setField(i, object);
+				}
 			}
-		} catch (Exception e){
+		} catch (Exception e) {
 			logger.error("Error happens when converting hive data type to flink data type.");
 			throw new FlinkHiveException(e);
 		}
+		if (!rowReused) {
+			// set partition columns
+			if (!partitionColNames.isEmpty()) {
+				for (int i = 0; i < fields.length; i++) {
+					if (fields[i] >= structFields.size()) {
+						String partition = partitionColNames.get(fields[i] - structFields.size());
+						reuse.setField(i, hiveTablePartition.getPartitionSpec().get(partition));
+					}
+				}
+			}
+			rowReused = true;
+		}
 		this.fetched = false;
-		return row;
+		return reuse;
 	}
 
 	// --------------------------------------------------------------------------------------------
@@ -236,9 +253,9 @@ public class HiveTableInputFormat extends HadoopInputFormatCommonBase<Row, HiveT
 	private void writeObject(ObjectOutputStream out) throws IOException {
 		super.write(out);
 		jobConf.write(out);
-		out.writeObject(rowArity);
 		out.writeObject(partitionColNames);
 		out.writeObject(partitions);
+		out.writeObject(fields);
 	}
 
 	@SuppressWarnings("unchecked")
@@ -253,8 +270,8 @@ public class HiveTableInputFormat extends HadoopInputFormatCommonBase<Row, HiveT
 		if (currentUserCreds != null) {
 			jobConf.getCredentials().addAll(currentUserCreds);
 		}
-		rowArity = (int) in.readObject();
 		partitionColNames = (List<String>) in.readObject();
 		partitions = (List<HiveTablePartition>) in.readObject();
+		fields = (int[]) in.readObject();
 	}
 }
diff --git a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveTableSource.java b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveTableSource.java
index 451bd93..cfa1e62 100644
--- a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveTableSource.java
+++ b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/connectors/hive/HiveTableSource.java
@@ -29,6 +29,7 @@ import org.apache.flink.table.catalog.hive.client.HiveMetastoreClientWrapper;
 import org.apache.flink.table.catalog.hive.descriptors.HiveCatalogValidator;
 import org.apache.flink.table.sources.InputFormatTableSource;
 import org.apache.flink.table.sources.PartitionableTableSource;
+import org.apache.flink.table.sources.ProjectableTableSource;
 import org.apache.flink.table.sources.TableSource;
 import org.apache.flink.table.types.DataType;
 import org.apache.flink.table.types.logical.LogicalTypeRoot;
@@ -45,6 +46,7 @@ import org.slf4j.LoggerFactory;
 
 import java.sql.Date;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -52,7 +54,7 @@ import java.util.Map;
 /**
  * A TableSource implementation to read data from Hive tables.
  */
-public class HiveTableSource extends InputFormatTableSource<Row> implements PartitionableTableSource {
+public class HiveTableSource extends InputFormatTableSource<Row> implements PartitionableTableSource, ProjectableTableSource<Row> {
 
 	private static Logger logger = LoggerFactory.getLogger(HiveTableSource.class);
 
@@ -66,6 +68,7 @@ public class HiveTableSource extends InputFormatTableSource<Row> implements Part
 	private Map<Map<String, String>, HiveTablePartition> partitionSpec2HiveTablePartition = new HashMap<>();
 	private boolean initAllPartitions;
 	private boolean partitionPruned;
+	private int[] projectedFields;
 
 	public HiveTableSource(JobConf jobConf, ObjectPath tablePath, CatalogTable catalogTable) {
 		this.jobConf = Preconditions.checkNotNull(jobConf);
@@ -77,18 +80,23 @@ public class HiveTableSource extends InputFormatTableSource<Row> implements Part
 		partitionPruned = false;
 	}
 
+	// A constructor mainly used to create copies during optimizations like partition pruning and projection push down.
 	private HiveTableSource(JobConf jobConf, ObjectPath tablePath, CatalogTable catalogTable,
 							List<HiveTablePartition> allHivePartitions,
 							String hiveVersion,
-							List<Map<String, String>> partitionList) {
+							List<Map<String, String>> partitionList,
+							boolean initAllPartitions,
+							boolean partitionPruned,
+							int[] projectedFields) {
 		this.jobConf = Preconditions.checkNotNull(jobConf);
 		this.tablePath = Preconditions.checkNotNull(tablePath);
 		this.catalogTable = Preconditions.checkNotNull(catalogTable);
 		this.allHivePartitions = allHivePartitions;
 		this.hiveVersion = hiveVersion;
 		this.partitionList = partitionList;
-		this.initAllPartitions = true;
-		partitionPruned = true;
+		this.initAllPartitions = initAllPartitions;
+		this.partitionPruned = partitionPruned;
+		this.projectedFields = projectedFields;
 	}
 
 	@Override
@@ -96,7 +104,7 @@ public class HiveTableSource extends InputFormatTableSource<Row> implements Part
 		if (!initAllPartitions) {
 			initAllPartitions();
 		}
-		return new HiveTableInputFormat(jobConf, catalogTable, allHivePartitions);
+		return new HiveTableInputFormat(jobConf, catalogTable, allHivePartitions, projectedFields);
 	}
 
 	@Override
@@ -106,7 +114,17 @@ public class HiveTableSource extends InputFormatTableSource<Row> implements Part
 
 	@Override
 	public DataType getProducedDataType() {
-		return getTableSchema().toRowDataType();
+		TableSchema originSchema = getTableSchema();
+		if (projectedFields == null) {
+			return originSchema.toRowDataType();
+		}
+		String[] names = new String[projectedFields.length];
+		DataType[] types = new DataType[projectedFields.length];
+		for (int i = 0; i < projectedFields.length; i++) {
+			names[i] = originSchema.getFieldName(projectedFields[i]).get();
+			types[i] = originSchema.getFieldDataType(projectedFields[i]).get();
+		}
+		return TableSchema.builder().fields(names, types).build().toRowDataType();
 	}
 
 	@Override
@@ -140,7 +158,8 @@ public class HiveTableSource extends InputFormatTableSource<Row> implements Part
 																			"partition spec %s", partitionSpec));
 				remainingHivePartitions.add(hiveTablePartition);
 			}
-			return new HiveTableSource(jobConf, tablePath, catalogTable, remainingHivePartitions, hiveVersion, partitionList);
+			return new HiveTableSource(jobConf, tablePath, catalogTable, remainingHivePartitions,
+					hiveVersion, partitionList, true, true, projectedFields);
 		}
 	}
 
@@ -223,7 +242,17 @@ public class HiveTableSource extends InputFormatTableSource<Row> implements Part
 
 	@Override
 	public String explainSource() {
-		return super.explainSource() + String.format(" TablePath: %s, PartitionPruned: %s, PartitionNums: %d",
-													tablePath.getFullName(), partitionPruned, null == allHivePartitions ? 0 : allHivePartitions.size());
+		String explain = String.format(" TablePath: %s, PartitionPruned: %s, PartitionNums: %d",
+				tablePath.getFullName(), partitionPruned, null == allHivePartitions ? 0 : allHivePartitions.size());
+		if (projectedFields != null) {
+			explain += ", ProjectedFields: " + Arrays.toString(projectedFields);
+		}
+		return super.explainSource() + explain;
+	}
+
+	@Override
+	public TableSource<Row> projectFields(int[] fields) {
+		return new HiveTableSource(jobConf, tablePath, catalogTable, allHivePartitions, hiveVersion,
+				partitionList, initAllPartitions, partitionPruned, fields);
 	}
 }
diff --git a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveTableSourceTest.java b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveTableSourceTest.java
index 3920d0f..878cd68 100644
--- a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveTableSourceTest.java
+++ b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveTableSourceTest.java
@@ -181,9 +181,9 @@ public class HiveTableSourceTest {
 		String abstractSyntaxTree = explain[1];
 		String optimizedLogicalPlan = explain[2];
 		String physicalExecutionPlan = explain[3];
-		assertTrue(abstractSyntaxTree.contains("HiveTableSource(year, value, pt) TablePath: source_db.test_table_pt_1, PartitionPruned: false, PartitionNums: 2]"));
-		assertTrue(optimizedLogicalPlan.contains("HiveTableSource(year, value, pt) TablePath: source_db.test_table_pt_1, PartitionPruned: true, PartitionNums: 1]"));
-		assertTrue(physicalExecutionPlan.contains("HiveTableSource(year, value, pt) TablePath: source_db.test_table_pt_1, PartitionPruned: true, PartitionNums: 1]"));
+		assertTrue(abstractSyntaxTree.contains("HiveTableSource(year, value, pt) TablePath: source_db.test_table_pt_1, PartitionPruned: false, PartitionNums: 2"));
+		assertTrue(optimizedLogicalPlan.contains("HiveTableSource(year, value, pt) TablePath: source_db.test_table_pt_1, PartitionPruned: true, PartitionNums: 1"));
+		assertTrue(physicalExecutionPlan.contains("HiveTableSource(year, value, pt) TablePath: source_db.test_table_pt_1, PartitionPruned: true, PartitionNums: 1"));
 		// second check execute results
 		List<Row> rows = JavaConverters.seqAsJavaListConverter(TableUtil.collect((TableImpl) src)).asJava();
 		assertEquals(2, rows.size());
@@ -191,4 +191,35 @@ public class HiveTableSourceTest {
 		assertArrayEquals(new String[]{"2014,3,0", "2014,4,0"}, rowStrings);
 	}
 
+	@Test
+	public void testProjectionPushDown() throws Exception {
+		hiveShell.execute("create table src(x int,y string) partitioned by (p1 bigint, p2 string)");
+		final String catalogName = "hive";
+		try {
+			hiveShell.insertInto("default", "src")
+					.addRow(1, "a", 2013, "2013")
+					.addRow(2, "b", 2013, "2013")
+					.addRow(3, "c", 2014, "2014")
+					.commit();
+			TableEnvironment tableEnv = HiveTestUtils.createTableEnv();
+			tableEnv.registerCatalog(catalogName, hiveCatalog);
+			Table table = tableEnv.sqlQuery("select p1, count(y) from hive.`default`.src group by p1");
+			String[] explain = tableEnv.explain(table).split("==.*==\n");
+			assertEquals(4, explain.length);
+			String logicalPlan = explain[2];
+			String physicalPlan = explain[3];
+			String expectedExplain =
+					"HiveTableSource(x, y, p1, p2) TablePath: default.src, PartitionPruned: false, PartitionNums: 2, ProjectedFields: [2, 1]";
+			assertTrue(logicalPlan.contains(expectedExplain));
+			assertTrue(physicalPlan.contains(expectedExplain));
+
+			List<Row> rows = JavaConverters.seqAsJavaListConverter(TableUtil.collect((TableImpl) table)).asJava();
+			assertEquals(2, rows.size());
+			Object[] rowStrings = rows.stream().map(Row::toString).sorted().toArray();
+			assertArrayEquals(new String[]{"2013,2", "2014,1"}, rowStrings);
+		} finally {
+			hiveShell.execute("drop table src");
+		}
+	}
+
 }