You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2020/11/30 13:12:24 UTC
[arrow] branch master updated: ARROW-10720: [C++] Add Rescale
support for BasicDecimal256
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 64f9b3f ARROW-10720: [C++] Add Rescale support for BasicDecimal256
64f9b3f is described below
commit 64f9b3fbe9ef4c718449a735435b53ab992ca852
Author: Bei <be...@gmail.com>
AuthorDate: Mon Nov 30 14:10:52 2020 +0100
ARROW-10720: [C++] Add Rescale support for BasicDecimal256
Closes #8763 from Bei-z/rescale
Authored-by: Bei <be...@gmail.com>
Signed-off-by: Antoine Pitrou <an...@python.org>
---
cpp/src/arrow/python/python_test.cc | 33 +++++-
cpp/src/arrow/util/basic_decimal.cc | 147 +++++++++++++++++++++++----
cpp/src/arrow/util/basic_decimal.h | 3 +
cpp/src/arrow/util/decimal_test.cc | 69 ++++++++++++-
python/pyarrow/tests/test_convert_builtin.py | 7 +-
python/pyarrow/tests/test_scalars.py | 5 +-
6 files changed, 234 insertions(+), 30 deletions(-)
diff --git a/cpp/src/arrow/python/python_test.cc b/cpp/src/arrow/python/python_test.cc
index 11d3593..33e0ee9 100644
--- a/cpp/src/arrow/python/python_test.cc
+++ b/cpp/src/arrow/python/python_test.cc
@@ -360,7 +360,8 @@ TEST_F(DecimalTest, FromPythonDecimalRescaleNotTruncateable) {
// lower scale
DecimalTestFromPythonDecimalRescale<Decimal128>(::arrow::decimal128(10, 2),
this->CreatePythonDecimal("1.001"), {});
- // TODO: Test Decimal256 after implementing scaling.
+ DecimalTestFromPythonDecimalRescale<Decimal256>(::arrow::decimal256(10, 2),
+ this->CreatePythonDecimal("1.001"), {});
}
TEST_F(DecimalTest, FromPythonDecimalRescaleTruncateable) {
@@ -368,13 +369,15 @@ TEST_F(DecimalTest, FromPythonDecimalRescaleTruncateable) {
// difference between the scales, e.g., 1.000 -> 1.00
DecimalTestFromPythonDecimalRescale<Decimal128>(
::arrow::decimal128(10, 2), this->CreatePythonDecimal("1.000"), 100);
- // TODO: Test Decimal256 after implementing scaling.
+ DecimalTestFromPythonDecimalRescale<Decimal256>(
+ ::arrow::decimal256(10, 2), this->CreatePythonDecimal("1.000"), 100);
}
TEST_F(DecimalTest, FromPythonNegativeDecimalRescale) {
DecimalTestFromPythonDecimalRescale<Decimal128>(
::arrow::decimal128(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000);
- // TODO: Test Decimal256 after implementing scaling.
+ DecimalTestFromPythonDecimalRescale<Decimal256>(
+ ::arrow::decimal256(10, 9), this->CreatePythonDecimal("-1.000"), -1000000000);
}
TEST_F(DecimalTest, Decimal128FromPythonInteger) {
@@ -386,7 +389,14 @@ TEST_F(DecimalTest, Decimal128FromPythonInteger) {
ASSERT_EQ(4200, value);
}
-// TODO: Test Decimal256 from python after implementing scaling.
+TEST_F(DecimalTest, Decimal256FromPythonInteger) {
+ Decimal256 value;
+ OwnedRef python_long(PyLong_FromLong(42));
+ auto type = ::arrow::decimal256(10, 2);
+ const auto& decimal_type = checked_cast<const DecimalType&>(*type);
+ ASSERT_OK(internal::DecimalFromPyObject(python_long.obj(), decimal_type, &value));
+ ASSERT_EQ(4200, value);
+}
TEST_F(DecimalTest, TestDecimal128OverflowFails) {
Decimal128 value;
@@ -403,7 +413,20 @@ TEST_F(DecimalTest, TestDecimal128OverflowFails) {
decimal_type, &value));
}
-// TODO: Test Decimal256 overflow after implementing scaling.
+TEST_F(DecimalTest, TestDecimal256OverflowFails) {
+ Decimal256 value;
+ OwnedRef python_decimal(this->CreatePythonDecimal(
+ "999999999999999999999999999999999999999999999999999999999999999999999999999.9"));
+ internal::DecimalMetadata metadata;
+ ASSERT_OK(metadata.Update(python_decimal.obj()));
+ ASSERT_EQ(76, metadata.precision());
+ ASSERT_EQ(1, metadata.scale());
+
+ auto type = ::arrow::decimal(76, 76);
+ const auto& decimal_type = checked_cast<const DecimalType&>(*type);
+ ASSERT_RAISES(Invalid, internal::DecimalFromPythonDecimal(python_decimal.obj(),
+ decimal_type, &value));
+}
TEST_F(DecimalTest, TestNoneAndNaN) {
OwnedRef list_ref(PyList_New(4));
diff --git a/cpp/src/arrow/util/basic_decimal.cc b/cpp/src/arrow/util/basic_decimal.cc
index 652d70c..1abcc2c 100644
--- a/cpp/src/arrow/util/basic_decimal.cc
+++ b/cpp/src/arrow/util/basic_decimal.cc
@@ -120,6 +120,112 @@ static const BasicDecimal128 ScaleMultipliersHalf[] = {
BasicDecimal128(271050543121376108LL, 9257742014424809472ULL),
BasicDecimal128(2710505431213761085LL, 343699775700336640ULL)};
+static const BasicDecimal256 ScaleMultipliersDecimal256[] = {
+ BasicDecimal256({1ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({10ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({100ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({1000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({10000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({100000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({1000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({10000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({100000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({1000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({10000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({100000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({1000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({10000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({100000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({1000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({10000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({100000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({1000000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({10000000000000000000ULL, 0ULL, 0ULL, 0ULL}),
+ BasicDecimal256({7766279631452241920ULL, 5ULL, 0ULL, 0ULL}),
+ BasicDecimal256({3875820019684212736ULL, 54ULL, 0ULL, 0ULL}),
+ BasicDecimal256({1864712049423024128ULL, 542ULL, 0ULL, 0ULL}),
+ BasicDecimal256({200376420520689664ULL, 5421ULL, 0ULL, 0ULL}),
+ BasicDecimal256({2003764205206896640ULL, 54210ULL, 0ULL, 0ULL}),
+ BasicDecimal256({1590897978359414784ULL, 542101ULL, 0ULL, 0ULL}),
+ BasicDecimal256({15908979783594147840ULL, 5421010ULL, 0ULL, 0ULL}),
+ BasicDecimal256({11515845246265065472ULL, 54210108ULL, 0ULL, 0ULL}),
+ BasicDecimal256({4477988020393345024ULL, 542101086ULL, 0ULL, 0ULL}),
+ BasicDecimal256({7886392056514347008ULL, 5421010862ULL, 0ULL, 0ULL}),
+ BasicDecimal256({5076944270305263616ULL, 54210108624ULL, 0ULL, 0ULL}),
+ BasicDecimal256({13875954555633532928ULL, 542101086242ULL, 0ULL, 0ULL}),
+ BasicDecimal256({9632337040368467968ULL, 5421010862427ULL, 0ULL, 0ULL}),
+ BasicDecimal256({4089650035136921600ULL, 54210108624275ULL, 0ULL, 0ULL}),
+ BasicDecimal256({4003012203950112768ULL, 542101086242752ULL, 0ULL, 0ULL}),
+ BasicDecimal256({3136633892082024448ULL, 5421010862427522ULL, 0ULL, 0ULL}),
+ BasicDecimal256({12919594847110692864ULL, 54210108624275221ULL, 0ULL, 0ULL}),
+ BasicDecimal256({68739955140067328ULL, 542101086242752217ULL, 0ULL, 0ULL}),
+ BasicDecimal256({687399551400673280ULL, 5421010862427522170ULL, 0ULL, 0ULL}),
+ BasicDecimal256({6873995514006732800ULL, 17316620476856118468ULL, 2ULL, 0ULL}),
+ BasicDecimal256({13399722918938673152ULL, 7145508105175220139ULL, 29ULL, 0ULL}),
+ BasicDecimal256({4870020673419870208ULL, 16114848830623546549ULL, 293ULL, 0ULL}),
+ BasicDecimal256({11806718586779598848ULL, 13574535716559052564ULL, 2938ULL, 0ULL}),
+ BasicDecimal256({7386721425538678784ULL, 6618148649623664334ULL, 29387ULL, 0ULL}),
+ BasicDecimal256({80237960548581376ULL, 10841254275107988496ULL, 293873ULL, 0ULL}),
+ BasicDecimal256({802379605485813760ULL, 16178822382532126880ULL, 2938735ULL, 0ULL}),
+ BasicDecimal256({8023796054858137600ULL, 14214271235644855872ULL, 29387358ULL, 0ULL}),
+ BasicDecimal256(
+ {6450984253743169536ULL, 13015503840481697412ULL, 293873587ULL, 0ULL}),
+ BasicDecimal256(
+ {9169610316303040512ULL, 1027829888850112811ULL, 2938735877ULL, 0ULL}),
+ BasicDecimal256(
+ {17909126868192198656ULL, 10278298888501128114ULL, 29387358770ULL, 0ULL}),
+ BasicDecimal256(
+ {13070572018536022016ULL, 10549268516463523069ULL, 293873587705ULL, 0ULL}),
+ BasicDecimal256(
+ {1578511669393358848ULL, 13258964796087472617ULL, 2938735877055ULL, 0ULL}),
+ BasicDecimal256(
+ {15785116693933588480ULL, 3462439444907864858ULL, 29387358770557ULL, 0ULL}),
+ BasicDecimal256(
+ {10277214349659471872ULL, 16177650375369096972ULL, 293873587705571ULL, 0ULL}),
+ BasicDecimal256(
+ {10538423128046960640ULL, 14202551164014556797ULL, 2938735877055718ULL, 0ULL}),
+ BasicDecimal256(
+ {13150510911921848320ULL, 12898303124178706663ULL, 29387358770557187ULL, 0ULL}),
+ BasicDecimal256(
+ {2377900603251621888ULL, 18302566799529756941ULL, 293873587705571876ULL, 0ULL}),
+ BasicDecimal256(
+ {5332261958806667264ULL, 17004971331911604867ULL, 2938735877055718769ULL, 0ULL}),
+ BasicDecimal256(
+ {16429131440647569408ULL, 4029016655730084128ULL, 10940614696847636083ULL, 1ULL}),
+ BasicDecimal256({16717361816799281152ULL, 3396678409881738056ULL,
+ 17172426599928602752ULL, 15ULL}),
+ BasicDecimal256({1152921504606846976ULL, 15520040025107828953ULL,
+ 5703569335900062977ULL, 159ULL}),
+ BasicDecimal256({11529215046068469760ULL, 7626447661401876602ULL,
+ 1695461137871974930ULL, 1593ULL}),
+ BasicDecimal256({4611686018427387904ULL, 2477500319180559562ULL,
+ 16954611378719749304ULL, 15930ULL}),
+ BasicDecimal256({9223372036854775808ULL, 6328259118096044006ULL,
+ 3525417123811528497ULL, 159309ULL}),
+ BasicDecimal256({0ULL, 7942358959831785217ULL, 16807427164405733357ULL, 1593091ULL}),
+ BasicDecimal256({0ULL, 5636613303479645706ULL, 2053574980671369030ULL, 15930919ULL}),
+ BasicDecimal256({0ULL, 1025900813667802212ULL, 2089005733004138687ULL, 159309191ULL}),
+ BasicDecimal256(
+ {0ULL, 10259008136678022120ULL, 2443313256331835254ULL, 1593091911ULL}),
+ BasicDecimal256(
+ {0ULL, 10356360998232463120ULL, 5986388489608800929ULL, 15930919111ULL}),
+ BasicDecimal256(
+ {0ULL, 11329889613776873120ULL, 4523652674959354447ULL, 159309191113ULL}),
+ BasicDecimal256(
+ {0ULL, 2618431695511421504ULL, 8343038602174441244ULL, 1593091911132ULL}),
+ BasicDecimal256(
+ {0ULL, 7737572881404663424ULL, 9643409726906205977ULL, 15930919111324ULL}),
+ BasicDecimal256(
+ {0ULL, 3588752519208427776ULL, 4200376900514301694ULL, 159309191113245ULL}),
+ BasicDecimal256(
+ {0ULL, 17440781118374726144ULL, 5110280857723913709ULL, 1593091911132452ULL}),
+ BasicDecimal256(
+ {0ULL, 8387114520361296896ULL, 14209320429820033867ULL, 15930919111324522ULL}),
+ BasicDecimal256(
+ {0ULL, 10084168908774762496ULL, 12965995782233477362ULL, 159309191113245227ULL}),
+ BasicDecimal256(
+ {0ULL, 8607968719199866880ULL, 532749306367912313ULL, 1593091911132452277ULL})};
+
#ifdef ARROW_USE_NATIVE_INT128
static constexpr uint64_t kInt64Mask = 0xFFFFFFFFFFFFFFFF;
#else
@@ -794,13 +900,13 @@ BasicDecimal128 operator%(const BasicDecimal128& left, const BasicDecimal128& ri
return remainder;
}
-static bool RescaleWouldCauseDataLoss(const BasicDecimal128& value, int32_t delta_scale,
- int32_t abs_delta_scale, BasicDecimal128* result) {
- BasicDecimal128 multiplier(ScaleMultipliers[abs_delta_scale]);
-
+template <class DecimalClass>
+static bool RescaleWouldCauseDataLoss(const DecimalClass& value, int32_t delta_scale,
+ const DecimalClass& multiplier,
+ DecimalClass* result) {
if (delta_scale < 0) {
DCHECK_NE(multiplier, 0);
- BasicDecimal128 remainder;
+ DecimalClass remainder;
auto status = value.Divide(multiplier, result, &remainder);
DCHECK_EQ(status, DecimalStatus::kSuccess);
return remainder != 0;
@@ -810,24 +916,23 @@ static bool RescaleWouldCauseDataLoss(const BasicDecimal128& value, int32_t delt
return (value < 0) ? *result > value : *result < value;
}
-DecimalStatus BasicDecimal128::Rescale(int32_t original_scale, int32_t new_scale,
- BasicDecimal128* out) const {
+template <class DecimalClass>
+DecimalStatus DecimalRescale(const DecimalClass& value, int32_t original_scale,
+ int32_t new_scale, DecimalClass* out) {
DCHECK_NE(out, nullptr);
if (original_scale == new_scale) {
- *out = *this;
+ *out = value;
return DecimalStatus::kSuccess;
}
const int32_t delta_scale = new_scale - original_scale;
const int32_t abs_delta_scale = std::abs(delta_scale);
- DCHECK_GE(abs_delta_scale, 1);
- DCHECK_LE(abs_delta_scale, 38);
+ DecimalClass multiplier = DecimalClass::GetScaleMultiplier(abs_delta_scale);
- BasicDecimal128 result(*this);
const bool rescale_would_cause_data_loss =
- RescaleWouldCauseDataLoss(result, delta_scale, abs_delta_scale, out);
+ RescaleWouldCauseDataLoss(value, delta_scale, multiplier, out);
// Fail if we overflow or truncate
if (ARROW_PREDICT_FALSE(rescale_would_cause_data_loss)) {
@@ -837,6 +942,11 @@ DecimalStatus BasicDecimal128::Rescale(int32_t original_scale, int32_t new_scale
return DecimalStatus::kSuccess;
}
+DecimalStatus BasicDecimal128::Rescale(int32_t original_scale, int32_t new_scale,
+ BasicDecimal128* out) const {
+ return DecimalRescale(*this, original_scale, new_scale, out);
+}
+
void BasicDecimal128::GetWholeAndFraction(int scale, BasicDecimal128* whole,
BasicDecimal128* fraction) const {
DCHECK_GE(scale, 0);
@@ -978,11 +1088,14 @@ DecimalStatus BasicDecimal256::Divide(const BasicDecimal256& divisor,
DecimalStatus BasicDecimal256::Rescale(int32_t original_scale, int32_t new_scale,
BasicDecimal256* out) const {
- if (original_scale == new_scale) {
- return DecimalStatus::kSuccess;
- }
- // TODO: implement.
- return DecimalStatus::kRescaleDataLoss;
+ return DecimalRescale(*this, original_scale, new_scale, out);
+}
+
+const BasicDecimal256& BasicDecimal256::GetScaleMultiplier(int32_t scale) {
+ DCHECK_GE(scale, 0);
+ DCHECK_LE(scale, 76);
+
+ return ScaleMultipliersDecimal256[scale];
}
BasicDecimal256 operator*(const BasicDecimal256& left, const BasicDecimal256& right) {
diff --git a/cpp/src/arrow/util/basic_decimal.h b/cpp/src/arrow/util/basic_decimal.h
index 54f5ca3..af58ee1 100644
--- a/cpp/src/arrow/util/basic_decimal.h
+++ b/cpp/src/arrow/util/basic_decimal.h
@@ -238,6 +238,9 @@ class ARROW_EXPORT BasicDecimal256 {
std::array<uint8_t, 32> ToBytes() const;
void ToBytes(uint8_t* out) const;
+ /// \brief Scale multiplier for given scale value.
+ static const BasicDecimal256& GetScaleMultiplier(int32_t scale);
+
/// \brief Convert BasicDecimal128 from one scale to another
DecimalStatus Rescale(int32_t original_scale, int32_t new_scale,
BasicDecimal256* out) const;
diff --git a/cpp/src/arrow/util/decimal_test.cc b/cpp/src/arrow/util/decimal_test.cc
index 197bf1a..06731be 100644
--- a/cpp/src/arrow/util/decimal_test.cc
+++ b/cpp/src/arrow/util/decimal_test.cc
@@ -897,7 +897,8 @@ TEST(Decimal128Test, TestToInteger) {
template <typename ArrowType, typename CType = typename ArrowType::c_type>
std::vector<CType> GetRandomNumbers(int32_t size) {
auto rand = random::RandomArrayGenerator(0x5487655);
- auto x_array = rand.Numeric<ArrowType>(size, 0, std::numeric_limits<CType>::max(), 0);
+ auto x_array = rand.Numeric<ArrowType>(size, static_cast<CType>(0),
+ std::numeric_limits<CType>::max(), 0);
auto x_ptr = x_array->data()->template GetValues<CType>(1);
std::vector<CType> ret;
@@ -985,6 +986,39 @@ TEST(Decimal128Test, Divide) {
}
}
+TEST(Decimal128Test, Rescale) {
+ ASSERT_OK_AND_EQ(Decimal128(11100), Decimal128(111).Rescale(0, 2));
+ ASSERT_OK_AND_EQ(Decimal128(111), Decimal128(11100).Rescale(2, 0));
+ ASSERT_OK_AND_EQ(Decimal128(5), Decimal128(500000).Rescale(6, 1));
+ ASSERT_OK_AND_EQ(Decimal128(500000), Decimal128(5).Rescale(1, 6));
+ ASSERT_RAISES(Invalid, Decimal128(555555).Rescale(6, 1));
+
+ // Test some random numbers.
+ for (auto original_scale : GetRandomNumbers<Int16Type>(16)) {
+ for (auto value : GetRandomNumbers<Int32Type>(16)) {
+ Decimal128 unscaled_value = Decimal128(value);
+ Decimal128 scaled_value = unscaled_value;
+ for (int32_t new_scale = original_scale; new_scale < original_scale + 29;
+ new_scale++, scaled_value *= Decimal128(10)) {
+ ASSERT_OK_AND_EQ(scaled_value, unscaled_value.Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(unscaled_value, scaled_value.Rescale(new_scale, original_scale));
+ }
+ }
+ }
+
+ for (auto original_scale : GetRandomNumbers<Int16Type>(16)) {
+ Decimal128 value(1);
+ for (int32_t new_scale = original_scale; new_scale < original_scale + 39;
+ new_scale++, value *= Decimal128(10)) {
+ Decimal128 negative_value = value * -1;
+ ASSERT_OK_AND_EQ(value, Decimal128(1).Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(negative_value, Decimal128(-1).Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(Decimal128(1), value.Rescale(new_scale, original_scale));
+ ASSERT_OK_AND_EQ(Decimal128(-1), negative_value.Rescale(new_scale, original_scale));
+ }
+ }
+}
+
TEST(Decimal128Test, Mod) {
ASSERT_EQ(Decimal128(234), Decimal128(20100) % Decimal128(301));
@@ -1352,6 +1386,39 @@ TEST(Decimal256Test, Divide) {
}
}
+TEST(Decimal256Test, Rescale) {
+ ASSERT_OK_AND_EQ(Decimal256(11100), Decimal256(111).Rescale(0, 2));
+ ASSERT_OK_AND_EQ(Decimal256(111), Decimal256(11100).Rescale(2, 0));
+ ASSERT_OK_AND_EQ(Decimal256(5), Decimal256(500000).Rescale(6, 1));
+ ASSERT_OK_AND_EQ(Decimal256(500000), Decimal256(5).Rescale(1, 6));
+ ASSERT_RAISES(Invalid, Decimal256(555555).Rescale(6, 1));
+
+ // Test some random numbers.
+ for (auto original_scale : GetRandomNumbers<Int16Type>(16)) {
+ for (auto value : GetRandomNumbers<Int32Type>(16)) {
+ Decimal256 unscaled_value = Decimal256(value);
+ Decimal256 scaled_value = unscaled_value;
+ for (int32_t new_scale = original_scale; new_scale < original_scale + 68;
+ new_scale++, scaled_value *= Decimal256(10)) {
+ ASSERT_OK_AND_EQ(scaled_value, unscaled_value.Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(unscaled_value, scaled_value.Rescale(new_scale, original_scale));
+ }
+ }
+ }
+
+ for (auto original_scale : GetRandomNumbers<Int16Type>(16)) {
+ Decimal256 value(1);
+ for (int32_t new_scale = original_scale; new_scale < original_scale + 77;
+ new_scale++, value *= Decimal256(10)) {
+ Decimal256 negative_value = value * -1;
+ ASSERT_OK_AND_EQ(value, Decimal256(1).Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(negative_value, Decimal256(-1).Rescale(original_scale, new_scale));
+ ASSERT_OK_AND_EQ(Decimal256(1), value.Rescale(new_scale, original_scale));
+ ASSERT_OK_AND_EQ(Decimal256(-1), negative_value.Rescale(new_scale, original_scale));
+ }
+ }
+}
+
class Decimal256ToStringTest : public ::testing::TestWithParam<ToStringTestParam> {};
TEST_P(Decimal256ToStringTest, ToString) {
diff --git a/python/pyarrow/tests/test_convert_builtin.py b/python/pyarrow/tests/test_convert_builtin.py
index 0690fe9..d35d44a 100644
--- a/python/pyarrow/tests/test_convert_builtin.py
+++ b/python/pyarrow/tests/test_convert_builtin.py
@@ -1481,10 +1481,9 @@ def test_sequence_decimal_large_integer():
def test_sequence_decimal_from_integers():
data = [0, 1, -39402950693754869342983]
expected = [decimal.Decimal(x) for x in data]
- # TODO: update this test after scaling implementation.
- type = pa.decimal128(precision=28, scale=5)
- arr = pa.array(data, type=type)
- assert arr.to_pylist() == expected
+ for type in [pa.decimal128, pa.decimal256]:
+ arr = pa.array(data, type=type(precision=28, scale=5))
+ assert arr.to_pylist() == expected
def test_sequence_decimal_too_high_precision():
diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py
index f516afd..6e17526 100644
--- a/python/pyarrow/tests/test_scalars.py
+++ b/python/pyarrow/tests/test_scalars.py
@@ -206,9 +206,8 @@ def test_decimal256():
v = decimal.Decimal("1.1234")
with pytest.raises(pa.ArrowInvalid):
pa.scalar(v, type=pa.decimal256(4, scale=3))
- # TODO: Add the following after implementing Decimal256 scaling.
- # with pytest.raises(pa.ArrowInvalid):
- # pa.scalar(v, type=pa.decimal256(5, scale=3))
+ with pytest.raises(pa.ArrowInvalid):
+ pa.scalar(v, type=pa.decimal256(5, scale=3))
s = pa.scalar(v, type=pa.decimal256(5, scale=4))
assert isinstance(s, pa.Decimal256Scalar)