You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ta...@apache.org on 2017/08/02 05:08:45 UTC
[22/59] beam git commit: rename package org.apache.beam.dsls.sql to
org.apache.beam.sdk.extensions.sql
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b5482d/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/rel/BeamValuesRelTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/rel/BeamValuesRelTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/rel/BeamValuesRelTest.java
new file mode 100644
index 0000000..ace1a3e
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/rel/BeamValuesRelTest.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.extensions.sql.rel;
+
+import java.sql.Types;
+import org.apache.beam.sdk.extensions.sql.BeamSqlCli;
+import org.apache.beam.sdk.extensions.sql.BeamSqlEnv;
+import org.apache.beam.sdk.extensions.sql.TestUtils;
+import org.apache.beam.sdk.extensions.sql.mock.MockedBoundedTable;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRow;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Test for {@code BeamValuesRel}.
+ */
+public class BeamValuesRelTest {
+ static BeamSqlEnv sqlEnv = new BeamSqlEnv();
+
+ @Rule
+ public final TestPipeline pipeline = TestPipeline.create();
+
+ @BeforeClass
+ public static void prepare() {
+ sqlEnv.registerTable("string_table",
+ MockedBoundedTable.of(
+ Types.VARCHAR, "name",
+ Types.VARCHAR, "description"
+ )
+ );
+ sqlEnv.registerTable("int_table",
+ MockedBoundedTable.of(
+ Types.INTEGER, "c0",
+ Types.INTEGER, "c1"
+ )
+ );
+ }
+
+ @Test
+ public void testValues() throws Exception {
+ String sql = "insert into string_table(name, description) values "
+ + "('hello', 'world'), ('james', 'bond')";
+ PCollection<BeamSqlRow> rows = BeamSqlCli.compilePipeline(sql, pipeline, sqlEnv);
+ PAssert.that(rows).containsInAnyOrder(
+ TestUtils.RowsBuilder.of(
+ Types.VARCHAR, "name",
+ Types.VARCHAR, "description"
+ ).addRows(
+ "hello", "world",
+ "james", "bond"
+ ).getRows()
+ );
+ pipeline.run();
+ }
+
+ @Test
+ public void testValues_castInt() throws Exception {
+ String sql = "insert into int_table (c0, c1) values(cast(1 as int), cast(2 as int))";
+ PCollection<BeamSqlRow> rows = BeamSqlCli.compilePipeline(sql, pipeline, sqlEnv);
+ PAssert.that(rows).containsInAnyOrder(
+ TestUtils.RowsBuilder.of(
+ Types.INTEGER, "c0",
+ Types.INTEGER, "c1"
+ ).addRows(
+ 1, 2
+ ).getRows()
+ );
+ pipeline.run();
+ }
+
+ @Test
+ public void testValues_onlySelect() throws Exception {
+ String sql = "select 1, '1'";
+ PCollection<BeamSqlRow> rows = BeamSqlCli.compilePipeline(sql, pipeline, sqlEnv);
+ PAssert.that(rows).containsInAnyOrder(
+ TestUtils.RowsBuilder.of(
+ Types.INTEGER, "EXPR$0",
+ Types.CHAR, "EXPR$1"
+ ).addRows(
+ 1, "1"
+ ).getRows()
+ );
+ pipeline.run();
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b5482d/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/rel/CheckSize.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/rel/CheckSize.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/rel/CheckSize.java
new file mode 100644
index 0000000..f369076
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/rel/CheckSize.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.extensions.sql.rel;
+
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRow;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.junit.Assert;
+
+/**
+ * Utility class to check size of BeamSQLRow iterable.
+ */
+public class CheckSize implements SerializableFunction<Iterable<BeamSqlRow>, Void> {
+ private int size;
+ public CheckSize(int size) {
+ this.size = size;
+ }
+ @Override public Void apply(Iterable<BeamSqlRow> input) {
+ int count = 0;
+ for (BeamSqlRow row : input) {
+ count++;
+ }
+ Assert.assertEquals(size, count);
+ return null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b5482d/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlRowCoderTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlRowCoderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlRowCoderTest.java
new file mode 100644
index 0000000..553420b
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/BeamSqlRowCoderTest.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.extensions.sql.schema;
+
+import java.math.BigDecimal;
+import java.util.Date;
+import java.util.GregorianCalendar;
+import org.apache.beam.sdk.extensions.sql.utils.CalciteUtils;
+import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelDataTypeSystem;
+import org.apache.calcite.rel.type.RelProtoDataType;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.junit.Test;
+
+/**
+ * Tests for BeamSqlRowCoder.
+ */
+public class BeamSqlRowCoderTest {
+
+ @Test
+ public void encodeAndDecode() throws Exception {
+ final RelProtoDataType protoRowType = new RelProtoDataType() {
+ @Override
+ public RelDataType apply(RelDataTypeFactory a0) {
+ return a0.builder()
+ .add("col_tinyint", SqlTypeName.TINYINT)
+ .add("col_smallint", SqlTypeName.SMALLINT)
+ .add("col_integer", SqlTypeName.INTEGER)
+ .add("col_bigint", SqlTypeName.BIGINT)
+ .add("col_float", SqlTypeName.FLOAT)
+ .add("col_double", SqlTypeName.DOUBLE)
+ .add("col_decimal", SqlTypeName.DECIMAL)
+ .add("col_string_varchar", SqlTypeName.VARCHAR)
+ .add("col_time", SqlTypeName.TIME)
+ .add("col_timestamp", SqlTypeName.TIMESTAMP)
+ .add("col_boolean", SqlTypeName.BOOLEAN)
+ .build();
+ }
+ };
+
+ BeamSqlRowType beamSQLRowType = CalciteUtils.toBeamRowType(
+ protoRowType.apply(new JavaTypeFactoryImpl(
+ RelDataTypeSystem.DEFAULT)));
+ BeamSqlRow row = new BeamSqlRow(beamSQLRowType);
+ row.addField("col_tinyint", Byte.valueOf("1"));
+ row.addField("col_smallint", Short.valueOf("1"));
+ row.addField("col_integer", 1);
+ row.addField("col_bigint", 1L);
+ row.addField("col_float", 1.1F);
+ row.addField("col_double", 1.1);
+ row.addField("col_decimal", BigDecimal.ZERO);
+ row.addField("col_string_varchar", "hello");
+ GregorianCalendar calendar = new GregorianCalendar();
+ calendar.setTime(new Date());
+ row.addField("col_time", calendar);
+ row.addField("col_timestamp", new Date());
+ row.addField("col_boolean", true);
+
+
+ BeamSqlRowCoder coder = new BeamSqlRowCoder(beamSQLRowType);
+ CoderProperties.coderDecodeEncodeEqual(coder, row);
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b5482d/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/kafka/BeamKafkaCSVTableTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/kafka/BeamKafkaCSVTableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/kafka/BeamKafkaCSVTableTest.java
new file mode 100644
index 0000000..4eccc44
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/kafka/BeamKafkaCSVTableTest.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.extensions.sql.schema.kafka;
+
+import java.io.Serializable;
+import org.apache.beam.sdk.extensions.sql.planner.BeamQueryPlanner;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRow;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRowType;
+import org.apache.beam.sdk.extensions.sql.utils.CalciteUtils;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelProtoDataType;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.commons.csv.CSVFormat;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Test for BeamKafkaCSVTable.
+ */
+public class BeamKafkaCSVTableTest {
+ @Rule
+ public TestPipeline pipeline = TestPipeline.create();
+ public static BeamSqlRow row1 = new BeamSqlRow(genRowType());
+ public static BeamSqlRow row2 = new BeamSqlRow(genRowType());
+
+ @BeforeClass
+ public static void setUp() {
+ row1.addField(0, 1L);
+ row1.addField(1, 1);
+ row1.addField(2, 1.0);
+
+ row2.addField(0, 2L);
+ row2.addField(1, 2);
+ row2.addField(2, 2.0);
+ }
+
+ @Test public void testCsvRecorderDecoder() throws Exception {
+ PCollection<BeamSqlRow> result = pipeline
+ .apply(
+ Create.of("1,\"1\",1.0", "2,2,2.0")
+ )
+ .apply(ParDo.of(new String2KvBytes()))
+ .apply(
+ new BeamKafkaCSVTable.CsvRecorderDecoder(genRowType(), CSVFormat.DEFAULT)
+ );
+
+ PAssert.that(result).containsInAnyOrder(row1, row2);
+
+ pipeline.run();
+ }
+
+ @Test public void testCsvRecorderEncoder() throws Exception {
+ PCollection<BeamSqlRow> result = pipeline
+ .apply(
+ Create.of(row1, row2)
+ )
+ .apply(
+ new BeamKafkaCSVTable.CsvRecorderEncoder(genRowType(), CSVFormat.DEFAULT)
+ ).apply(
+ new BeamKafkaCSVTable.CsvRecorderDecoder(genRowType(), CSVFormat.DEFAULT)
+ );
+
+ PAssert.that(result).containsInAnyOrder(row1, row2);
+
+ pipeline.run();
+ }
+
+ private static BeamSqlRowType genRowType() {
+ return CalciteUtils.toBeamRowType(new RelProtoDataType() {
+
+ @Override public RelDataType apply(RelDataTypeFactory a0) {
+ return a0.builder().add("order_id", SqlTypeName.BIGINT)
+ .add("site_id", SqlTypeName.INTEGER)
+ .add("price", SqlTypeName.DOUBLE).build();
+ }
+ }.apply(BeamQueryPlanner.TYPE_FACTORY));
+ }
+
+ private static class String2KvBytes extends DoFn<String, KV<byte[], byte[]>>
+ implements Serializable {
+ @ProcessElement
+ public void processElement(ProcessContext ctx) {
+ ctx.output(KV.of(new byte[] {}, ctx.element().getBytes()));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b5482d/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/text/BeamTextCSVTableTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/text/BeamTextCSVTableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/text/BeamTextCSVTableTest.java
new file mode 100644
index 0000000..9dc599f
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/text/BeamTextCSVTableTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.sdk.extensions.sql.schema.text;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.PrintStream;
+import java.nio.file.FileVisitResult;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.SimpleFileVisitor;
+import java.nio.file.attribute.BasicFileAttributes;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.beam.sdk.extensions.sql.planner.BeamQueryPlanner;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRow;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRowType;
+import org.apache.beam.sdk.extensions.sql.utils.CalciteUtils;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rel.type.RelProtoDataType;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.commons.csv.CSVFormat;
+import org.apache.commons.csv.CSVPrinter;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Tests for {@code BeamTextCSVTable}.
+ */
+public class BeamTextCSVTableTest {
+
+ @Rule public TestPipeline pipeline = TestPipeline.create();
+ @Rule public TestPipeline pipeline2 = TestPipeline.create();
+
+ /**
+ * testData.
+ *
+ * <p>
+ * The types of the csv fields are:
+ * integer,bigint,float,double,string
+ * </p>
+ */
+ private static Object[] data1 = new Object[] { 1, 1L, 1.1F, 1.1, "james" };
+ private static Object[] data2 = new Object[] { 2, 2L, 2.2F, 2.2, "bond" };
+
+ private static List<Object[]> testData = Arrays.asList(data1, data2);
+ private static List<BeamSqlRow> testDataRows = new ArrayList<BeamSqlRow>() {{
+ for (Object[] data : testData) {
+ add(buildRow(data));
+ }
+ }};
+
+ private static Path tempFolder;
+ private static File readerSourceFile;
+ private static File writerTargetFile;
+
+ @Test public void testBuildIOReader() {
+ PCollection<BeamSqlRow> rows = new BeamTextCSVTable(buildBeamSqlRowType(),
+ readerSourceFile.getAbsolutePath()).buildIOReader(pipeline);
+ PAssert.that(rows).containsInAnyOrder(testDataRows);
+ pipeline.run();
+ }
+
+ @Test public void testBuildIOWriter() {
+ new BeamTextCSVTable(buildBeamSqlRowType(),
+ readerSourceFile.getAbsolutePath()).buildIOReader(pipeline)
+ .apply(new BeamTextCSVTable(buildBeamSqlRowType(), writerTargetFile.getAbsolutePath())
+ .buildIOWriter());
+ pipeline.run();
+
+ PCollection<BeamSqlRow> rows = new BeamTextCSVTable(buildBeamSqlRowType(),
+ writerTargetFile.getAbsolutePath()).buildIOReader(pipeline2);
+
+ // confirm the two reads match
+ PAssert.that(rows).containsInAnyOrder(testDataRows);
+ pipeline2.run();
+ }
+
+ @BeforeClass public static void setUp() throws IOException {
+ tempFolder = Files.createTempDirectory("BeamTextTableTest");
+ readerSourceFile = writeToFile(testData, "readerSourceFile.txt");
+ writerTargetFile = writeToFile(testData, "writerTargetFile.txt");
+ }
+
+ @AfterClass public static void teardownClass() throws IOException {
+ Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() {
+
+ @Override public FileVisitResult visitFile(Path file, BasicFileAttributes attrs)
+ throws IOException {
+ Files.delete(file);
+ return FileVisitResult.CONTINUE;
+ }
+
+ @Override public FileVisitResult postVisitDirectory(Path dir, IOException exc)
+ throws IOException {
+ Files.delete(dir);
+ return FileVisitResult.CONTINUE;
+ }
+ });
+ }
+
+ private static File writeToFile(List<Object[]> rows, String filename) throws IOException {
+ File file = tempFolder.resolve(filename).toFile();
+ OutputStream output = new FileOutputStream(file);
+ writeToStreamAndClose(rows, output);
+ return file;
+ }
+
+ /**
+ * Helper that writes the given lines (adding a newline in between) to a stream, then closes the
+ * stream.
+ */
+ private static void writeToStreamAndClose(List<Object[]> rows, OutputStream outputStream) {
+ try (PrintStream writer = new PrintStream(outputStream)) {
+ CSVPrinter printer = CSVFormat.DEFAULT.print(writer);
+ for (Object[] row : rows) {
+ for (Object field : row) {
+ printer.print(field);
+ }
+ printer.println();
+ }
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ private RelProtoDataType buildRowType() {
+ return new RelProtoDataType() {
+
+ @Override public RelDataType apply(RelDataTypeFactory a0) {
+ return a0.builder().add("id", SqlTypeName.INTEGER).add("order_id", SqlTypeName.BIGINT)
+ .add("price", SqlTypeName.FLOAT).add("amount", SqlTypeName.DOUBLE)
+ .add("user_name", SqlTypeName.VARCHAR).build();
+ }
+ };
+ }
+
+ private static RelDataType buildRelDataType() {
+ return BeamQueryPlanner.TYPE_FACTORY.builder().add("id", SqlTypeName.INTEGER)
+ .add("order_id", SqlTypeName.BIGINT).add("price", SqlTypeName.FLOAT)
+ .add("amount", SqlTypeName.DOUBLE).add("user_name", SqlTypeName.VARCHAR).build();
+ }
+
+ private static BeamSqlRowType buildBeamSqlRowType() {
+ return CalciteUtils.toBeamRowType(buildRelDataType());
+ }
+
+ private static BeamSqlRow buildRow(Object[] data) {
+ return new BeamSqlRow(buildBeamSqlRowType(), Arrays.asList(data));
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b5482d/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/transform/BeamAggregationTransformTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/transform/BeamAggregationTransformTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/transform/BeamAggregationTransformTest.java
new file mode 100644
index 0000000..571c8ef
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/transform/BeamAggregationTransformTest.java
@@ -0,0 +1,453 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.schema.transform;
+
+import java.text.ParseException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.extensions.sql.planner.BeamQueryPlanner;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRow;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRowCoder;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRowType;
+import org.apache.beam.sdk.extensions.sql.transform.BeamAggregationTransforms;
+import org.apache.beam.sdk.extensions.sql.utils.CalciteUtils;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder;
+import org.apache.calcite.rel.type.RelDataTypeSystem;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlAvgAggFunction;
+import org.apache.calcite.sql.fun.SqlCountAggFunction;
+import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
+import org.apache.calcite.sql.fun.SqlSumAggFunction;
+import org.apache.calcite.sql.type.BasicSqlType;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Unit tests for {@link BeamAggregationTransforms}.
+ *
+ */
+public class BeamAggregationTransformTest extends BeamTransformBaseTest{
+
+ @Rule
+ public TestPipeline p = TestPipeline.create();
+
+ private List<AggregateCall> aggCalls;
+
+ private BeamSqlRowType keyType;
+ private BeamSqlRowType aggPartType;
+ private BeamSqlRowType outputType;
+
+ private BeamSqlRowCoder inRecordCoder;
+ private BeamSqlRowCoder keyCoder;
+ private BeamSqlRowCoder aggCoder;
+ private BeamSqlRowCoder outRecordCoder;
+
+ /**
+ * This step equals to below query.
+ * <pre>
+ * SELECT `f_int`
+ * , COUNT(*) AS `size`
+ * , SUM(`f_long`) AS `sum1`, AVG(`f_long`) AS `avg1`
+ * , MAX(`f_long`) AS `max1`, MIN(`f_long`) AS `min1`
+ * , SUM(`f_short`) AS `sum2`, AVG(`f_short`) AS `avg2`
+ * , MAX(`f_short`) AS `max2`, MIN(`f_short`) AS `min2`
+ * , SUM(`f_byte`) AS `sum3`, AVG(`f_byte`) AS `avg3`
+ * , MAX(`f_byte`) AS `max3`, MIN(`f_byte`) AS `min3`
+ * , SUM(`f_float`) AS `sum4`, AVG(`f_float`) AS `avg4`
+ * , MAX(`f_float`) AS `max4`, MIN(`f_float`) AS `min4`
+ * , SUM(`f_double`) AS `sum5`, AVG(`f_double`) AS `avg5`
+ * , MAX(`f_double`) AS `max5`, MIN(`f_double`) AS `min5`
+ * , MAX(`f_timestamp`) AS `max7`, MIN(`f_timestamp`) AS `min7`
+ * ,SUM(`f_int2`) AS `sum8`, AVG(`f_int2`) AS `avg8`
+ * , MAX(`f_int2`) AS `max8`, MIN(`f_int2`) AS `min8`
+ * FROM TABLE_NAME
+ * GROUP BY `f_int`
+ * </pre>
+ * @throws ParseException
+ */
+ @Test
+ public void testCountPerElementBasic() throws ParseException {
+ setupEnvironment();
+
+ PCollection<BeamSqlRow> input = p.apply(Create.of(inputRows));
+
+ //1. extract fields in group-by key part
+ PCollection<KV<BeamSqlRow, BeamSqlRow>> exGroupByStream = input.apply("exGroupBy",
+ WithKeys
+ .of(new BeamAggregationTransforms.AggregationGroupByKeyFn(-1, ImmutableBitSet.of(0))))
+ .setCoder(KvCoder.<BeamSqlRow, BeamSqlRow>of(keyCoder, inRecordCoder));
+
+ //2. apply a GroupByKey.
+ PCollection<KV<BeamSqlRow, Iterable<BeamSqlRow>>> groupedStream = exGroupByStream
+ .apply("groupBy", GroupByKey.<BeamSqlRow, BeamSqlRow>create())
+ .setCoder(KvCoder.<BeamSqlRow, Iterable<BeamSqlRow>>of(keyCoder,
+ IterableCoder.<BeamSqlRow>of(inRecordCoder)));
+
+ //3. run aggregation functions
+ PCollection<KV<BeamSqlRow, BeamSqlRow>> aggregatedStream = groupedStream.apply("aggregation",
+ Combine.<BeamSqlRow, BeamSqlRow, BeamSqlRow>groupedValues(
+ new BeamAggregationTransforms.AggregationAdaptor(aggCalls, inputRowType)))
+ .setCoder(KvCoder.<BeamSqlRow, BeamSqlRow>of(keyCoder, aggCoder));
+
+ //4. flat KV to a single record
+ PCollection<BeamSqlRow> mergedStream = aggregatedStream.apply("mergeRecord",
+ ParDo.of(new BeamAggregationTransforms.MergeAggregationRecord(outputType, aggCalls, -1)));
+ mergedStream.setCoder(outRecordCoder);
+
+ //assert function BeamAggregationTransform.AggregationGroupByKeyFn
+ PAssert.that(exGroupByStream).containsInAnyOrder(prepareResultOfAggregationGroupByKeyFn());
+
+ //assert BeamAggregationTransform.AggregationCombineFn
+ PAssert.that(aggregatedStream).containsInAnyOrder(prepareResultOfAggregationCombineFn());
+
+ //assert BeamAggregationTransform.MergeAggregationRecord
+ PAssert.that(mergedStream).containsInAnyOrder(prepareResultOfMergeAggregationRecord());
+
+ p.run();
+}
+
+ private void setupEnvironment() {
+ prepareAggregationCalls();
+ prepareTypeAndCoder();
+ }
+
+ /**
+ * create list of all {@link AggregateCall}.
+ */
+ @SuppressWarnings("deprecation")
+ private void prepareAggregationCalls() {
+ //aggregations for all data type
+ aggCalls = new ArrayList<>();
+ aggCalls.add(
+ new AggregateCall(new SqlCountAggFunction(), false,
+ Arrays.<Integer>asList(),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "count")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlSumAggFunction(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT)), false,
+ Arrays.<Integer>asList(1),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "sum1")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(1),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "avg1")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(1),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "max1")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(1),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.BIGINT),
+ "min1")
+ );
+
+ aggCalls.add(
+ new AggregateCall(new SqlSumAggFunction(
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT)), false,
+ Arrays.<Integer>asList(2),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+ "sum2")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(2),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+ "avg2")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(2),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+ "max2")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(2),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.SMALLINT),
+ "min2")
+ );
+
+ aggCalls.add(
+ new AggregateCall(
+ new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT)),
+ false,
+ Arrays.<Integer>asList(3),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+ "sum3")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(3),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+ "avg3")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(3),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+ "max3")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(3),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TINYINT),
+ "min3")
+ );
+
+ aggCalls.add(
+ new AggregateCall(
+ new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT)),
+ false,
+ Arrays.<Integer>asList(4),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+ "sum4")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(4),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+ "avg4")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(4),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+ "max4")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(4),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.FLOAT),
+ "min4")
+ );
+
+ aggCalls.add(
+ new AggregateCall(
+ new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE)),
+ false,
+ Arrays.<Integer>asList(5),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+ "sum5")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(5),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+ "avg5")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(5),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+ "max5")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(5),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.DOUBLE),
+ "min5")
+ );
+
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(7),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP),
+ "max7")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(7),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.TIMESTAMP),
+ "min7")
+ );
+
+ aggCalls.add(
+ new AggregateCall(
+ new SqlSumAggFunction(new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER)),
+ false,
+ Arrays.<Integer>asList(8),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+ "sum8")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlAvgAggFunction(SqlKind.AVG), false,
+ Arrays.<Integer>asList(8),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+ "avg8")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MAX), false,
+ Arrays.<Integer>asList(8),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+ "max8")
+ );
+ aggCalls.add(
+ new AggregateCall(new SqlMinMaxAggFunction(SqlKind.MIN), false,
+ Arrays.<Integer>asList(8),
+ new BasicSqlType(RelDataTypeSystem.DEFAULT, SqlTypeName.INTEGER),
+ "min8")
+ );
+ }
+
+ /**
+ * Coders used in aggregation steps.
+ */
+ private void prepareTypeAndCoder() {
+ inRecordCoder = new BeamSqlRowCoder(inputRowType);
+
+ keyType = initTypeOfSqlRow(Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER)));
+ keyCoder = new BeamSqlRowCoder(keyType);
+
+ aggPartType = initTypeOfSqlRow(
+ Arrays.asList(KV.of("count", SqlTypeName.BIGINT),
+
+ KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT),
+ KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT),
+
+ KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT),
+ KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT),
+
+ KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT),
+ KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT),
+
+ KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT),
+ KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT),
+
+ KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE),
+ KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE),
+
+ KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP),
+
+ KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER),
+ KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER)
+ ));
+ aggCoder = new BeamSqlRowCoder(aggPartType);
+
+ outputType = prepareFinalRowType();
+ outRecordCoder = new BeamSqlRowCoder(outputType);
+ }
+
+ /**
+ * expected results after {@link BeamAggregationTransforms.AggregationGroupByKeyFn}.
+ */
+ private List<KV<BeamSqlRow, BeamSqlRow>> prepareResultOfAggregationGroupByKeyFn() {
+ return Arrays.asList(
+ KV.of(new BeamSqlRow(keyType, Arrays.<Object>asList(inputRows.get(0).getInteger(0))),
+ inputRows.get(0)),
+ KV.of(new BeamSqlRow(keyType, Arrays.<Object>asList(inputRows.get(1).getInteger(0))),
+ inputRows.get(1)),
+ KV.of(new BeamSqlRow(keyType, Arrays.<Object>asList(inputRows.get(2).getInteger(0))),
+ inputRows.get(2)),
+ KV.of(new BeamSqlRow(keyType, Arrays.<Object>asList(inputRows.get(3).getInteger(0))),
+ inputRows.get(3)));
+ }
+
+ /**
+ * expected results after {@link BeamAggregationTransforms.AggregationCombineFn}.
+ */
+ private List<KV<BeamSqlRow, BeamSqlRow>> prepareResultOfAggregationCombineFn()
+ throws ParseException {
+ return Arrays.asList(
+ KV.of(new BeamSqlRow(keyType, Arrays.<Object>asList(inputRows.get(0).getInteger(0))),
+ new BeamSqlRow(aggPartType, Arrays.<Object>asList(
+ 4L,
+ 10000L, 2500L, 4000L, 1000L,
+ (short) 10, (short) 2, (short) 4, (short) 1,
+ (byte) 10, (byte) 2, (byte) 4, (byte) 1,
+ 10.0F, 2.5F, 4.0F, 1.0F,
+ 10.0, 2.5, 4.0, 1.0,
+ format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"),
+ 10, 2, 4, 1
+ )))
+ );
+ }
+
+ /**
+ * Row type of final output row.
+ */
+ private BeamSqlRowType prepareFinalRowType() {
+ FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder();
+ List<KV<String, SqlTypeName>> columnMetadata =
+ Arrays.asList(KV.of("f_int", SqlTypeName.INTEGER), KV.of("count", SqlTypeName.BIGINT),
+
+ KV.of("sum1", SqlTypeName.BIGINT), KV.of("avg1", SqlTypeName.BIGINT),
+ KV.of("max1", SqlTypeName.BIGINT), KV.of("min1", SqlTypeName.BIGINT),
+
+ KV.of("sum2", SqlTypeName.SMALLINT), KV.of("avg2", SqlTypeName.SMALLINT),
+ KV.of("max2", SqlTypeName.SMALLINT), KV.of("min2", SqlTypeName.SMALLINT),
+
+ KV.of("sum3", SqlTypeName.TINYINT), KV.of("avg3", SqlTypeName.TINYINT),
+ KV.of("max3", SqlTypeName.TINYINT), KV.of("min3", SqlTypeName.TINYINT),
+
+ KV.of("sum4", SqlTypeName.FLOAT), KV.of("avg4", SqlTypeName.FLOAT),
+ KV.of("max4", SqlTypeName.FLOAT), KV.of("min4", SqlTypeName.FLOAT),
+
+ KV.of("sum5", SqlTypeName.DOUBLE), KV.of("avg5", SqlTypeName.DOUBLE),
+ KV.of("max5", SqlTypeName.DOUBLE), KV.of("min5", SqlTypeName.DOUBLE),
+
+ KV.of("max7", SqlTypeName.TIMESTAMP), KV.of("min7", SqlTypeName.TIMESTAMP),
+
+ KV.of("sum8", SqlTypeName.INTEGER), KV.of("avg8", SqlTypeName.INTEGER),
+ KV.of("max8", SqlTypeName.INTEGER), KV.of("min8", SqlTypeName.INTEGER)
+ );
+ for (KV<String, SqlTypeName> cm : columnMetadata) {
+ builder.add(cm.getKey(), cm.getValue());
+ }
+ return CalciteUtils.toBeamRowType(builder.build());
+ }
+
+ /**
+ * expected results after {@link BeamAggregationTransforms.MergeAggregationRecord}.
+ */
+ private BeamSqlRow prepareResultOfMergeAggregationRecord() throws ParseException {
+ return new BeamSqlRow(outputType, Arrays.<Object>asList(
+ 1, 4L,
+ 10000L, 2500L, 4000L, 1000L,
+ (short) 10, (short) 2, (short) 4, (short) 1,
+ (byte) 10, (byte) 2, (byte) 4, (byte) 1,
+ 10.0F, 2.5F, 4.0F, 1.0F,
+ 10.0, 2.5, 4.0, 1.0,
+ format.parse("2017-01-01 02:04:03"), format.parse("2017-01-01 01:01:03"),
+ 10, 2, 4, 1
+ ));
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/c1b5482d/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/transform/BeamTransformBaseTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/transform/BeamTransformBaseTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/transform/BeamTransformBaseTest.java
new file mode 100644
index 0000000..b2aa6c4
--- /dev/null
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/schema/transform/BeamTransformBaseTest.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sql.schema.transform;
+
+import java.text.DateFormat;
+import java.text.ParseException;
+import java.text.SimpleDateFormat;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.beam.sdk.extensions.sql.planner.BeamQueryPlanner;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRow;
+import org.apache.beam.sdk.extensions.sql.schema.BeamSqlRowType;
+import org.apache.beam.sdk.extensions.sql.utils.CalciteUtils;
+import org.apache.beam.sdk.values.KV;
+import org.apache.calcite.rel.type.RelDataTypeFactory.FieldInfoBuilder;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.junit.BeforeClass;
+
+/**
+ * shared methods to test PTransforms which execute Beam SQL steps.
+ *
+ */
+public class BeamTransformBaseTest {
+ public static DateFormat format = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
+
+ public static BeamSqlRowType inputRowType;
+ public static List<BeamSqlRow> inputRows;
+
+ @BeforeClass
+ public static void prepareInput() throws NumberFormatException, ParseException{
+ List<KV<String, SqlTypeName>> columnMetadata = Arrays.asList(
+ KV.of("f_int", SqlTypeName.INTEGER), KV.of("f_long", SqlTypeName.BIGINT),
+ KV.of("f_short", SqlTypeName.SMALLINT), KV.of("f_byte", SqlTypeName.TINYINT),
+ KV.of("f_float", SqlTypeName.FLOAT), KV.of("f_double", SqlTypeName.DOUBLE),
+ KV.of("f_string", SqlTypeName.VARCHAR), KV.of("f_timestamp", SqlTypeName.TIMESTAMP),
+ KV.of("f_int2", SqlTypeName.INTEGER)
+ );
+ inputRowType = initTypeOfSqlRow(columnMetadata);
+ inputRows = Arrays.asList(
+ initBeamSqlRow(columnMetadata,
+ Arrays.<Object>asList(1, 1000L, Short.valueOf("1"), Byte.valueOf("1"), 1.0F, 1.0,
+ "string_row1", format.parse("2017-01-01 01:01:03"), 1)),
+ initBeamSqlRow(columnMetadata,
+ Arrays.<Object>asList(1, 2000L, Short.valueOf("2"), Byte.valueOf("2"), 2.0F, 2.0,
+ "string_row2", format.parse("2017-01-01 01:02:03"), 2)),
+ initBeamSqlRow(columnMetadata,
+ Arrays.<Object>asList(1, 3000L, Short.valueOf("3"), Byte.valueOf("3"), 3.0F, 3.0,
+ "string_row3", format.parse("2017-01-01 01:03:03"), 3)),
+ initBeamSqlRow(columnMetadata, Arrays.<Object>asList(1, 4000L, Short.valueOf("4"),
+ Byte.valueOf("4"), 4.0F, 4.0, "string_row4", format.parse("2017-01-01 02:04:03"), 4)));
+ }
+
+ /**
+ * create a {@code BeamSqlRowType} for given column metadata.
+ */
+ public static BeamSqlRowType initTypeOfSqlRow(List<KV<String, SqlTypeName>> columnMetadata){
+ FieldInfoBuilder builder = BeamQueryPlanner.TYPE_FACTORY.builder();
+ for (KV<String, SqlTypeName> cm : columnMetadata) {
+ builder.add(cm.getKey(), cm.getValue());
+ }
+ return CalciteUtils.toBeamRowType(builder.build());
+ }
+
+ /**
+ * Create an empty row with given column metadata.
+ */
+ public static BeamSqlRow initBeamSqlRow(List<KV<String, SqlTypeName>> columnMetadata) {
+ return initBeamSqlRow(columnMetadata, Arrays.asList());
+ }
+
+ /**
+ * Create a row with given column metadata, and values for each column.
+ *
+ */
+ public static BeamSqlRow initBeamSqlRow(List<KV<String, SqlTypeName>> columnMetadata,
+ List<Object> rowValues){
+ BeamSqlRowType rowType = initTypeOfSqlRow(columnMetadata);
+
+ return new BeamSqlRow(rowType, rowValues);
+ }
+
+}