You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2021/09/21 11:09:37 UTC

[arrow] branch master updated: ARROW-13573: [C++] Support dictionaries natively in case_when

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 87e2ad5  ARROW-13573: [C++] Support dictionaries natively in case_when
87e2ad5 is described below

commit 87e2ad5b2f2cdaf1e469b1cda1a2899a747464b6
Author: David Li <li...@gmail.com>
AuthorDate: Tue Sep 21 13:08:09 2021 +0200

    ARROW-13573: [C++] Support dictionaries natively in case_when
    
    This supports dictionaries 'natively', that is, dictionaries are no longer always unpacked. (If mixed dictionary and non-dictionary arguments are given, then they will be unpacked.)
    
    For scalar conditions, the output will have the dictionary of whichever input is selected (or no dictionary if the output is null). For array conditions, we unify the dictionaries as we select elements.
    
    Closes #11022 from lidavidm/arrow-13573
    
    Authored-by: David Li <li...@gmail.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 .github/workflows/cpp.yml                          |   2 +-
 .github/workflows/r.yml                            |   2 +-
 ci/scripts/PKGBUILD                                |   6 +-
 cpp/src/arrow/array/array_test.cc                  |  12 +-
 cpp/src/arrow/array/builder_base.cc                |  10 +-
 cpp/src/arrow/array/builder_base.h                 |  13 +-
 cpp/src/arrow/array/builder_dict.cc                |  39 +--
 cpp/src/arrow/array/builder_dict.h                 | 110 +++++++++
 cpp/src/arrow/builder.cc                           | 269 +++++++++++----------
 cpp/src/arrow/compute/kernels/scalar_if_else.cc    |  37 ++-
 .../arrow/compute/kernels/scalar_if_else_test.cc   | 193 +++++++++++++++
 cpp/src/arrow/compute/kernels/test_util.cc         |  92 ++++++-
 cpp/src/arrow/compute/kernels/test_util.h          |  15 ++
 cpp/src/arrow/ipc/json_simple.cc                   |  19 ++
 cpp/src/arrow/ipc/json_simple.h                    |   5 +
 cpp/src/arrow/ipc/json_simple_test.cc              |  24 ++
 cpp/src/arrow/scalar.cc                            |   3 +-
 cpp/src/arrow/testing/gtest_util.cc                |   9 +
 cpp/src/arrow/testing/gtest_util.h                 |   5 +
 19 files changed, 692 insertions(+), 173 deletions(-)

diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml
index 086f45d..0f19f73 100644
--- a/.github/workflows/cpp.yml
+++ b/.github/workflows/cpp.yml
@@ -238,7 +238,7 @@ jobs:
     name: AMD64 Windows MinGW ${{ matrix.mingw-n-bits }} C++
     runs-on: windows-latest
     if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
-    timeout-minutes: 45
+    timeout-minutes: 60
     strategy:
       fail-fast: false
       matrix:
diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml
index e160ba8..3886eaf 100644
--- a/.github/workflows/r.yml
+++ b/.github/workflows/r.yml
@@ -53,7 +53,7 @@ jobs:
     name: AMD64 Ubuntu ${{ matrix.ubuntu }} R ${{ matrix.r }}
     runs-on: ubuntu-latest
     if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
-    timeout-minutes: 60
+    timeout-minutes: 75
     strategy:
       fail-fast: false
       matrix:
diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD
index 56d70d8..246b679 100644
--- a/ci/scripts/PKGBUILD
+++ b/ci/scripts/PKGBUILD
@@ -80,9 +80,13 @@ build() {
     export LIBS="-L${MINGW_PREFIX}/libs"
     export ARROW_S3=OFF
     export ARROW_WITH_RE2=OFF
+    # Without this, some dataset functionality segfaults
+    export CMAKE_UNITY_BUILD=ON
   else
     export ARROW_S3=ON
     export ARROW_WITH_RE2=ON
+    # Without this, some compute functionality segfaults in tests
+    export CMAKE_UNITY_BUILD=OFF
   fi
 
   MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \
@@ -115,7 +119,7 @@ build() {
     -DARROW_CXXFLAGS="${CPPFLAGS}" \
     -DCMAKE_BUILD_TYPE="release" \
     -DCMAKE_INSTALL_PREFIX=${MINGW_PREFIX} \
-    -DCMAKE_UNITY_BUILD=ON \
+    -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \
     -DCMAKE_VERBOSE_MAKEFILE=ON
 
   make -j3
diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc
index d9617c4..2e3d405 100644
--- a/cpp/src/arrow/array/array_test.cc
+++ b/cpp/src/arrow/array/array_test.cc
@@ -456,7 +456,7 @@ TEST_F(TestArray, TestValidateNullCount) {
 void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr<Scalar>& scalar) {
   std::unique_ptr<arrow::ArrayBuilder> builder;
   auto null_scalar = MakeNullScalar(scalar->type);
-  ASSERT_OK(MakeBuilder(pool, scalar->type, &builder));
+  ASSERT_OK(MakeBuilderExactIndex(pool, scalar->type, &builder));
   ASSERT_OK(builder->AppendScalar(*scalar));
   ASSERT_OK(builder->AppendScalar(*scalar));
   ASSERT_OK(builder->AppendScalar(*null_scalar));
@@ -471,15 +471,18 @@ void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr<Scalar>& scalar)
   ASSERT_EQ(out->length(), 9);
 
   const bool can_check_nulls = internal::HasValidityBitmap(out->type()->id());
+  // For a dictionary builder, the output dictionary won't necessarily be the same
+  const bool can_check_values = !is_dictionary(out->type()->id());
 
   if (can_check_nulls) {
     ASSERT_EQ(out->null_count(), 4);
   }
+
   for (const auto index : {0, 1, 3, 5, 6}) {
     ASSERT_FALSE(out->IsNull(index));
     ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index));
     ASSERT_OK(scalar_i->ValidateFull());
-    AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true);
+    if (can_check_values) AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true);
   }
   for (const auto index : {2, 4, 7, 8}) {
     ASSERT_EQ(out->IsNull(index), can_check_nulls);
@@ -575,8 +578,6 @@ TEST_F(TestArray, TestMakeArrayFromScalar) {
   }
 
   for (auto scalar : scalars) {
-    // TODO(ARROW-13197): appending dictionary scalars not implemented
-    if (is_dictionary(scalar->type->id())) continue;
     AssertAppendScalar(pool_, scalar);
   }
 }
@@ -634,9 +635,6 @@ TEST_F(TestArray, TestMakeArrayFromMapScalar) {
 TEST_F(TestArray, TestAppendArraySlice) {
   auto scalars = GetScalars();
   for (const auto& scalar : scalars) {
-    // TODO(ARROW-13573): appending dictionary arrays not implemented
-    if (is_dictionary(scalar->type->id())) continue;
-
     ARROW_SCOPED_TRACE(*scalar->type);
     ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 16));
     ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(scalar->type, 16));
diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc
index 2f4e63b..117b9d3 100644
--- a/cpp/src/arrow/array/builder_base.cc
+++ b/cpp/src/arrow/array/builder_base.cc
@@ -22,6 +22,7 @@
 #include <vector>
 
 #include "arrow/array/array_base.h"
+#include "arrow/array/builder_dict.h"
 #include "arrow/array/data.h"
 #include "arrow/array/util.h"
 #include "arrow/buffer.h"
@@ -268,15 +269,6 @@ struct AppendScalarImpl {
 
 }  // namespace
 
-Status ArrayBuilder::AppendScalar(const Scalar& scalar) {
-  if (!scalar.type->Equals(type())) {
-    return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(),
-                           " to builder for type ", type()->ToString());
-  }
-  std::shared_ptr<Scalar> shared{const_cast<Scalar*>(&scalar), [](Scalar*) {}};
-  return AppendScalarImpl{&shared, &shared + 1, /*n_repeats=*/1, this}.Convert();
-}
-
 Status ArrayBuilder::AppendScalar(const Scalar& scalar, int64_t n_repeats) {
   if (!scalar.type->Equals(type())) {
     return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(),
diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h
index 67203e7..87e39c3 100644
--- a/cpp/src/arrow/array/builder_base.h
+++ b/cpp/src/arrow/array/builder_base.h
@@ -119,9 +119,9 @@ class ARROW_EXPORT ArrayBuilder {
   virtual Status AppendEmptyValues(int64_t length) = 0;
 
   /// \brief Append a value from a scalar
-  Status AppendScalar(const Scalar& scalar);
-  Status AppendScalar(const Scalar& scalar, int64_t n_repeats);
-  Status AppendScalars(const ScalarVector& scalars);
+  Status AppendScalar(const Scalar& scalar) { return AppendScalar(scalar, 1); }
+  virtual Status AppendScalar(const Scalar& scalar, int64_t n_repeats);
+  virtual Status AppendScalars(const ScalarVector& scalars);
 
   /// \brief Append a range of values from an array.
   ///
@@ -282,6 +282,13 @@ ARROW_EXPORT
 Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
                    std::unique_ptr<ArrayBuilder>* out);
 
+/// \brief Construct an empty ArrayBuilder corresponding to the data
+/// type, where any top-level or nested dictionary builders return the
+/// exact index type specified by the type.
+ARROW_EXPORT
+Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+                             std::unique_ptr<ArrayBuilder>* out);
+
 /// \brief Construct an empty DictionaryBuilder initialized optionally
 /// with a pre-existing dictionary
 /// \param[in] pool the MemoryPool to use for allocations
diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc
index b13f6a2..d247316 100644
--- a/cpp/src/arrow/array/builder_dict.cc
+++ b/cpp/src/arrow/array/builder_dict.cc
@@ -159,23 +159,32 @@ DictionaryMemoTable::DictionaryMemoTable(MemoryPool* pool,
 
 DictionaryMemoTable::~DictionaryMemoTable() = default;
 
-#define GET_OR_INSERT(C_TYPE)                                                       \
-  Status DictionaryMemoTable::GetOrInsert(                                          \
-      const typename CTypeTraits<C_TYPE>::ArrowType*, C_TYPE value, int32_t* out) { \
-    return impl_->GetOrInsert<typename CTypeTraits<C_TYPE>::ArrowType>(value, out); \
+#define GET_OR_INSERT(ARROW_TYPE)                                           \
+  Status DictionaryMemoTable::GetOrInsert(                                  \
+      const ARROW_TYPE*, typename ARROW_TYPE::c_type value, int32_t* out) { \
+    return impl_->GetOrInsert<ARROW_TYPE>(value, out);                      \
   }
 
-GET_OR_INSERT(bool)
-GET_OR_INSERT(int8_t)
-GET_OR_INSERT(int16_t)
-GET_OR_INSERT(int32_t)
-GET_OR_INSERT(int64_t)
-GET_OR_INSERT(uint8_t)
-GET_OR_INSERT(uint16_t)
-GET_OR_INSERT(uint32_t)
-GET_OR_INSERT(uint64_t)
-GET_OR_INSERT(float)
-GET_OR_INSERT(double)
+GET_OR_INSERT(BooleanType)
+GET_OR_INSERT(Int8Type)
+GET_OR_INSERT(Int16Type)
+GET_OR_INSERT(Int32Type)
+GET_OR_INSERT(Int64Type)
+GET_OR_INSERT(UInt8Type)
+GET_OR_INSERT(UInt16Type)
+GET_OR_INSERT(UInt32Type)
+GET_OR_INSERT(UInt64Type)
+GET_OR_INSERT(FloatType)
+GET_OR_INSERT(DoubleType)
+GET_OR_INSERT(DurationType);
+GET_OR_INSERT(TimestampType);
+GET_OR_INSERT(Date32Type);
+GET_OR_INSERT(Date64Type);
+GET_OR_INSERT(Time32Type);
+GET_OR_INSERT(Time64Type);
+GET_OR_INSERT(MonthDayNanoIntervalType);
+GET_OR_INSERT(DayTimeIntervalType);
+GET_OR_INSERT(MonthIntervalType);
 
 #undef GET_OR_INSERT
 
diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h
index 455cb3d..0637c97 100644
--- a/cpp/src/arrow/array/builder_dict.h
+++ b/cpp/src/arrow/array/builder_dict.h
@@ -37,6 +37,7 @@
 #include "arrow/util/decimal.h"
 #include "arrow/util/macros.h"
 #include "arrow/util/visibility.h"
+#include "arrow/visitor_inline.h"
 
 namespace arrow {
 
@@ -97,6 +98,17 @@ class ARROW_EXPORT DictionaryMemoTable {
   Status GetOrInsert(const UInt16Type*, uint16_t value, int32_t* out);
   Status GetOrInsert(const UInt32Type*, uint32_t value, int32_t* out);
   Status GetOrInsert(const UInt64Type*, uint64_t value, int32_t* out);
+  Status GetOrInsert(const DurationType*, int64_t value, int32_t* out);
+  Status GetOrInsert(const TimestampType*, int64_t value, int32_t* out);
+  Status GetOrInsert(const Date32Type*, int32_t value, int32_t* out);
+  Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out);
+  Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out);
+  Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out);
+  Status GetOrInsert(const MonthDayNanoIntervalType*,
+                     MonthDayNanoIntervalType::MonthDayNanos value, int32_t* out);
+  Status GetOrInsert(const DayTimeIntervalType*,
+                     DayTimeIntervalType::DayMilliseconds value, int32_t* out);
+  Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out);
   Status GetOrInsert(const FloatType*, float value, int32_t* out);
   Status GetOrInsert(const DoubleType*, double value, int32_t* out);
 
@@ -282,6 +294,73 @@ class DictionaryBuilderBase : public ArrayBuilder {
     return indices_builder_.AppendEmptyValues(length);
   }
 
+  Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override {
+    if (!scalar.is_valid) return AppendNulls(n_repeats);
+
+    const auto& dict_ty = internal::checked_cast<const DictionaryType&>(*scalar.type);
+    const DictionaryScalar& dict_scalar =
+        internal::checked_cast<const DictionaryScalar&>(scalar);
+    const auto& dict = internal::checked_cast<const typename TypeTraits<T>::ArrayType&>(
+        *dict_scalar.value.dictionary);
+    ARROW_RETURN_NOT_OK(Reserve(n_repeats));
+    switch (dict_ty.index_type()->id()) {
+      case Type::UINT8:
+        return AppendScalarImpl<UInt8Type>(dict, *dict_scalar.value.index, n_repeats);
+      case Type::INT8:
+        return AppendScalarImpl<Int8Type>(dict, *dict_scalar.value.index, n_repeats);
+      case Type::UINT16:
+        return AppendScalarImpl<UInt16Type>(dict, *dict_scalar.value.index, n_repeats);
+      case Type::INT16:
+        return AppendScalarImpl<Int16Type>(dict, *dict_scalar.value.index, n_repeats);
+      case Type::UINT32:
+        return AppendScalarImpl<UInt32Type>(dict, *dict_scalar.value.index, n_repeats);
+      case Type::INT32:
+        return AppendScalarImpl<Int32Type>(dict, *dict_scalar.value.index, n_repeats);
+      case Type::UINT64:
+        return AppendScalarImpl<UInt64Type>(dict, *dict_scalar.value.index, n_repeats);
+      case Type::INT64:
+        return AppendScalarImpl<Int64Type>(dict, *dict_scalar.value.index, n_repeats);
+      default:
+        return Status::TypeError("Invalid index type: ", dict_ty);
+    }
+    return Status::OK();
+  }
+
+  Status AppendScalars(const ScalarVector& scalars) override {
+    for (const auto& scalar : scalars) {
+      ARROW_RETURN_NOT_OK(AppendScalar(*scalar, /*n_repeats=*/1));
+    }
+    return Status::OK();
+  }
+
+  Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final {
+    // Visit the indices and insert the unpacked values.
+    const auto& dict_ty = internal::checked_cast<const DictionaryType&>(*array.type);
+    const typename TypeTraits<T>::ArrayType dict(array.dictionary);
+    ARROW_RETURN_NOT_OK(Reserve(length));
+    switch (dict_ty.index_type()->id()) {
+      case Type::UINT8:
+        return AppendArraySliceImpl<uint8_t>(dict, array, offset, length);
+      case Type::INT8:
+        return AppendArraySliceImpl<int8_t>(dict, array, offset, length);
+      case Type::UINT16:
+        return AppendArraySliceImpl<uint16_t>(dict, array, offset, length);
+      case Type::INT16:
+        return AppendArraySliceImpl<int16_t>(dict, array, offset, length);
+      case Type::UINT32:
+        return AppendArraySliceImpl<uint32_t>(dict, array, offset, length);
+      case Type::INT32:
+        return AppendArraySliceImpl<int32_t>(dict, array, offset, length);
+      case Type::UINT64:
+        return AppendArraySliceImpl<uint64_t>(dict, array, offset, length);
+      case Type::INT64:
+        return AppendArraySliceImpl<int64_t>(dict, array, offset, length);
+      default:
+        return Status::TypeError("Invalid index type: ", dict_ty);
+    }
+    return Status::OK();
+  }
+
   /// \brief Insert values into the dictionary's memo, but do not append any
   /// indices. Can be used to initialize a new builder with known dictionary
   /// values
@@ -376,6 +455,37 @@ class DictionaryBuilderBase : public ArrayBuilder {
   }
 
  protected:
+  template <typename c_type>
+  Status AppendArraySliceImpl(const typename TypeTraits<T>::ArrayType& dict,
+                              const ArrayData& array, int64_t offset, int64_t length) {
+    const c_type* values = array.GetValues<c_type>(1) + offset;
+    return VisitBitBlocks(
+        array.buffers[0], array.offset + offset, length,
+        [&](const int64_t position) {
+          const int64_t index = static_cast<int64_t>(values[position]);
+          if (dict.IsValid(index)) {
+            return Append(dict.GetView(index));
+          }
+          return AppendNull();
+        },
+        [&]() { return AppendNull(); });
+  }
+
+  template <typename IndexType>
+  Status AppendScalarImpl(const typename TypeTraits<T>::ArrayType& dict,
+                          const Scalar& index_scalar, int64_t n_repeats) {
+    using ScalarType = typename TypeTraits<IndexType>::ScalarType;
+    const auto index = internal::checked_cast<const ScalarType&>(index_scalar).value;
+    if (index_scalar.is_valid && dict.IsValid(index)) {
+      const auto& value = dict.GetView(index);
+      for (int64_t i = 0; i < n_repeats; i++) {
+        ARROW_RETURN_NOT_OK(Append(value));
+      }
+      return Status::OK();
+    }
+    return AppendNulls(n_repeats);
+  }
+
   Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
     std::shared_ptr<ArrayData> dictionary;
     ARROW_RETURN_NOT_OK(FinishWithDictOffset(/*offset=*/0, out, &dictionary));
diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc
index 37cc9e0..115a97e 100644
--- a/cpp/src/arrow/builder.cc
+++ b/cpp/src/arrow/builder.cc
@@ -41,14 +41,10 @@ struct DictionaryBuilderCase {
   }
 
   Status Visit(const NullType&) { return CreateFor<NullType>(); }
-  Status Visit(const BinaryType&) { return Create<BinaryDictionaryBuilder>(); }
-  Status Visit(const StringType&) { return Create<StringDictionaryBuilder>(); }
-  Status Visit(const LargeBinaryType&) {
-    return Create<DictionaryBuilder<LargeBinaryType>>();
-  }
-  Status Visit(const LargeStringType&) {
-    return Create<DictionaryBuilder<LargeStringType>>();
-  }
+  Status Visit(const BinaryType&) { return CreateFor<BinaryType>(); }
+  Status Visit(const StringType&) { return CreateFor<StringType>(); }
+  Status Visit(const LargeBinaryType&) { return CreateFor<LargeBinaryType>(); }
+  Status Visit(const LargeStringType&) { return CreateFor<LargeStringType>(); }
   Status Visit(const FixedSizeBinaryType&) { return CreateFor<FixedSizeBinaryType>(); }
   Status Visit(const Decimal128Type&) { return CreateFor<Decimal128Type>(); }
   Status Visit(const Decimal256Type&) { return CreateFor<Decimal256Type>(); }
@@ -63,19 +59,50 @@ struct DictionaryBuilderCase {
 
   template <typename ValueType>
   Status CreateFor() {
-    return Create<DictionaryBuilder<ValueType>>();
-  }
-
-  template <typename BuilderType>
-  Status Create() {
-    BuilderType* builder;
+    using AdaptiveBuilderType = DictionaryBuilder<ValueType>;
     if (dictionary != nullptr) {
-      builder = new BuilderType(dictionary, pool);
+      out->reset(new AdaptiveBuilderType(dictionary, pool));
+    } else if (exact_index_type) {
+      switch (index_type->id()) {
+        case Type::UINT8:
+          out->reset(new internal::DictionaryBuilderBase<UInt8Builder, ValueType>(
+              value_type, pool));
+          break;
+        case Type::INT8:
+          out->reset(new internal::DictionaryBuilderBase<Int8Builder, ValueType>(
+              value_type, pool));
+          break;
+        case Type::UINT16:
+          out->reset(new internal::DictionaryBuilderBase<UInt16Builder, ValueType>(
+              value_type, pool));
+          break;
+        case Type::INT16:
+          out->reset(new internal::DictionaryBuilderBase<Int16Builder, ValueType>(
+              value_type, pool));
+          break;
+        case Type::UINT32:
+          out->reset(new internal::DictionaryBuilderBase<UInt32Builder, ValueType>(
+              value_type, pool));
+          break;
+        case Type::INT32:
+          out->reset(new internal::DictionaryBuilderBase<Int32Builder, ValueType>(
+              value_type, pool));
+          break;
+        case Type::UINT64:
+          out->reset(new internal::DictionaryBuilderBase<UInt64Builder, ValueType>(
+              value_type, pool));
+          break;
+        case Type::INT64:
+          out->reset(new internal::DictionaryBuilderBase<Int64Builder, ValueType>(
+              value_type, pool));
+          break;
+        default:
+          return Status::TypeError("MakeBuilder: invalid index type ", *index_type);
+      }
     } else {
       auto start_int_size = internal::GetByteWidth(*index_type);
-      builder = new BuilderType(start_int_size, value_type, pool);
+      out->reset(new AdaptiveBuilderType(start_int_size, value_type, pool));
     }
-    out->reset(builder);
     return Status::OK();
   }
 
@@ -85,138 +112,130 @@ struct DictionaryBuilderCase {
   const std::shared_ptr<DataType>& index_type;
   const std::shared_ptr<DataType>& value_type;
   const std::shared_ptr<Array>& dictionary;
+  bool exact_index_type;
   std::unique_ptr<ArrayBuilder>* out;
 };
 
-#define BUILDER_CASE(TYPE_CLASS)                     \
-  case TYPE_CLASS##Type::type_id:                    \
-    out->reset(new TYPE_CLASS##Builder(type, pool)); \
+struct MakeBuilderImpl {
+  template <typename T>
+  enable_if_not_nested<T, Status> Visit(const T&) {
+    out.reset(new typename TypeTraits<T>::BuilderType(type, pool));
     return Status::OK();
+  }
 
-Result<std::vector<std::shared_ptr<ArrayBuilder>>> FieldBuilders(const DataType& type,
-                                                                 MemoryPool* pool) {
-  std::vector<std::shared_ptr<ArrayBuilder>> field_builders;
+  Status Visit(const DictionaryType& dict_type) {
+    DictionaryBuilderCase visitor = {pool,
+                                     dict_type.index_type(),
+                                     dict_type.value_type(),
+                                     /*dictionary=*/nullptr,
+                                     exact_index_type,
+                                     &out};
+    return visitor.Make();
+  }
 
-  for (const auto& field : type.fields()) {
-    std::unique_ptr<ArrayBuilder> builder;
-    RETURN_NOT_OK(MakeBuilder(pool, field->type(), &builder));
-    field_builders.emplace_back(std::move(builder));
+  Status Visit(const ListType& list_type) {
+    std::shared_ptr<DataType> value_type = list_type.value_type();
+    ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+    out.reset(new ListBuilder(pool, std::move(value_builder), type));
+    return Status::OK();
   }
 
-  return field_builders;
-}
+  Status Visit(const LargeListType& list_type) {
+    std::shared_ptr<DataType> value_type = list_type.value_type();
+    ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+    out.reset(new LargeListBuilder(pool, std::move(value_builder), type));
+    return Status::OK();
+  }
 
-Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
-                   std::unique_ptr<ArrayBuilder>* out) {
-  switch (type->id()) {
-    case Type::NA: {
-      out->reset(new NullBuilder(pool));
-      return Status::OK();
-    }
-      BUILDER_CASE(UInt8);
-      BUILDER_CASE(Int8);
-      BUILDER_CASE(UInt16);
-      BUILDER_CASE(Int16);
-      BUILDER_CASE(UInt32);
-      BUILDER_CASE(Int32);
-      BUILDER_CASE(UInt64);
-      BUILDER_CASE(Int64);
-      BUILDER_CASE(Date32);
-      BUILDER_CASE(Date64);
-      BUILDER_CASE(Duration);
-      BUILDER_CASE(Time32);
-      BUILDER_CASE(Time64);
-      BUILDER_CASE(Timestamp);
-      BUILDER_CASE(MonthInterval);
-      BUILDER_CASE(DayTimeInterval);
-      BUILDER_CASE(MonthDayNanoInterval);
-      BUILDER_CASE(Boolean);
-      BUILDER_CASE(HalfFloat);
-      BUILDER_CASE(Float);
-      BUILDER_CASE(Double);
-      BUILDER_CASE(String);
-      BUILDER_CASE(Binary);
-      BUILDER_CASE(LargeString);
-      BUILDER_CASE(LargeBinary);
-      BUILDER_CASE(FixedSizeBinary);
-      BUILDER_CASE(Decimal128);
-      BUILDER_CASE(Decimal256);
-
-    case Type::DICTIONARY: {
-      const auto& dict_type = static_cast<const DictionaryType&>(*type);
-      DictionaryBuilderCase visitor = {pool, dict_type.index_type(),
-                                       dict_type.value_type(), nullptr, out};
-      return visitor.Make();
-    }
+  Status Visit(const MapType& map_type) {
+    ARROW_ASSIGN_OR_RAISE(auto key_builder, ChildBuilder(map_type.key_type()));
+    ARROW_ASSIGN_OR_RAISE(auto item_builder, ChildBuilder(map_type.item_type()));
+    out.reset(
+        new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type));
+    return Status::OK();
+  }
 
-    case Type::LIST: {
-      std::unique_ptr<ArrayBuilder> value_builder;
-      std::shared_ptr<DataType> value_type =
-          internal::checked_cast<const ListType&>(*type).value_type();
-      RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder));
-      out->reset(new ListBuilder(pool, std::move(value_builder), type));
-      return Status::OK();
-    }
+  Status Visit(const FixedSizeListType& list_type) {
+    auto value_type = list_type.value_type();
+    ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type));
+    out.reset(new FixedSizeListBuilder(pool, std::move(value_builder), type));
+    return Status::OK();
+  }
 
-    case Type::LARGE_LIST: {
-      std::unique_ptr<ArrayBuilder> value_builder;
-      std::shared_ptr<DataType> value_type =
-          internal::checked_cast<const LargeListType&>(*type).value_type();
-      RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder));
-      out->reset(new LargeListBuilder(pool, std::move(value_builder), type));
-      return Status::OK();
-    }
+  Status Visit(const StructType& struct_type) {
+    ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+    out.reset(new StructBuilder(type, pool, std::move(field_builders)));
+    return Status::OK();
+  }
 
-    case Type::MAP: {
-      const auto& map_type = internal::checked_cast<const MapType&>(*type);
-      std::unique_ptr<ArrayBuilder> key_builder, item_builder;
-      RETURN_NOT_OK(MakeBuilder(pool, map_type.key_type(), &key_builder));
-      RETURN_NOT_OK(MakeBuilder(pool, map_type.item_type(), &item_builder));
-      out->reset(
-          new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type));
-      return Status::OK();
-    }
+  Status Visit(const SparseUnionType&) {
+    ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+    out.reset(new SparseUnionBuilder(pool, std::move(field_builders), type));
+    return Status::OK();
+  }
 
-    case Type::FIXED_SIZE_LIST: {
-      const auto& list_type = internal::checked_cast<const FixedSizeListType&>(*type);
-      std::unique_ptr<ArrayBuilder> value_builder;
-      auto value_type = list_type.value_type();
-      RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder));
-      out->reset(new FixedSizeListBuilder(pool, std::move(value_builder), type));
-      return Status::OK();
-    }
+  Status Visit(const DenseUnionType&) {
+    ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
+    out.reset(new DenseUnionBuilder(pool, std::move(field_builders), type));
+    return Status::OK();
+  }
 
-    case Type::STRUCT: {
-      ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
-      out->reset(new StructBuilder(type, pool, std::move(field_builders)));
-      return Status::OK();
-    }
+  Status Visit(const ExtensionType&) { return NotImplemented(); }
+  Status Visit(const DataType&) { return NotImplemented(); }
 
-    case Type::SPARSE_UNION: {
-      ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
-      out->reset(new SparseUnionBuilder(pool, std::move(field_builders), type));
-      return Status::OK();
-    }
+  Status NotImplemented() {
+    return Status::NotImplemented("MakeBuilder: cannot construct builder for type ",
+                                  type->ToString());
+  }
 
-    case Type::DENSE_UNION: {
-      ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool));
-      out->reset(new DenseUnionBuilder(pool, std::move(field_builders), type));
-      return Status::OK();
-    }
+  Result<std::unique_ptr<ArrayBuilder>> ChildBuilder(
+      const std::shared_ptr<DataType>& type) {
+    MakeBuilderImpl impl{pool, type, exact_index_type, /*out=*/nullptr};
+    RETURN_NOT_OK(VisitTypeInline(*type, &impl));
+    return std::move(impl.out);
+  }
 
-    default:
-      break;
+  Result<std::vector<std::shared_ptr<ArrayBuilder>>> FieldBuilders(const DataType& type,
+                                                                   MemoryPool* pool) {
+    std::vector<std::shared_ptr<ArrayBuilder>> field_builders;
+    for (const auto& field : type.fields()) {
+      std::unique_ptr<ArrayBuilder> builder;
+      MakeBuilderImpl impl{pool, field->type(), exact_index_type, /*out=*/nullptr};
+      RETURN_NOT_OK(VisitTypeInline(*field->type(), &impl));
+      field_builders.emplace_back(std::move(impl.out));
+    }
+    return field_builders;
   }
-  return Status::NotImplemented("MakeBuilder: cannot construct builder for type ",
-                                type->ToString());
+
+  MemoryPool* pool;
+  const std::shared_ptr<DataType>& type;
+  bool exact_index_type;
+  std::unique_ptr<ArrayBuilder> out;
+};
+
+Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+                   std::unique_ptr<ArrayBuilder>* out) {
+  MakeBuilderImpl impl{pool, type, /*exact_index_type=*/false, /*out=*/nullptr};
+  RETURN_NOT_OK(VisitTypeInline(*type, &impl));
+  *out = std::move(impl.out);
+  return Status::OK();
+}
+
+Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr<DataType>& type,
+                             std::unique_ptr<ArrayBuilder>* out) {
+  MakeBuilderImpl impl{pool, type, /*exact_index_type=*/true, /*out=*/nullptr};
+  RETURN_NOT_OK(VisitTypeInline(*type, &impl));
+  *out = std::move(impl.out);
+  return Status::OK();
 }
 
 Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
                              const std::shared_ptr<Array>& dictionary,
                              std::unique_ptr<ArrayBuilder>* out) {
   const auto& dict_type = static_cast<const DictionaryType&>(*type);
-  DictionaryBuilderCase visitor = {pool, dict_type.index_type(), dict_type.value_type(),
-                                   dictionary, out};
+  DictionaryBuilderCase visitor = {
+      pool,       dict_type.index_type(),     dict_type.value_type(),
+      dictionary, /*exact_index_type=*/false, out};
   return visitor.Make();
 }
 
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 4de04da..35bb624 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1222,7 +1222,6 @@ struct CaseWhenFunction : ScalarFunction {
     // The first function is a struct of booleans, where the number of fields in the
     // struct is either equal to the number of other arguments or is one less.
     RETURN_NOT_OK(CheckArity(*values));
-    EnsureDictionaryDecoded(values);
     auto first_type = (*values)[0].type;
     if (first_type->id() != Type::STRUCT) {
       return Status::TypeError("case_when: first argument must be STRUCT, not ",
@@ -1243,6 +1242,9 @@ struct CaseWhenFunction : ScalarFunction {
       }
     }
 
+    if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+
+    EnsureDictionaryDecoded(values);
     if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) {
       for (auto it = values->begin() + 1; it != values->end(); it++) {
         it->type = type;
@@ -1279,6 +1281,15 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out
     return Status::OK();
   }
   ArrayData* output = out->mutable_array();
+  if (is_dictionary_type<Type>::value) {
+    const Datum& dict_from = result.is_value() ? result : batch[1];
+    if (dict_from.is_scalar()) {
+      output->dictionary = checked_cast<const DictionaryScalar&>(*dict_from.scalar())
+                               .value.dictionary->data();
+    } else {
+      output->dictionary = dict_from.array()->dictionary;
+    }
+  }
   if (!result.is_value()) {
     // All conditions false, no 'else' argument
     result = MakeNullScalar(out->type());
@@ -1304,6 +1315,7 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out)
       static_cast<size_t>(conds_array.type->num_fields()) < num_value_args;
   uint8_t* out_valid = output->buffers[0]->mutable_data();
   uint8_t* out_values = output->buffers[1]->mutable_data();
+
   if (have_else_arg) {
     // Copy 'else' value into output
     CopyValues<Type>(batch.values.back(), /*in_offset=*/0, batch.length, out_valid,
@@ -1472,7 +1484,7 @@ static Status ExecVarWidthArrayCaseWhenImpl(
   const bool have_else_arg =
       static_cast<size_t>(conds_array.type->num_fields()) < (batch.values.size() - 1);
   std::unique_ptr<ArrayBuilder> raw_builder;
-  RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
+  RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder));
   RETURN_NOT_OK(raw_builder->Reserve(batch.length));
   RETURN_NOT_OK(reserve_data(raw_builder.get()));
 
@@ -1701,6 +1713,24 @@ struct CaseWhenFunctor<Type, enable_if_union<Type>> {
   }
 };
 
+template <>
+struct CaseWhenFunctor<DictionaryType> {
+  static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    if (batch[0].null_count() > 0) {
+      return Status::Invalid("cond struct must not have outer nulls");
+    }
+    if (batch[0].is_scalar()) {
+      return ExecVarWidthScalarCaseWhen(ctx, batch, out);
+    }
+    return ExecArray(ctx, batch, out);
+  }
+
+  static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+    std::function<Status(ArrayBuilder*)> reserve_data = ReserveNoData;
+    return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data));
+  }
+};
+
 struct CoalesceFunction : ScalarFunction {
   using ScalarFunction::ScalarFunction;
 
@@ -2446,7 +2476,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
   }
   {
     auto func = std::make_shared<CaseWhenFunction>(
-        "case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc);
+        "case_when", Arity::VarArgs(/*min_args=*/2), &case_when_doc);
     AddPrimitiveCaseWhenKernels(func, NumericTypes());
     AddPrimitiveCaseWhenKernels(func, TemporalTypes());
     AddPrimitiveCaseWhenKernels(func, IntervalTypes());
@@ -2464,6 +2494,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
     AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor<StructType>::Exec);
     AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor<DenseUnionType>::Exec);
     AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor<SparseUnionType>::Exec);
+    AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor<DictionaryType>::Exec);
     DCHECK_OK(registry->AddFunction(std::move(func)));
   }
   {
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index b3b0f26..8793cac 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -624,6 +624,187 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) {
               ArrayFromJSON(type, R"([null, null, null, [6, null]])"));
 }
 
+template <typename Type>
+class TestCaseWhenDict : public ::testing::Test {};
+
+struct JsonDict {
+  std::shared_ptr<DataType> type;
+  std::string value;
+};
+
+TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes);
+
+TYPED_TEST(TestCaseWhenDict, Simple) {
+  auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+  for (const auto& dict :
+       {JsonDict{utf8(), R"(["a", null, "bc", "def"])"},
+        JsonDict{int64(), "[1, null, 2, 3]"},
+        JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) {
+    auto type = dictionary(default_type_instance<TypeParam>(), dict.type);
+    auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value);
+    auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value);
+    auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value);
+
+    // Easy case: all arguments have the same dictionary
+    CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2});
+    CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1});
+    CheckDictionary("case_when",
+                    {MakeStruct({cond1, cond2}), values_null, values2, values1});
+  }
+}
+
+TYPED_TEST(TestCaseWhenDict, Mixed) {
+  auto type = dictionary(default_type_instance<TypeParam>(), utf8());
+  auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+  auto dict = R"(["a", null, "bc", "def"])";
+  auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
+  auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
+  auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])");
+  auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
+  auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])");
+
+  // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries
+  CheckDictionary("case_when",
+                  {MakeStruct({cond1, cond2}), values1_dict, values2_decoded},
+                  /*result_is_encoded=*/false);
+  CheckDictionary("case_when",
+                  {MakeStruct({cond1, cond2}), values1_decoded, values2_dict},
+                  /*result_is_encoded=*/false);
+  CheckDictionary(
+      "case_when",
+      {MakeStruct({cond1, cond2}), values1_dict, values2_dict, values1_decoded},
+      /*result_is_encoded=*/false);
+  CheckDictionary(
+      "case_when",
+      {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded},
+      /*result_is_encoded=*/false);
+}
+
+TYPED_TEST(TestCaseWhenDict, NestedSimple) {
+  auto make_list = [](const std::shared_ptr<Array>& indices,
+                      const std::shared_ptr<Array>& backing_array) {
+    EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array));
+    return result;
+  };
+  auto index_type = default_type_instance<TypeParam>();
+  auto inner_type = dictionary(index_type, utf8());
+  auto type = list(inner_type);
+  auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+  auto dict = R"(["a", null, "bc", "def"])";
+  auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
+                               DictArrayFromJSON(inner_type, "[]", dict));
+  auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict);
+  auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict);
+  auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
+  auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing);
+
+  CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+                  /*result_is_encoded=*/false);
+  CheckDictionary(
+      "case_when",
+      {MakeStruct({cond1, cond2}), values1,
+       make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)},
+      /*result_is_encoded=*/false);
+  CheckDictionary(
+      "case_when",
+      {MakeStruct({cond1, cond2}), values1,
+       make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), values1},
+      /*result_is_encoded=*/false);
+
+  CheckDictionary("case_when",
+                  {
+                      Datum(MakeStruct({cond1, cond2})),
+                      Datum(std::make_shared<ListScalar>(
+                          DictArrayFromJSON(inner_type, "[0, 1]", dict))),
+                      Datum(std::make_shared<ListScalar>(
+                          DictArrayFromJSON(inner_type, "[2, 3]", dict))),
+                  },
+                  /*result_is_encoded=*/false);
+
+  CheckDictionary("case_when",
+                  {MakeStruct({Datum(true), Datum(false)}), values1, values2},
+                  /*result_is_encoded=*/false);
+  CheckDictionary("case_when",
+                  {MakeStruct({Datum(false), Datum(true)}), values1, values2},
+                  /*result_is_encoded=*/false);
+  CheckDictionary("case_when", {MakeStruct({Datum(false)}), values1, values2},
+                  /*result_is_encoded=*/false);
+  CheckDictionary("case_when",
+                  {MakeStruct({Datum(false), Datum(false)}), values1, values2},
+                  /*result_is_encoded=*/false);
+}
+
+TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) {
+  auto type = dictionary(default_type_instance<TypeParam>(), utf8());
+  auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+  auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]");
+  auto dict1 = R"(["a", null, "bc", "def"])";
+  auto dict2 = R"(["bc", "foo", null, "a"])";
+  auto dict3 = R"(["def", null, "a", "bc"])";
+  auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1);
+  auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2);
+  auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1);
+  auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2);
+  auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3);
+
+  CheckDictionary("case_when",
+                  {MakeStruct({Datum(true), Datum(false)}), values1, values2});
+  CheckDictionary("case_when",
+                  {MakeStruct({Datum(false), Datum(true)}), values1, values2});
+  CheckDictionary("case_when",
+                  {MakeStruct({Datum(false), Datum(false)}), values1, values2});
+  CheckDictionary("case_when",
+                  {MakeStruct({Datum(false), Datum(false)}), values2, values1});
+
+  CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2});
+  CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1});
+
+  CheckDictionary("case_when",
+                  {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}),
+                   values1, values2});
+  CheckDictionary("case_when",
+                  {MakeStruct({ArrayFromJSON(boolean(), "[true, false, false, true]")}),
+                   values1, values2});
+  CheckDictionary("case_when",
+                  {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+                               ArrayFromJSON(boolean(), "[true, false, true, false]")}),
+                   values1, values2});
+  CheckDictionary("case_when",
+                  {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"),
+                               ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+                   values1, values3});
+  CheckDictionary("case_when",
+                  {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"),
+                               ArrayFromJSON(boolean(), "[true, true, true, true]")}),
+                   values1, values3});
+  CheckDictionary(
+      "case_when",
+      {
+          MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}),
+          DictScalarFromJSON(type, "0", dict1),
+          DictScalarFromJSON(type, "0", dict2),
+      });
+  CheckDictionary(
+      "case_when",
+      {
+          MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+                      ArrayFromJSON(boolean(), "[false, false, true, true]")}),
+          DictScalarFromJSON(type, "0", dict1),
+          DictScalarFromJSON(type, "0", dict2),
+      });
+  CheckDictionary(
+      "case_when",
+      {
+          MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"),
+                      ArrayFromJSON(boolean(), "[false, false, true, true]")}),
+          DictScalarFromJSON(type, "null", dict1),
+          DictScalarFromJSON(type, "0", dict2),
+      });
+}
+
 TEST(TestCaseWhen, Null) {
   auto cond_true = ScalarFromJSON(boolean(), "true");
   auto cond_false = ScalarFromJSON(boolean(), "false");
@@ -1489,6 +1670,18 @@ TEST(TestCaseWhen, DispatchBest) {
                 CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")}),
                                            ArrayFromJSON(int64(), "[]"),
                                            ArrayFromJSON(utf8(), "[]")}));
+
+  // Do not dictionary-decode when we have only dictionary values
+  CheckDispatchBest("case_when",
+                    {struct_({field("", boolean())}), dictionary(int64(), utf8()),
+                     dictionary(int64(), utf8())},
+                    {struct_({field("", boolean())}), dictionary(int64(), utf8()),
+                     dictionary(int64(), utf8())});
+
+  // Dictionary-decode if we have a mix
+  CheckDispatchBest(
+      "case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()},
+      {struct_({field("", boolean())}), utf8(), utf8()});
 }
 
 template <typename Type>
diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc
index 4a92151..cedc036 100644
--- a/cpp/src/arrow/compute/kernels/test_util.cc
+++ b/cpp/src/arrow/compute/kernels/test_util.cc
@@ -24,6 +24,7 @@
 #include "arrow/array.h"
 #include "arrow/array/validate.h"
 #include "arrow/chunked_array.h"
+#include "arrow/compute/cast.h"
 #include "arrow/compute/exec.h"
 #include "arrow/compute/function.h"
 #include "arrow/compute/registry.h"
@@ -46,13 +47,6 @@ DatumVector GetDatums(const std::vector<T>& inputs) {
   return datums;
 }
 
-void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
-                             const Datum& expected, const FunctionOptions* options) {
-  ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options));
-  ValidateOutput(out);
-  AssertDatumsEqual(expected, out, /*verbose=*/true);
-}
-
 template <typename... SliceArgs>
 DatumVector SliceArrays(const DatumVector& inputs, SliceArgs... slice_args) {
   DatumVector sliced;
@@ -80,6 +74,13 @@ ScalarVector GetScalars(const DatumVector& inputs, int64_t index) {
 
 }  // namespace
 
+void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
+                             const Datum& expected, const FunctionOptions* options) {
+  ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options));
+  ValidateOutput(out);
+  AssertDatumsEqual(expected, out, /*verbose=*/true);
+}
+
 void CheckScalar(std::string func_name, const ScalarVector& inputs,
                  std::shared_ptr<Scalar> expected, const FunctionOptions* options) {
   ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options));
@@ -170,6 +171,83 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expecte
   }
 }
 
+Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args,
+                                  bool result_is_encoded) {
+  EXPECT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, args));
+  ValidateOutput(actual);
+
+  DatumVector decoded_args;
+  decoded_args.reserve(args.size());
+  for (const auto& arg : args) {
+    if (arg.type()->id() == Type::DICTIONARY) {
+      const auto& to_type = checked_cast<const DictionaryType&>(*arg.type()).value_type();
+      EXPECT_OK_AND_ASSIGN(auto decoded, Cast(arg, to_type));
+      decoded_args.push_back(decoded);
+    } else {
+      decoded_args.push_back(arg);
+    }
+  }
+  EXPECT_OK_AND_ASSIGN(Datum expected, CallFunction(func_name, decoded_args));
+
+  if (result_is_encoded) {
+    EXPECT_EQ(Type::DICTIONARY, actual.type()->id())
+        << "Result should have been dictionary-encoded";
+    // Decode before comparison - we care about equivalent not identical results
+    const auto& to_type =
+        checked_cast<const DictionaryType&>(*actual.type()).value_type();
+    EXPECT_OK_AND_ASSIGN(auto decoded, Cast(actual, to_type));
+    AssertDatumsApproxEqual(expected, decoded, /*verbose=*/true);
+  } else {
+    AssertDatumsApproxEqual(expected, actual, /*verbose=*/true);
+  }
+  return actual;
+}
+
+void CheckDictionary(const std::string& func_name, const DatumVector& args,
+                     bool result_is_encoded) {
+  auto actual = CheckDictionaryNonRecursive(func_name, args, result_is_encoded);
+
+  if (actual.is_scalar()) return;
+  ASSERT_TRUE(actual.is_array());
+  ASSERT_GE(actual.length(), 0);
+
+  // Check all scalars
+  for (int64_t i = 0; i < actual.length(); i++) {
+    CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i)),
+                                result_is_encoded);
+  }
+
+  // Check slices of the input
+  const auto slice_length = actual.length() / 3;
+  if (slice_length > 0) {
+    CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length),
+                                result_is_encoded);
+    CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length),
+                                result_is_encoded);
+    CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length),
+                                result_is_encoded);
+  }
+
+  // Check empty slice
+  CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0), result_is_encoded);
+
+  // Check chunked arrays
+  if (slice_length > 0) {
+    DatumVector chunked_args;
+    chunked_args.reserve(args.size());
+    for (const auto& arg : args) {
+      if (arg.is_array()) {
+        auto arr = arg.make_array();
+        ArrayVector chunks{arr->Slice(0, slice_length), arr->Slice(slice_length)};
+        chunked_args.push_back(std::make_shared<ChunkedArray>(std::move(chunks)));
+      } else {
+        chunked_args.push_back(arg);
+      }
+    }
+    CheckDictionaryNonRecursive(func_name, chunked_args, result_is_encoded);
+  }
+}
+
 void CheckScalarUnary(std::string func_name, Datum input, Datum expected,
                       const FunctionOptions* options) {
   std::vector<Datum> input_vector = {std::move(input)};
diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h
index 79745b0..25ea577 100644
--- a/cpp/src/arrow/compute/kernels/test_util.h
+++ b/cpp/src/arrow/compute/kernels/test_util.h
@@ -67,6 +67,8 @@ inline std::string CompareOperatorToFunctionName(CompareOperator op) {
   return function_names[op];
 }
 
+// Call the function with the given arguments, as well as slices of
+// the arguments and scalars extracted from the arguments.
 void CheckScalar(std::string func_name, const ScalarVector& inputs,
                  std::shared_ptr<Scalar> expected,
                  const FunctionOptions* options = nullptr);
@@ -74,6 +76,19 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs,
 void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected,
                  const FunctionOptions* options = nullptr);
 
+// Like CheckScalar, but gets the expected result by
+// dictionary-decoding arguments and calling the function again.
+//
+// result_is_encoded controls whether the result is expected to be a
+// dictionary or not.
+void CheckDictionary(const std::string& func_name, const DatumVector& args,
+                     bool result_is_encoded = true);
+
+// Just call the function with the given arguments.
+void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs,
+                             const Datum& expected,
+                             const FunctionOptions* options = nullptr);
+
 void CheckScalarUnary(std::string func_name, std::shared_ptr<DataType> in_ty,
                       std::string json_input, std::shared_ptr<DataType> out_ty,
                       std::string json_expected,
diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc
index 34b0f3f..8347b87 100644
--- a/cpp/src/arrow/ipc/json_simple.cc
+++ b/cpp/src/arrow/ipc/json_simple.cc
@@ -969,6 +969,25 @@ Status ScalarFromJSON(const std::shared_ptr<DataType>& type,
   return Status::OK();
 }
 
+Status DictScalarFromJSON(const std::shared_ptr<DataType>& type,
+                          util::string_view index_json, util::string_view dictionary_json,
+                          std::shared_ptr<Scalar>* out) {
+  if (type->id() != Type::DICTIONARY) {
+    return Status::TypeError("DictScalarFromJSON requires dictionary type, got ", *type);
+  }
+
+  const auto& dictionary_type = checked_cast<const DictionaryType&>(*type);
+
+  std::shared_ptr<Scalar> index;
+  std::shared_ptr<Array> dictionary;
+  RETURN_NOT_OK(ScalarFromJSON(dictionary_type.index_type(), index_json, &index));
+  RETURN_NOT_OK(
+      ArrayFromJSON(dictionary_type.value_type(), dictionary_json, &dictionary));
+
+  *out = DictionaryScalar::Make(std::move(index), std::move(dictionary));
+  return Status::OK();
+}
+
 }  // namespace json
 }  // namespace internal
 }  // namespace ipc
diff --git a/cpp/src/arrow/ipc/json_simple.h b/cpp/src/arrow/ipc/json_simple.h
index 4dd3a66..8269bd6 100644
--- a/cpp/src/arrow/ipc/json_simple.h
+++ b/cpp/src/arrow/ipc/json_simple.h
@@ -55,6 +55,11 @@ ARROW_EXPORT
 Status ScalarFromJSON(const std::shared_ptr<DataType>&, util::string_view json,
                       std::shared_ptr<Scalar>* out);
 
+ARROW_EXPORT
+Status DictScalarFromJSON(const std::shared_ptr<DataType>&, util::string_view index_json,
+                          util::string_view dictionary_json,
+                          std::shared_ptr<Scalar>* out);
+
 }  // namespace json
 }  // namespace internal
 }  // namespace ipc
diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc
index ce2c37b..34c300f 100644
--- a/cpp/src/arrow/ipc/json_simple_test.cc
+++ b/cpp/src/arrow/ipc/json_simple_test.cc
@@ -1385,6 +1385,30 @@ TEST(TestScalarFromJSON, Errors) {
   ASSERT_RAISES(Invalid, ScalarFromJSON(boolean(), "\"true\"", &scalar));
 }
 
+TEST(TestDictScalarFromJSON, Basics) {
+  auto type = dictionary(int32(), utf8());
+  auto dict = R"(["whiskey", "tango", "foxtrot"])";
+  auto expected_dictionary = ArrayFromJSON(utf8(), dict);
+
+  for (auto index : {"null", "2", "1", "0"}) {
+    auto scalar = DictScalarFromJSON(type, index, dict);
+    auto expected_index = ScalarFromJSON(int32(), index);
+    AssertScalarsEqual(*DictionaryScalar::Make(expected_index, expected_dictionary),
+                       *scalar, /*verbose=*/true);
+    ASSERT_OK(scalar->ValidateFull());
+  }
+}
+
+TEST(TestDictScalarFromJSON, Errors) {
+  auto type = dictionary(int32(), utf8());
+  std::shared_ptr<Scalar> scalar;
+
+  ASSERT_RAISES(Invalid,
+                DictScalarFromJSON(type, "\"not a valid index\"", "[\"\"]", &scalar));
+  ASSERT_RAISES(Invalid, DictScalarFromJSON(type, "0", "[1]",
+                                            &scalar));  // dict value isn't string
+}
+
 }  // namespace json
 }  // namespace internal
 }  // namespace ipc
diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc
index 60ba54f..adfc501 100644
--- a/cpp/src/arrow/scalar.cc
+++ b/cpp/src/arrow/scalar.cc
@@ -599,8 +599,9 @@ Result<std::shared_ptr<Scalar>> DictionaryScalar::GetEncodedValue() const {
 std::shared_ptr<DictionaryScalar> DictionaryScalar::Make(std::shared_ptr<Scalar> index,
                                                          std::shared_ptr<Array> dict) {
   auto type = dictionary(index->type, dict->type());
+  auto is_valid = index->is_valid;
   return std::make_shared<DictionaryScalar>(ValueType{std::move(index), std::move(dict)},
-                                            std::move(type));
+                                            std::move(type), is_valid);
 }
 
 namespace {
diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc
index 587154c..24f5edc 100644
--- a/cpp/src/arrow/testing/gtest_util.cc
+++ b/cpp/src/arrow/testing/gtest_util.cc
@@ -446,6 +446,15 @@ std::shared_ptr<Scalar> ScalarFromJSON(const std::shared_ptr<DataType>& type,
   return out;
 }
 
+std::shared_ptr<Scalar> DictScalarFromJSON(const std::shared_ptr<DataType>& type,
+                                           util::string_view index_json,
+                                           util::string_view dictionary_json) {
+  std::shared_ptr<Scalar> out;
+  ABORT_NOT_OK(
+      ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out));
+  return out;
+}
+
 std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>& schema,
                                      const std::vector<std::string>& json) {
   std::vector<std::shared_ptr<RecordBatch>> batches;
diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h
index f0021e0..65ab33c 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -339,6 +339,11 @@ std::shared_ptr<Scalar> ScalarFromJSON(const std::shared_ptr<DataType>&,
                                        util::string_view json);
 
 ARROW_TESTING_EXPORT
+std::shared_ptr<Scalar> DictScalarFromJSON(const std::shared_ptr<DataType>&,
+                                           util::string_view index_json,
+                                           util::string_view dictionary_json);
+
+ARROW_TESTING_EXPORT
 std::shared_ptr<Table> TableFromJSON(const std::shared_ptr<Schema>&,
                                      const std::vector<std::string>& json);