Skip to content

Commit

Permalink
Added Suggested Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Isaac7777-cpu-school committed Oct 19, 2024
1 parent 3abe49c commit 121a2b2
Showing 1 changed file with 39 additions and 29 deletions.
68 changes: 39 additions & 29 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def test_option_class_equality(request):
for cls in exported_option_classes:
# Timezone database might not be installed on Windows or Emscripten
if (
cls not in classes
and (request.config.pyarrow.is_enabled["timezone_data"])
and cls != pc.AssumeTimezoneOptions
cls not in classes
and (request.config.pyarrow.is_enabled["timezone_data"])
and cls != pc.AssumeTimezoneOptions
):
try:
options.append(cls())
Expand Down Expand Up @@ -396,7 +396,7 @@ def test_sum_decimal_array(arrow_type):

arr = pa.array(
[decimal.Decimal('-0.1'), decimal.Decimal('-0.0'),
decimal.Decimal('0.1'), decimal.Decimal('0.2'), None], type=arrow_type)
decimal.Decimal('0.1'), decimal.Decimal('0.2'), None], type=arrow_type)
assert arr.sum().as_py() == decimal.Decimal('0.20')
assert pc.sum(arr).as_py() == decimal.Decimal('0.20')

Expand Down Expand Up @@ -437,6 +437,16 @@ def test_sum_chunked_decimal_array(arrow_type):
assert pc.sum(arr, min_count=0).as_py() == decimal.Decimal('0.00')


def test_decimal_compute():
a = pa.array([decimal.Decimal('1.232345'),
decimal.Decimal('1.25'), decimal.Decimal('-100.25')])
assert a.type == pa.decimal128(9, 6)
assert a.sum().as_py() == decimal.Decimal('-97.767655')
b = pa.array([1, 2, 3])
c = pc.multiply(a, b)
assert c.type == pa.decimal128(9, 6)


def test_mode_array():
# ARROW-9917
arr = pa.array([1, 1, 3, 4, 3, 5], type='int64')
Expand Down Expand Up @@ -1971,13 +1981,13 @@ def largest_scaled_float_not_above(val, scale):
"""
assert val >= 0
assert scale >= 0
float_val = float(val) / 10**scale
if float_val * 10**scale > val:
float_val = float(val) / 10 ** scale
if float_val * 10 ** scale > val:
# Take the float just below... it *should* satisfy
float_val = np.nextafter(float_val, 0.0)
if float_val * 10**scale > val:
if float_val * 10 ** scale > val:
float_val = np.nextafter(float_val, 0.0)
assert float_val * 10**scale <= val
assert float_val * 10 ** scale <= val
return float_val


Expand All @@ -2001,9 +2011,9 @@ def integral_float_to_decimal_cast_cases(float_ty, max_precision):
for scale in range(0, precision, 2):
yield FloatToDecimalCase(precision, scale, 0.0)
yield FloatToDecimalCase(precision, scale, 1.0)
epsilon = 10**max(precision - mantissa_digits, scale)
epsilon = 10 ** max(precision - mantissa_digits, scale)
abs_maxval = largest_scaled_float_not_above(
10**precision - epsilon, scale)
10 ** precision - epsilon, scale)
yield FloatToDecimalCase(precision, scale, abs_maxval)


Expand All @@ -2014,10 +2024,10 @@ def real_float_to_decimal_cast_cases(float_ty, max_precision):
mantissa_digits = 16
for precision in range(1, max_precision, 3):
for scale in range(0, precision, 2):
epsilon = 2 * 10**max(precision - mantissa_digits, 0)
epsilon = 2 * 10 ** max(precision - mantissa_digits, 0)
abs_minval = largest_scaled_float_not_above(epsilon, scale)
abs_maxval = largest_scaled_float_not_above(
10**precision - epsilon, scale)
10 ** precision - epsilon, scale)
yield FloatToDecimalCase(precision, scale, abs_minval)
yield FloatToDecimalCase(precision, scale, abs_maxval)

Expand All @@ -2030,9 +2040,9 @@ def random_float_to_decimal_cast_cases(float_ty, max_precision):
for precision in range(1, max_precision, 6):
for scale in range(0, precision, 4):
for i in range(20):
unscaled = r.randrange(0, 10**precision)
unscaled = r.randrange(0, 10 ** precision)
float_val = scaled_float(unscaled, scale)
assert float_val * 10**scale < 10**precision
assert float_val * 10 ** scale < 10 ** precision
yield FloatToDecimalCase(precision, scale, float_val)


Expand All @@ -2051,7 +2061,7 @@ def check_cast_float_to_decimal(float_ty, float_val, decimal_ty, decimal_ctx,
# Allow the last digit to vary. The tolerance is higher for
# very high precisions as rounding errors can accumulate in
# the iterative algorithm (GH-35576).
diff_digits = abs(actual - expected) * 10**decimal_ty.scale
diff_digits = abs(actual - expected) * 10 ** decimal_ty.scale
limit = 2 if decimal_ty.precision < max_precision - 1 else 4
assert diff_digits <= limit, (
f"float_val = {float_val!r}, precision={decimal_ty.precision}, "
Expand Down Expand Up @@ -2099,7 +2109,7 @@ def test_cast_float_to_decimal_random(float_ty, decimal_traits):
pa.float32(): (-126, 127),
pa.float64(): (-1022, 1023),
}[float_ty]
mantissa_digits = math.floor(math.log10(2**mantissa_bits))
mantissa_digits = math.floor(math.log10(2 ** mantissa_bits))
max_precision = decimal_traits.max_precision

with decimal.localcontext() as ctx:
Expand All @@ -2109,28 +2119,28 @@ def test_cast_float_to_decimal_random(float_ty, decimal_traits):
# 1) it's within bounds for the decimal type
# 2) the floating point exponent is within bounds
min_scale = max(-max_precision,
precision + math.ceil(math.log10(2**float_exp_min)))
precision + math.ceil(math.log10(2 ** float_exp_min)))
max_scale = min(max_precision,
math.floor(math.log10(2**float_exp_max)))
math.floor(math.log10(2 ** float_exp_max)))
for scale in range(min_scale, max_scale):
decimal_ty = decimal_traits.factory(precision, scale)
# We want to random-generate a float from its mantissa bits
# and exponent, and compute the expected value in the
# decimal domain. The float exponent has to ensure the
# expected value doesn't overflow and doesn't lose precision.
float_exp = (-mantissa_bits +
math.floor(math.log2(10**(precision - scale))))
math.floor(math.log2(10 ** (precision - scale))))
assert float_exp_min <= float_exp <= float_exp_max
for i in range(5):
mantissa = r.randrange(0, 2**mantissa_bits)
mantissa = r.randrange(0, 2 ** mantissa_bits)
float_val = np.ldexp(np_float_ty(mantissa), float_exp)
assert isinstance(float_val, np_float_ty)
# Make sure we compute the exact expected value and
# round by half-to-even when converting to the expected precision.
if float_exp >= 0:
expected = decimal.Decimal(mantissa) * 2**float_exp
expected = decimal.Decimal(mantissa) * 2 ** float_exp
else:
expected = decimal.Decimal(mantissa) / 2**-float_exp
expected = decimal.Decimal(mantissa) / 2 ** -float_exp
expected_as_int = round(expected.scaleb(scale))
actual = pc.cast(
pa.scalar(float_val, type=float_ty), decimal_ty).as_py()
Expand Down Expand Up @@ -2361,10 +2371,10 @@ def test_extract_datetime_components(request):
def test_iso_calendar_longer_array(unit):
# https://github.com/apache/arrow/issues/38655
# ensure correct result for array length > 32
arr = pa.array([datetime.datetime(2022, 1, 2, 9)]*50, pa.timestamp(unit))
arr = pa.array([datetime.datetime(2022, 1, 2, 9)] * 50, pa.timestamp(unit))
result = pc.iso_calendar(arr)
expected = pa.StructArray.from_arrays(
[[2021]*50, [52]*50, [7]*50],
[[2021] * 50, [52] * 50, [7] * 50],
names=['iso_year', 'iso_week', 'iso_day_of_week']
)
assert result.equals(expected)
Expand Down Expand Up @@ -2424,7 +2434,7 @@ def test_assume_timezone():

with pytest.raises(ValueError,
match="Timestamp doesn't exist in "
f"timezone '{timezone}'"):
f"timezone '{timezone}'"):
pc.assume_timezone(nonexistent_array,
options=options_nonexistent_raise)

Expand Down Expand Up @@ -3465,8 +3475,8 @@ def create_sample_expressions():
f, # Struct literals lose their field names
a.isin([1, 2, 3]), # isin converts to an or list
pc.field('i64').is_null() # pyarrow always specifies a FunctionOptions
# for is_null which, being the default, is
# dropped on serialization
# for is_null which, being the default, is
# dropped on serialization
]

all_exprs = literal_exprs.copy()
Expand All @@ -3484,6 +3494,7 @@ def create_sample_expressions():
"schema": schema
}


# Tests the Arrow-specific serialization mechanism


Expand All @@ -3498,7 +3509,6 @@ def test_expression_serialization_arrow(pickle_module):
@pytest.mark.numpy
@pytest.mark.substrait
def test_expression_serialization_substrait():

exprs = create_sample_expressions()
schema = exprs["schema"]

Expand Down Expand Up @@ -3661,7 +3671,7 @@ def test_list_slice_output_fixed(start, stop, step, expected, value_type,
else:
result = pc.list_slice(*args)
pylist = result.cast(pa.list_(pa.int8(),
result.type.list_size)).to_pylist()
result.type.list_size)).to_pylist()
assert pylist == [e[::step] if e else e for e in expected]


Expand Down

0 comments on commit 121a2b2

Please sign in to comment.