You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/07/20 13:16:44 UTC

[GitHub] [arrow] pitrou commented on a diff in pull request #13654: ARROW-17135: [C++] Reduce code size in compute/kernels/scalar_compare.cc

pitrou commented on code in PR #13654:
URL: https://github.com/apache/arrow/pull/13654#discussion_r925579583


##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];

Review Comment:
   This could even be `uint8_t`, though I'm not sure it would make much of a difference in practice?



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] = Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(out_bitmap, bit_index++,
+                         Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveAS {
+  static void Exec(const void* left_values_void, const void* right_value_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T right_value = *reinterpret_cast<const T*>(right_value_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveSA {

Review Comment:
   ```suggestion
   struct ComparePrimitiveScalarArray {
   ```



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] = Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(out_bitmap, bit_index++,
+                         Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveAS {
+  static void Exec(const void* left_values_void, const void* right_value_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T right_value = *reinterpret_cast<const T*>(right_value_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveSA {
+  static void Exec(const void* left_value_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T left_value = *reinterpret_cast<const T*>(left_value_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, nullptr));
+    }
+  }
+};
+
+using BinaryKernel = void (*)(const void*, const void*, int64_t, void*);
+
+struct CompareData : public KernelState {
+  BinaryKernel func_aa;
+  BinaryKernel func_sa;
+  BinaryKernel func_as;
+  CompareData(BinaryKernel func_aa, BinaryKernel func_sa, BinaryKernel func_as)
+      : func_aa(func_aa), func_sa(func_sa), func_as(func_as) {}

Review Comment:
   This trivial constructor probably doesn't need to be defined explicitly, as it will be synthesized automatically.



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -171,22 +343,28 @@ struct CompareTimestamps
           "Cannot compare timestamp with timezone to timestamp without timezone, got: ",
           lhs, " and ", rhs);
     }
-    return Base::Exec(ctx, batch, out);
+    return CompareKernel<Int64Type>::Exec(ctx, batch, out);
   }
 };
 
 template <typename Op>
-void AddIntegerCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
-  auto exec =
-      GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
-  DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
+ScalarKernel GetCompareKernel(InputType ty, Type::type compare_type,
+                              ArrayKernelExec exec) {
+  ScalarKernel kernel;
+  kernel.signature = KernelSignature::Make({ty, ty}, boolean());
+  BinaryKernel func_aa = GetBinaryKernel<ComparePrimitive, Op>(compare_type);
+  BinaryKernel func_sa = GetBinaryKernel<ComparePrimitiveSA, Op>(compare_type);
+  BinaryKernel func_as = GetBinaryKernel<ComparePrimitiveAS, Op>(compare_type);
+  kernel.data = std::make_shared<CompareData>(func_aa, func_sa, func_as);
+  kernel.exec = exec;
+  return kernel;
 }
 
-template <typename InType, typename Op>
-void AddGenericCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
-  DCHECK_OK(
-      func->AddKernel({ty, ty}, boolean(),
-                      applicator::ScalarBinaryEqualTypes<BooleanType, InType, Op>::Exec));
+template <typename Op>
+void AddPrimitiveCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
+  ArrayKernelExec exec = GeneratePhysicalNumeric<CompareKernel>(ty);
+  ScalarKernel kernel = GetCompareKernel<Op>(ty, ty->id(), exec);

Review Comment:
   So we're using both `GeneratePhysicalNumeric` and `GetBinaryKernel` which ultimately do the same thing? Is it possible to streamline this and avoid redundancies?



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] = Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(out_bitmap, bit_index++,
+                         Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveAS {
+  static void Exec(const void* left_values_void, const void* right_value_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T right_value = *reinterpret_cast<const T*>(right_value_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveSA {
+  static void Exec(const void* left_value_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T left_value = *reinterpret_cast<const T*>(left_value_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, nullptr));
+    }
+  }
+};
+
+using BinaryKernel = void (*)(const void*, const void*, int64_t, void*);
+
+struct CompareData : public KernelState {
+  BinaryKernel func_aa;
+  BinaryKernel func_sa;
+  BinaryKernel func_as;
+  CompareData(BinaryKernel func_aa, BinaryKernel func_sa, BinaryKernel func_as)
+      : func_aa(func_aa), func_sa(func_sa), func_as(func_as) {}
+};
+
+template <template <typename...> class Generator, typename Op>
+BinaryKernel GetBinaryKernel(Type::type type) {
+  switch (type) {
+    case Type::INT8:
+      return Generator<int8_t, Op>::Exec;
+    case Type::INT16:
+      return Generator<int16_t, Op>::Exec;
+    case Type::INT32:
+    case Type::DATE32:
+      return Generator<int32_t, Op>::Exec;
+    case Type::INT64:
+    case Type::DURATION:
+    case Type::TIMESTAMP:
+    case Type::DATE64:
+      return Generator<int64_t, Op>::Exec;
+    case Type::UINT8:
+      return Generator<uint8_t, Op>::Exec;
+    case Type::UINT16:
+      return Generator<uint16_t, Op>::Exec;
+    case Type::UINT32:
+      return Generator<uint32_t, Op>::Exec;
+    case Type::UINT64:
+      return Generator<uint64_t, Op>::Exec;
+    case Type::FLOAT:
+      return Generator<float, Op>::Exec;
+    case Type::DOUBLE:
+      return Generator<double, Op>::Exec;
+    default:
+      return nullptr;
+  }
+}
+
+template <typename Type>
+struct CompareKernel {
+  using T = typename Type::c_type;
+
+  static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
+    const auto kernel = static_cast<const ScalarKernel*>(ctx->kernel());
+    DCHECK(kernel);
+    const auto kernel_data = static_cast<const CompareData*>(kernel->data.get());
+
+    ArraySpan* out_arr = out->array_span();
+
+    // TODO: implement path for offset not multiple of 8
+    const bool out_is_byte_aligned = out_arr->offset % 8 == 0;

Review Comment:
   Hmm... is there any situation where the output offset is not zero?



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] = Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(out_bitmap, bit_index++,
+                         Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveAS {

Review Comment:
   ```suggestion
   struct ComparePrimitiveArrayScalar {
   ```



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] = Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(out_bitmap, bit_index++,
+                         Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveAS {
+  static void Exec(const void* left_values_void, const void* right_value_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T right_value = *reinterpret_cast<const T*>(right_value_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveSA {
+  static void Exec(const void* left_value_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T left_value = *reinterpret_cast<const T*>(left_value_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, nullptr));
+    }
+  }
+};
+
+using BinaryKernel = void (*)(const void*, const void*, int64_t, void*);
+
+struct CompareData : public KernelState {
+  BinaryKernel func_aa;
+  BinaryKernel func_sa;
+  BinaryKernel func_as;
+  CompareData(BinaryKernel func_aa, BinaryKernel func_sa, BinaryKernel func_as)
+      : func_aa(func_aa), func_sa(func_sa), func_as(func_as) {}
+};
+
+template <template <typename...> class Generator, typename Op>
+BinaryKernel GetBinaryKernel(Type::type type) {
+  switch (type) {
+    case Type::INT8:
+      return Generator<int8_t, Op>::Exec;
+    case Type::INT16:
+      return Generator<int16_t, Op>::Exec;
+    case Type::INT32:
+    case Type::DATE32:
+      return Generator<int32_t, Op>::Exec;
+    case Type::INT64:
+    case Type::DURATION:
+    case Type::TIMESTAMP:
+    case Type::DATE64:
+      return Generator<int64_t, Op>::Exec;
+    case Type::UINT8:
+      return Generator<uint8_t, Op>::Exec;
+    case Type::UINT16:
+      return Generator<uint16_t, Op>::Exec;
+    case Type::UINT32:
+      return Generator<uint32_t, Op>::Exec;
+    case Type::UINT64:
+      return Generator<uint64_t, Op>::Exec;
+    case Type::FLOAT:
+      return Generator<float, Op>::Exec;
+    case Type::DOUBLE:
+      return Generator<double, Op>::Exec;
+    default:
+      return nullptr;
+  }
+}
+
+template <typename Type>
+struct CompareKernel {
+  using T = typename Type::c_type;
+
+  static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
+    const auto kernel = static_cast<const ScalarKernel*>(ctx->kernel());
+    DCHECK(kernel);
+    const auto kernel_data = static_cast<const CompareData*>(kernel->data.get());

Review Comment:
   Can we use `checked_cast` here?



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -158,11 +158,183 @@ struct Maximum {
 
 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
 
-template <typename OutType, typename ArgType, typename Op>
-struct CompareTimestamps
-    : public applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op> {
-  using Base = applicator::ScalarBinaryEqualTypes<OutType, ArgType, Op>;
+template <int batch_size>
+void PackBits(const int* values, uint8_t* out) {
+  for (int i = 0; i < batch_size / 8; ++i) {
+    *out++ = (values[0] | values[1] << 1 | values[2] << 2 | values[3] << 3 |
+              values[4] << 4 | values[5] << 5 | values[6] << 6 | values[7] << 7);
+    values += 8;
+  }
+}
+
+template <typename T, typename Op>
+struct ComparePrimitive {
+  static void Exec(const void* left_values_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] = Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(out_bitmap, bit_index++,
+                         Op::template Call<bool, T, T>(nullptr, *left_values++,
+                                                       *right_values++, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveAS {
+  static void Exec(const void* left_values_void, const void* right_value_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T* left_values = reinterpret_cast<const T*>(left_values_void);
+    const T right_value = *reinterpret_cast<const T*>(right_value_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, *left_values++, right_value, nullptr));
+    }
+  }
+};
+
+template <typename T, typename Op>
+struct ComparePrimitiveSA {
+  static void Exec(const void* left_value_void, const void* right_values_void,
+                   int64_t length, void* out_bitmap_void) {
+    const T left_value = *reinterpret_cast<const T*>(left_value_void);
+    const T* right_values = reinterpret_cast<const T*>(right_values_void);
+    uint8_t* out_bitmap = reinterpret_cast<uint8_t*>(out_bitmap_void);
+    static constexpr int kBatchSize = 32;
+    int64_t num_batches = length / kBatchSize;
+    int temp_output[kBatchSize];
+    for (int64_t j = 0; j < num_batches; ++j) {
+      for (int i = 0; i < kBatchSize; ++i) {
+        temp_output[i] =
+            Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, nullptr);
+      }
+      PackBits<kBatchSize>(temp_output, out_bitmap);
+      out_bitmap += kBatchSize / 8;
+    }
+    int64_t bit_index = 0;
+    for (int64_t j = kBatchSize * num_batches; j < length; ++j) {
+      bit_util::SetBitTo(
+          out_bitmap, bit_index++,
+          Op::template Call<bool, T, T>(nullptr, left_value, *right_values++, nullptr));
+    }
+  }
+};
+
+using BinaryKernel = void (*)(const void*, const void*, int64_t, void*);
+
+struct CompareData : public KernelState {
+  BinaryKernel func_aa;
+  BinaryKernel func_sa;
+  BinaryKernel func_as;
+  CompareData(BinaryKernel func_aa, BinaryKernel func_sa, BinaryKernel func_as)
+      : func_aa(func_aa), func_sa(func_sa), func_as(func_as) {}
+};
+
+template <template <typename...> class Generator, typename Op>
+BinaryKernel GetBinaryKernel(Type::type type) {

Review Comment:
   Can we give this a more descriptive name?



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -310,30 +480,37 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
   return func;
 }
 
-struct FlippedData : public KernelState {
+struct FlippedData : public CompareData {
   ArrayKernelExec unflipped_exec;
-  explicit FlippedData(ArrayKernelExec unflipped_exec) : unflipped_exec(unflipped_exec) {}
+  explicit FlippedData(ArrayKernelExec unflipped_exec, BinaryKernel func_aa = nullptr,
+                       BinaryKernel func_sa = nullptr, BinaryKernel func_as = nullptr)
+      : CompareData(func_aa, func_sa, func_as), unflipped_exec(unflipped_exec) {}
 };
 
-Status FlippedBinaryExec(KernelContext* ctx, const ExecSpan& span, ExecResult* out) {
+Status FlippedCompare(KernelContext* ctx, const ExecSpan& span, ExecResult* out) {
   const auto kernel = static_cast<const ScalarKernel*>(ctx->kernel());
-  DCHECK(kernel);
   const auto kernel_data = static_cast<const FlippedData*>(kernel->data.get());
-
   ExecSpan flipped_span = span;
   std::swap(flipped_span.values[0], flipped_span.values[1]);
   return kernel_data->unflipped_exec(ctx, flipped_span, out);
 }
 
-std::shared_ptr<ScalarFunction> MakeFlippedFunction(std::string name,
-                                                    const ScalarFunction& func,
-                                                    FunctionDoc doc) {
+std::shared_ptr<ScalarFunction> MakeFlippedCompare(std::string name,
+                                                   const ScalarFunction& func,
+                                                   FunctionDoc doc) {
   auto flipped_func =
       std::make_shared<CompareFunction>(name, Arity::Binary(), std::move(doc));
   for (const ScalarKernel* kernel : func.kernels()) {
     ScalarKernel flipped_kernel = *kernel;
-    flipped_kernel.data = std::make_shared<FlippedData>(kernel->exec);
-    flipped_kernel.exec = FlippedBinaryExec;
+    if (kernel->data) {
+      auto compare_data = static_cast<const CompareData*>(kernel->data.get());

Review Comment:
   `checked_cast` perhaps?



##########
cpp/src/arrow/compute/kernels/scalar_compare.cc:
##########
@@ -171,22 +343,28 @@ struct CompareTimestamps
           "Cannot compare timestamp with timezone to timestamp without timezone, got: ",
           lhs, " and ", rhs);
     }
-    return Base::Exec(ctx, batch, out);
+    return CompareKernel<Int64Type>::Exec(ctx, batch, out);
   }
 };
 
 template <typename Op>
-void AddIntegerCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
-  auto exec =
-      GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
-  DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
+ScalarKernel GetCompareKernel(InputType ty, Type::type compare_type,
+                              ArrayKernelExec exec) {
+  ScalarKernel kernel;
+  kernel.signature = KernelSignature::Make({ty, ty}, boolean());
+  BinaryKernel func_aa = GetBinaryKernel<ComparePrimitive, Op>(compare_type);
+  BinaryKernel func_sa = GetBinaryKernel<ComparePrimitiveSA, Op>(compare_type);
+  BinaryKernel func_as = GetBinaryKernel<ComparePrimitiveAS, Op>(compare_type);
+  kernel.data = std::make_shared<CompareData>(func_aa, func_sa, func_as);

Review Comment:
   Nit, but precomputing these seems a bit futile since `CompareKernel` has the right physical type and so deducing the right `ComparePrimitive` functions there should be trivial.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org