You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@orc.apache.org by do...@apache.org on 2021/11/11 19:37:39 UTC

[orc] branch main updated: ORC-1047: Handle quoted field names during string schema parsing (#959)

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/orc.git


The following commit(s) were added to refs/heads/main by this push:
     new 93882ea  ORC-1047: Handle quoted field names during string schema parsing (#959)
93882ea is described below

commit 93882eabe6fc45be0ad84c74192713dd2ff4447a
Author: noirello <no...@gmail.com>
AuthorDate: Thu Nov 11 20:37:35 2021 +0100

    ORC-1047: Handle quoted field names during string schema parsing (#959)
    
    ### What changes were proposed in this pull request?
    
    Improve parsing schema string with `Type::buildTypeFromString` to handle quoted field names and have stricter validations.
    
    ### Why are the changes needed?
    
    The current implementation cannot handle quoted field names and allows parsing string schemas that the Java implementation would reject (e.g. `struct<bigint>`, `map(boolean,float)`). It also cannot parse schema with `timestamp with local time zone` in the root.
    
    ### How was this patch tested?
    
    Ran the existing test suites locally with the newly added tests for quoted field names and invalid schemas.
---
 c++/src/TypeImpl.cc    | 273 +++++++++++++++++++++++++++++++++----------------
 c++/src/TypeImpl.hh    |  12 ++-
 c++/test/TestType.cc   |  60 +++++++++++
 c++/test/TestWriter.cc |   4 +-
 4 files changed, 259 insertions(+), 90 deletions(-)

diff --git a/c++/src/TypeImpl.cc b/c++/src/TypeImpl.cc
index 4d5a5a9..d65b084 100644
--- a/c++/src/TypeImpl.cc
+++ b/c++/src/TypeImpl.cc
@@ -183,6 +183,15 @@ namespace orc {
     return this;
   }
 
+  bool isUnquotedFieldName(std::string fieldName) {
+    for (auto &ch : fieldName) {
+        if (!isalnum(ch) && ch != '_') {
+          return false;
+        }
+    }
+    return true;
+  }
+
   std::string TypeImpl::toString() const {
     switch (static_cast<int64_t>(kind)) {
     case BOOLEAN:
@@ -218,7 +227,19 @@ namespace orc {
         if (i != 0) {
           result += ",";
         }
-        result += fieldNames[i];
+        if (isUnquotedFieldName(fieldNames[i])) {
+          result += fieldNames[i];
+        } else {
+          std::string name(fieldNames[i]);
+          size_t pos = 0;
+          while ((pos = name.find("`", pos)) != std::string::npos) {
+            name.replace(pos, 1, "``");
+            pos += 2;
+          }
+          result += "`";
+          result += name;
+          result += "`";
+        }
         result += ":";
         result += subTypes[i]->toString();
       }
@@ -554,12 +575,13 @@ namespace orc {
   }
 
   ORC_UNIQUE_PTR<Type> Type::buildTypeFromString(const std::string& input) {
-    std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > res =
-      TypeImpl::parseType(input, 0, input.size());
-    if (res.size() != 1) {
+    size_t size = input.size();
+    std::pair<ORC_UNIQUE_PTR<Type>, size_t> res =
+      TypeImpl::parseType(input, 0, size);
+    if (res.second != size) {
       throw std::logic_error("Invalid type string.");
     }
-    return std::move(res[0].second);
+    return std::move(res.first);
   }
 
   std::unique_ptr<Type> TypeImpl::parseArrayType(const std::string &input,
@@ -567,45 +589,107 @@ namespace orc {
                                                  size_t end) {
     TypeImpl* arrayType = new TypeImpl(LIST);
     std::unique_ptr<Type> return_value = std::unique_ptr<Type>(arrayType);
-    std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > v =
-      TypeImpl::parseType(input, start, end);
-    if (v.size() != 1) {
-      throw std::logic_error("Array type must contain exactly one sub type.");
+    if (input[start] != '<') {
+      throw std::logic_error("Missing < after array.");
+    }
+    std::pair<ORC_UNIQUE_PTR<Type>, size_t> res =
+      TypeImpl::parseType(input, start + 1, end);
+    if (res.second != end) {
+      throw std::logic_error(
+        "Array type must contain exactly one sub type.");
     }
-    arrayType->addChildType(std::move(v[0].second));
+    arrayType->addChildType(std::move(res.first));
     return return_value;
   }
 
   std::unique_ptr<Type> TypeImpl::parseMapType(const std::string &input,
                                                size_t start,
                                                size_t end) {
-    TypeImpl * mapType = new TypeImpl(MAP);
+    TypeImpl* mapType = new TypeImpl(MAP);
     std::unique_ptr<Type> return_value = std::unique_ptr<Type>(mapType);
-    std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > v =
-      TypeImpl::parseType(input, start, end);
-    if (v.size() != 2) {
+    if (input[start] != '<') {
+      throw std::logic_error("Missing < after map.");
+    }
+    std::pair<ORC_UNIQUE_PTR<Type>, size_t> key =
+      TypeImpl::parseType(input, start + 1, end);
+    if (input[key.second] != ',') {
+      throw std::logic_error("Missing comma after key.");
+    }
+    std::pair<ORC_UNIQUE_PTR<Type>, size_t> val =
+      TypeImpl::parseType(input, key.second + 1, end);
+    if (val.second != end) {
       throw std::logic_error(
         "Map type must contain exactly two sub types.");
     }
-    mapType->addChildType(std::move(v[0].second));
-    mapType->addChildType(std::move(v[1].second));
+    mapType->addChildType(std::move(key.first));
+    mapType->addChildType(std::move(val.first));
     return return_value;
   }
 
+  std::pair<std::string, size_t> TypeImpl::parseName(const std::string &input,
+                                                     const size_t start,
+                                                     const size_t end) {
+    size_t pos = start;
+    if (input[pos] == '`') {
+      bool closed = false;
+      std::ostringstream oss;
+      while (pos < end) {
+        char ch = input[++pos];
+        if (ch == '`') {
+          if (pos < end && input[pos+1] == '`') {
+            ++pos;
+            oss.put('`');
+          } else {
+            closed = true;
+            break;
+          }
+        } else {
+          oss.put(ch);
+        }
+      }
+      if (!closed) {
+        throw std::logic_error("Invalid field name. Unmatched quote");
+      }
+      if (oss.tellp() == std::streamoff(0)) {
+        throw std::logic_error("Empty quoted field name.");
+      }
+      return std::make_pair(oss.str(), pos + 1);
+    } else {
+      while (pos < end && (isalnum(input[pos]) || input[pos] == '_')) {
+        ++pos;
+      }
+      if (pos == start) {
+        throw std::logic_error("Missing field name.");
+      }
+      return std::make_pair(input.substr(start, pos - start), pos);
+    }
+  }
+
   std::unique_ptr<Type> TypeImpl::parseStructType(const std::string &input,
                                                   size_t start,
                                                   size_t end) {
     TypeImpl* structType = new TypeImpl(STRUCT);
     std::unique_ptr<Type> return_value = std::unique_ptr<Type>(structType);
-    std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type>> > v =
-      TypeImpl::parseType(input, start, end);
-    if (v.size() == 0) {
-      throw std::logic_error(
-        "Struct type must contain at least one sub type.");
+    size_t pos = start + 1;
+    if (input[start] != '<') {
+      throw std::logic_error("Missing < after struct.");
     }
-    for (size_t i = 0; i < v.size(); ++i) {
-      structType->addStructField(v[i].first, std::move(v[i].second));
+    while (pos < end) {
+      std::pair<std::string, size_t> nameRes = parseName(input, pos, end);
+      pos = nameRes.second;
+      if (input[pos] != ':') {
+        throw std::logic_error("Invalid struct type. No field name set.");
+      }
+      std::pair<ORC_UNIQUE_PTR<Type>, size_t> typeRes =
+        TypeImpl::parseType(input, ++pos, end);
+      structType->addStructField(nameRes.first, std::move(typeRes.first));
+      pos = typeRes.second;
+      if (pos != end && input[pos] != ',') {
+        throw std::logic_error("Missing comma after field.");
+      }
+      ++pos;
     }
+
     return return_value;
   }
 
@@ -614,56 +698,89 @@ namespace orc {
                                                  size_t end) {
     TypeImpl* unionType = new TypeImpl(UNION);
     std::unique_ptr<Type> return_value = std::unique_ptr<Type>(unionType);
-    std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > v =
-      TypeImpl::parseType(input, start, end);
-    if (v.size() == 0) {
-      throw std::logic_error("Union type must contain at least one sub type.");
+    size_t pos = start + 1;
+    if (input[start] != '<') {
+      throw std::logic_error("Missing < after uniontype.");
     }
-    for (size_t i = 0; i < v.size(); ++i) {
-      unionType->addChildType(std::move(v[i].second));
+    while (pos < end) {
+      std::pair<ORC_UNIQUE_PTR<Type>, size_t> res =
+        TypeImpl::parseType(input, pos, end);
+      unionType->addChildType(std::move(res.first));
+      pos = res.second;
+      if (pos != end && input[pos] != ',') {
+        throw std::logic_error("Missing comma after union sub type.");
+      }
+      ++pos;
     }
+
     return return_value;
   }
 
   std::unique_ptr<Type> TypeImpl::parseDecimalType(const std::string &input,
                                                    size_t start,
                                                    size_t end) {
-    size_t sep = input.find(',', start);
+    if (input[start] != '(') {
+      throw std::logic_error("Missing ( after decimal.");
+    }
+    size_t pos = start + 1;
+    size_t sep = input.find(',', pos);
     if (sep + 1 >= end || sep == std::string::npos) {
       throw std::logic_error("Decimal type must specify precision and scale.");
     }
     uint64_t precision =
-      static_cast<uint64_t>(atoi(input.substr(start, sep - start).c_str()));
+      static_cast<uint64_t>(atoi(input.substr(pos, sep - pos).c_str()));
     uint64_t scale =
       static_cast<uint64_t>(atoi(input.substr(sep + 1, end - sep - 1).c_str()));
     return std::unique_ptr<Type>(new TypeImpl(DECIMAL, precision, scale));
   }
 
+  void validatePrimitiveType(std::string category,
+                             const std::string &input,
+                             const size_t pos) {
+    if (input[pos] == '<' || input[pos] == '(') {
+      std::ostringstream oss;
+      oss << "Invalid " << input[pos] << " after "
+        << category << " type.";
+      throw std::logic_error(oss.str());
+    }
+  }
+
   std::unique_ptr<Type> TypeImpl::parseCategory(std::string category,
                                                 const std::string &input,
                                                 size_t start,
                                                 size_t end) {
     if (category == "boolean") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(BOOLEAN));
     } else if (category == "tinyint") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(BYTE));
     } else if (category == "smallint") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(SHORT));
     } else if (category == "int") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(INT));
     } else if (category == "bigint") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(LONG));
     } else if (category == "float") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(FLOAT));
     } else if (category == "double") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(DOUBLE));
     } else if (category == "string") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(STRING));
     } else if (category == "binary") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(BINARY));
     } else if (category == "timestamp") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(TIMESTAMP));
     } else if (category == "timestamp with local time zone") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(TIMESTAMP_INSTANT));
     } else if (category == "array") {
       return parseArrayType(input, start, end);
@@ -676,81 +793,63 @@ namespace orc {
     } else if (category == "decimal") {
       return parseDecimalType(input, start, end);
     } else if (category == "date") {
+      validatePrimitiveType(category, input, start);
       return std::unique_ptr<Type>(new TypeImpl(DATE));
     } else if (category == "varchar") {
+      if (input[start] != '(') {
+        throw std::logic_error("Missing ( after varchar.");
+      }
       uint64_t maxLength = static_cast<uint64_t>(
-        atoi(input.substr(start, end - start).c_str()));
+        atoi(input.substr(start + 1, end - start + 1).c_str()));
       return std::unique_ptr<Type>(new TypeImpl(VARCHAR, maxLength));
     } else if (category == "char") {
+      if (input[start] != '(') {
+        throw std::logic_error("Missing ( after char.");
+      }
       uint64_t maxLength = static_cast<uint64_t>(
-        atoi(input.substr(start, end - start).c_str()));
+        atoi(input.substr(start + 1, end - start + 1).c_str()));
       return std::unique_ptr<Type>(new TypeImpl(CHAR, maxLength));
     } else {
       throw std::logic_error("Unknown type " + category);
     }
   }
 
-  std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > TypeImpl::parseType(
-                                                       const std::string &input,
-                                                       size_t start,
-                                                       size_t end) {
-    std::vector<std::pair<std::string, ORC_UNIQUE_PTR<Type> > > res;
+  std::pair<ORC_UNIQUE_PTR<Type>, size_t> TypeImpl::parseType(const std::string &input, size_t start, size_t end) {
     size_t pos = start;
-
-    while (pos < end) {
-      size_t endPos = pos;
-      while (endPos < end && (isalnum(input[endPos]) || input[endPos] == '_')) {
-        ++endPos;
-      }
-
-      std::string fieldName;
-      if (input[endPos] == ':') {
-        fieldName = input.substr(pos, endPos - pos);
-        pos = ++endPos;
-        while (endPos < end && (isalpha(input[endPos]) || input[endPos] == ' ')) {
-          ++endPos;
+    while (pos < end && (isalpha(input[pos]) || input[pos] == ' ')) {
+      ++pos;
+    }
+    size_t endPos = pos;
+    size_t nextPos = pos + 1;
+    if (input[pos] == '<') {
+      int count = 1;
+      while (nextPos < end) {
+        if (input[nextPos] == '<') {
+          ++count;
+        } else if (input[nextPos] == '>') {
+          --count;
         }
-      }
-
-      size_t nextPos = endPos + 1;
-      if (input[endPos] == '<') {
-        int count = 1;
-        while (nextPos < end) {
-          if (input[nextPos] == '<') {
-            ++count;
-          } else if (input[nextPos] == '>') {
-            --count;
-          }
-          if (count == 0) {
-            break;
-          }
-          ++nextPos;
-        }
-        if (nextPos == end) {
-          throw std::logic_error("Invalid type string. Cannot find closing >");
-        }
-      } else if (input[endPos] == '(') {
-        while (nextPos < end && input[nextPos] != ')') {
-          ++nextPos;
+        if (count == 0) {
+          break;
         }
-        if (nextPos == end) {
-          throw std::logic_error("Invalid type string. Cannot find closing )");
-        }
-      } else if (input[endPos] != ',' && endPos != end) {
-        throw std::logic_error("Unrecognized character.");
+        ++nextPos;
       }
-
-      std::string category = input.substr(pos, endPos - pos);
-      res.push_back(std::make_pair(fieldName, parseCategory(category, input, endPos + 1, nextPos)));
-
-      if (nextPos < end && (input[nextPos] == ')' || input[nextPos] == '>')) {
-        pos = nextPos + 2;
-      } else {
-        pos = nextPos;
+      if (nextPos == end) {
+        throw std::logic_error("Invalid type string. Cannot find closing >");
+      }
+      endPos = nextPos + 1;
+    } else if (input[pos] == '(') {
+      while (nextPos < end && input[nextPos] != ')') {
+        ++nextPos;
+      }
+      if (nextPos == end) {
+        throw std::logic_error("Invalid type string. Cannot find closing )");
       }
+      endPos = nextPos + 1;
     }
 
-    return res;
+    std::string category = input.substr(start, pos - start);
+    return std::make_pair(parseCategory(category, input, pos, nextPos), endPos);
   }
 
 }
diff --git a/c++/src/TypeImpl.hh b/c++/src/TypeImpl.hh
index 18a3e71..88c4737 100644
--- a/c++/src/TypeImpl.hh
+++ b/c++/src/TypeImpl.hh
@@ -109,7 +109,7 @@ namespace orc {
      */
     void addChildType(std::unique_ptr<Type> childType);
 
-    static std::vector<std::pair<std::string, std::unique_ptr<Type> > > parseType(
+    static std::pair<ORC_UNIQUE_PTR<Type>, size_t> parseType(
       const std::string &input,
       size_t start,
       size_t end);
@@ -148,6 +148,16 @@ namespace orc {
                                               size_t end);
 
     /**
+     * Parse field name from string
+     * @param input the input string of a field name
+     * @param start start position of the input string
+     * @param end end position of the input string
+     */
+    static std::pair<std::string, size_t> parseName(const std::string &input,
+                                                    const size_t start,
+                                                    const size_t end);
+
+    /**
      * Parse struct type from string
      * @param input the input string of a struct type
      * @param start start position of the input string
diff --git a/c++/test/TestType.cc b/c++/test/TestType.cc
index 3d6f2d1..1473462 100644
--- a/c++/test/TestType.cc
+++ b/c++/test/TestType.cc
@@ -277,6 +277,18 @@ namespace orc {
     EXPECT_EQ(13, cutType->getSubtype(1)->getMaximumColumnId());
   }
 
+  void expectLogicErrorDuringParse(std::string typeStr, const char* errMsg) {
+    try {
+      ORC_UNIQUE_PTR<Type> type = Type::buildTypeFromString(typeStr);
+      FAIL() << "'" << typeStr << "'"
+          << " should throw std::logic_error for invalid schema";
+    } catch (std::logic_error& e) {
+      EXPECT_EQ(e.what(), std::string(errMsg));
+    } catch (...) {
+      FAIL() << "Should only throw std::logic_error for invalid schema";
+    }
+  }
+
   TEST(TestType, buildTypeFromString) {
     std::string typeStr = "struct<a:int,b:string,c:decimal(10,2),d:varchar(5)>";
     ORC_UNIQUE_PTR<Type> type = Type::buildTypeFromString(typeStr);
@@ -303,6 +315,54 @@ namespace orc {
       "struct<a:bigint,b:struct<a:binary,b:timestamp>,c:map<double,tinyint>>";
     type = Type::buildTypeFromString(typeStr);
     EXPECT_EQ(typeStr, type->toString());
+
+    typeStr = "timestamp with local time zone";
+    type = Type::buildTypeFromString(typeStr);
+    EXPECT_EQ(typeStr, type->toString());
+
+    expectLogicErrorDuringParse("foobar",
+        "Unknown type foobar");
+    expectLogicErrorDuringParse("struct<col0:int>other",
+        "Invalid type string.");
+    expectLogicErrorDuringParse("array<>",
+        "Unknown type ");
+    expectLogicErrorDuringParse("array<int,string>",
+        "Array type must contain exactly one sub type.");
+    expectLogicErrorDuringParse("map<int,string,double>",
+        "Map type must contain exactly two sub types.");
+    expectLogicErrorDuringParse("int<>","Invalid < after int type.");
+    expectLogicErrorDuringParse("array(int)", "Missing < after array.");
+    expectLogicErrorDuringParse("struct<struct<bigint>>",
+        "Invalid struct type. No field name set.");
+    expectLogicErrorDuringParse("struct<a:bigint;b:string>",
+        "Missing comma after field.");
+  }
+
+  TEST(TestType, quotedFieldNames) {
+    ORC_UNIQUE_PTR<Type> type = createStructType();
+    type->addStructField("foo bar", createPrimitiveType(INT));
+    type->addStructField("`some`thing`", createPrimitiveType(INT));
+    type->addStructField("1234567890_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ",
+        createPrimitiveType(INT));
+    type->addStructField("'!@#$%^&*()-=_+", createPrimitiveType(INT));
+    EXPECT_EQ("struct<`foo bar`"
+        ":int,```some``thing```:int,"
+        "1234567890_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ:int,"
+        "`'!@#$%^&*()-=_+`:int>", type->toString());
+
+    std::string typeStr =
+      "struct<`foo bar`:int,```quotes```:double,`abc``def````ghi`:float>";
+    type = Type::buildTypeFromString(typeStr);
+    EXPECT_EQ(3, type->getSubtypeCount());
+    EXPECT_EQ("foo bar", type->getFieldName(0));
+    EXPECT_EQ("`quotes`", type->getFieldName(1));
+    EXPECT_EQ("abc`def``ghi", type->getFieldName(2));
+    EXPECT_EQ(typeStr, type->toString());
+
+    expectLogicErrorDuringParse("struct<``:int>",
+        "Empty quoted field name.");
+    expectLogicErrorDuringParse("struct<`col0:int>",
+        "Invalid field name. Unmatched quote");
   }
 
   void testCorruptHelper(const proto::Type& type,
diff --git a/c++/test/TestWriter.cc b/c++/test/TestWriter.cc
index 506887e..c2e50c1 100644
--- a/c++/test/TestWriter.cc
+++ b/c++/test/TestWriter.cc
@@ -1572,7 +1572,7 @@ namespace orc {
     MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
     MemoryPool * pool = getDefaultPool();
     std::unique_ptr<Type> type(Type::buildTypeFromString(
-      "struct<struct<bigint>>"));
+      "struct<col0:struct<col1:bigint>>"));
 
     uint64_t stripeSize = 1024;
     uint64_t compressionBlockSize = 1024;
@@ -1662,7 +1662,7 @@ namespace orc {
     MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
     MemoryPool * pool = getDefaultPool();
     std::unique_ptr<Type> type(Type::buildTypeFromString(
-      "struct<struct<bigint>>"));
+      "struct<col0:struct<col1:bigint>>"));
 
     uint64_t stripeSize = 1024;
     uint64_t compressionBlockSize = 1024;