You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/06/15 09:56:21 UTC

[GitHub] [arrow] pitrou commented on a change in pull request #10520: ARROW-12709: [C++] Add binary_join_element_wise

pitrou commented on a change in pull request #10520:
URL: https://github.com/apache/arrow/pull/10520#discussion_r651627477



##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -3344,12 +3344,231 @@ struct BinaryJoin {
   }
 };
 
+using BinaryJoinElementWiseState = OptionsWrapper<JoinOptions>;
+
+template <typename Type>
+struct BinaryJoinElementWise {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using BuilderType = typename TypeTraits<Type>::BuilderType;
+  using offset_type = typename Type::offset_type;
+
+  static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    JoinOptions options = BinaryJoinElementWiseState::Get(ctx);
+    // Last argument is the separator (for consistency with binary_join)
+    if (std::all_of(batch.values.begin(), batch.values.end(),
+                    [](const Datum& d) { return d.is_scalar(); })) {
+      return ExecOnlyScalar(ctx, options, batch, out);
+    }
+    return ExecContainingArrays(ctx, options, batch, out);
+  }
+
+  static Status ExecOnlyScalar(KernelContext* ctx, const JoinOptions& options,
+                               const ExecBatch& batch, Datum* out) {
+    BaseBinaryScalar* output = checked_cast<BaseBinaryScalar*>(out->scalar().get());
+    const size_t num_args = batch.values.size();
+    if (num_args == 1) {
+      // Only separator, no values
+      ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
+      output->is_valid = batch.values[0].scalar()->is_valid;
+      return Status::OK();
+    }
+
+    int64_t final_size = CalculateRowSize(options, batch, 0);
+    if (final_size < 0) {
+      ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
+      output->is_valid = false;
+      return Status::OK();
+    }
+    ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(final_size));
+    const auto separator = UnboxScalar<Type>::Unbox(*batch.values.back().scalar());
+    uint8_t* buf = output->value->mutable_data();
+    bool first = true;
+    for (size_t i = 0; i < num_args - 1; i++) {
+      const Scalar& scalar = *batch[i].scalar();
+      util::string_view s;
+      if (scalar.is_valid) {
+        s = UnboxScalar<Type>::Unbox(scalar);
+      } else {
+        switch (options.null_handling) {
+          case JoinOptions::EMIT_NULL:
+            // Handled by CalculateRowSize
+            DCHECK(false) << "unreachable";
+            break;
+          case JoinOptions::SKIP:
+            continue;
+          case JoinOptions::REPLACE:
+            s = options.null_replacement;
+            break;
+        }
+      }
+      if (!first) {
+        buf = std::copy(separator.begin(), separator.end(), buf);
+      }
+      first = false;
+      buf = std::copy(s.begin(), s.end(), buf);
+    }
+    output->is_valid = true;
+    return Status::OK();
+  }
+
+  static Status ExecContainingArrays(KernelContext* ctx, const JoinOptions& options,
+                                     const ExecBatch& batch, Datum* out) {
+    // Presize data to avoid reallocations
+    int64_t final_size = 0;
+    for (int64_t i = 0; i < batch.length; i++) {
+      auto size = CalculateRowSize(options, batch, i);
+      if (size > 0) final_size += size;
+    }
+    BuilderType builder(ctx->memory_pool());
+    RETURN_NOT_OK(builder.Reserve(batch.length));
+    RETURN_NOT_OK(builder.ReserveData(final_size));
+
+    std::vector<const Datum*> valid_cols(batch.values.size());
+    for (size_t row = 0; row < static_cast<size_t>(batch.length); row++) {
+      size_t num_valid = 0;  // Not counting separator
+      for (size_t col = 0; col < batch.values.size(); col++) {
+        bool valid = false;
+        if (batch[col].is_scalar()) {
+          valid = batch[col].scalar()->is_valid;
+        } else {
+          const ArrayData& array = *batch[col].array();
+          valid = !array.MayHaveNulls() ||
+                  BitUtil::GetBit(array.buffers[0]->data(), array.offset + row);
+        }
+        if (valid) {
+          valid_cols[col] = &batch[col];
+          if (col < batch.values.size() - 1) num_valid++;
+        } else {
+          valid_cols[col] = nullptr;
+        }
+      }
+
+      if (!valid_cols.back()) {
+        // Separator is null
+        builder.UnsafeAppendNull();
+        continue;
+      } else if (batch.values.size() == 1) {
+        // Only given separator
+        builder.UnsafeAppendEmptyValue();
+        continue;
+      } else if (num_valid < batch.values.size() - 1) {
+        // We had some nulls
+        if (options.null_handling == JoinOptions::EMIT_NULL) {
+          builder.UnsafeAppendNull();
+          continue;
+        }
+      }
+      const auto separator = Lookup(*valid_cols.back(), row);
+      bool first = true;
+      for (size_t col = 0; col < batch.values.size() - 1; col++) {
+        const Datum* datum = valid_cols[col];
+        util::string_view value;
+        if (!datum) {
+          switch (options.null_handling) {
+            case JoinOptions::EMIT_NULL:
+              DCHECK(false) << "unreachable";
+              break;
+            case JoinOptions::SKIP:
+              continue;
+            case JoinOptions::REPLACE:
+              value = options.null_replacement;
+              break;
+          }
+        } else {
+          value = Lookup(*datum, row);
+        }
+        if (first) {
+          builder.UnsafeAppend(value);
+          first = false;
+          continue;
+        }
+        builder.UnsafeExtendCurrent(separator);
+        builder.UnsafeExtendCurrent(value);
+      }
+    }
+
+    std::shared_ptr<Array> string_array;
+    RETURN_NOT_OK(builder.Finish(&string_array));
+    *out = *string_array->data();
+    out->mutable_array()->type = batch[0].type();
+    DCHECK_EQ(batch.length, out->array()->length);

Review comment:
       Also `DCHECK_EQ(final_size, checked_cast<const ArrayType&>(*string_array).total_values_length())`?

##########
File path: cpp/src/arrow/compute/kernels/scalar_string.cc
##########
@@ -3344,12 +3344,231 @@ struct BinaryJoin {
   }
 };
 
+using BinaryJoinElementWiseState = OptionsWrapper<JoinOptions>;
+
+template <typename Type>
+struct BinaryJoinElementWise {
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+  using BuilderType = typename TypeTraits<Type>::BuilderType;
+  using offset_type = typename Type::offset_type;
+
+  static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    JoinOptions options = BinaryJoinElementWiseState::Get(ctx);
+    // Last argument is the separator (for consistency with binary_join)
+    if (std::all_of(batch.values.begin(), batch.values.end(),
+                    [](const Datum& d) { return d.is_scalar(); })) {
+      return ExecOnlyScalar(ctx, options, batch, out);
+    }
+    return ExecContainingArrays(ctx, options, batch, out);
+  }
+
+  static Status ExecOnlyScalar(KernelContext* ctx, const JoinOptions& options,
+                               const ExecBatch& batch, Datum* out) {
+    BaseBinaryScalar* output = checked_cast<BaseBinaryScalar*>(out->scalar().get());
+    const size_t num_args = batch.values.size();
+    if (num_args == 1) {
+      // Only separator, no values
+      ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
+      output->is_valid = batch.values[0].scalar()->is_valid;
+      return Status::OK();
+    }
+
+    int64_t final_size = CalculateRowSize(options, batch, 0);
+    if (final_size < 0) {
+      ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(0));
+      output->is_valid = false;
+      return Status::OK();
+    }
+    ARROW_ASSIGN_OR_RAISE(output->value, ctx->Allocate(final_size));
+    const auto separator = UnboxScalar<Type>::Unbox(*batch.values.back().scalar());
+    uint8_t* buf = output->value->mutable_data();
+    bool first = true;
+    for (size_t i = 0; i < num_args - 1; i++) {
+      const Scalar& scalar = *batch[i].scalar();
+      util::string_view s;
+      if (scalar.is_valid) {
+        s = UnboxScalar<Type>::Unbox(scalar);
+      } else {
+        switch (options.null_handling) {
+          case JoinOptions::EMIT_NULL:
+            // Handled by CalculateRowSize
+            DCHECK(false) << "unreachable";
+            break;
+          case JoinOptions::SKIP:
+            continue;
+          case JoinOptions::REPLACE:
+            s = options.null_replacement;
+            break;
+        }
+      }
+      if (!first) {
+        buf = std::copy(separator.begin(), separator.end(), buf);
+      }
+      first = false;
+      buf = std::copy(s.begin(), s.end(), buf);
+    }
+    output->is_valid = true;

Review comment:
       Add something like `DCHECK_EQ(buf - output->value->mutable_data(), final_size)`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org