Skip to content

Commit

Permalink
Merge pull request #339 from cppalliance/fma
Browse files Browse the repository at this point in the history
  • Loading branch information
mborland authored Nov 20, 2023
2 parents 5037e2c + 4c200e4 commit 270204a
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 14 deletions.
6 changes: 6 additions & 0 deletions include/boost/decimal/cmath.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#define BOOST_DECIMAL_DEC_NAN std::numeric_limits<boost::decimal::decimal64>::signaling_NaN()
#define BOOST_DECIMAL_FP_FAST_FMAD32 1
#define BOOST_DECIMAL_FP_FAST_FMAD64 1
#define BOOST_DECIMAL_FP_FAST_FMAD128 1

namespace boost { namespace decimal {

Expand Down Expand Up @@ -119,6 +120,11 @@ constexpr auto fma(decimal64 x, decimal64 y, decimal64 z) noexcept -> decimal64
return fmad64(x, y, z);
}

constexpr auto fma(decimal128 x, decimal128 y, decimal128 z) noexcept -> decimal128
{
return fmad128(x, y, z);
}

constexpr auto samequantum(decimal32 lhs, decimal32 rhs) noexcept -> bool
{
return samequantumd32(lhs, rhs);
Expand Down
102 changes: 98 additions & 4 deletions include/boost/decimal/decimal128.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ class decimal128 final
friend constexpr auto copysignd128(decimal128 mag, decimal128 sgn) noexcept -> decimal128;
friend constexpr auto scalblnd128(decimal128 num, long exp) noexcept -> decimal128;
friend constexpr auto scalbnd128(decimal128 num, int exp) noexcept -> decimal128;
friend constexpr auto fmad128(decimal128 x, decimal128 y, decimal128 z) noexcept -> decimal128;
};

#if !defined(BOOST_DECIMAL_DISABLE_IOSTREAM)
Expand Down Expand Up @@ -1802,11 +1803,21 @@ constexpr auto operator*(decimal128 lhs, decimal128 rhs) noexcept -> decimal128

auto lhs_sig {lhs.full_significand()};
auto lhs_exp {lhs.biased_exponent()};
detail::normalize<decimal128>(lhs_sig, lhs_exp);

while (lhs_sig % 10 == 0 && lhs_sig != 0)
{
lhs_sig /= 10;
++lhs_exp;
}

auto rhs_sig {rhs.full_significand()};
auto rhs_exp {rhs.biased_exponent()};
detail::normalize<decimal128>(rhs_sig, rhs_exp);

while (rhs_sig % 10 == 0 && rhs_sig != 0)
{
rhs_sig /= 10;
++rhs_exp;
}

const auto result {d128_mul_impl(lhs_sig, lhs_exp, lhs.isneg(),
rhs_sig, rhs_exp, rhs.isneg())};
Expand All @@ -1825,12 +1836,20 @@ constexpr auto operator*(decimal128 lhs, Integer rhs) noexcept

auto lhs_sig {lhs.full_significand()};
auto lhs_exp {lhs.biased_exponent()};
detail::normalize<decimal128>(lhs_sig, lhs_exp);
while (lhs_sig % 10 == 0 && lhs_sig != 0)
{
lhs_sig /= 10;
++lhs_exp;
}
auto lhs_components {detail::decimal128_components{lhs_sig, lhs_exp, lhs.isneg()}};

auto rhs_sig {static_cast<detail::uint128>(detail::make_positive_unsigned(rhs))};
std::int32_t rhs_exp {0};
detail::normalize<decimal128>(rhs_sig, rhs_exp);
while (rhs_sig % 10 == 0 && rhs_sig != 0)
{
rhs_sig /= 10;
++rhs_exp;
}
auto unsigned_sig_rhs {detail::make_positive_unsigned(rhs_sig)};
auto rhs_components {detail::decimal128_components{unsigned_sig_rhs, rhs_exp, (rhs < 0)}};

Expand Down Expand Up @@ -2260,6 +2279,81 @@ constexpr auto scalbnd128(decimal128 num, int expval) noexcept -> decimal128
return scalblnd128(num, static_cast<long>(expval));
}

constexpr auto fmad128(decimal128 x, decimal128 y, decimal128 z) noexcept -> decimal128
{
// First calculate x * y without rounding
constexpr decimal128 zero {0, 0};

const auto res {detail::check_non_finite(x, y)};
if (res != zero)
{
return res;
}

auto sig_lhs {x.full_significand()};
auto exp_lhs {x.biased_exponent()};

while (sig_lhs % 10 == 0 && sig_lhs != 0)
{
sig_lhs /= 10;
++exp_lhs;
}

auto sig_rhs {y.full_significand()};
auto exp_rhs {y.biased_exponent()};

while (sig_rhs % 10 == 0 && sig_rhs != 0)
{
sig_rhs /= 10;
++exp_rhs;
}

auto mul_result {d128_mul_impl(sig_lhs, exp_lhs, x.isneg(), sig_rhs, exp_rhs, y.isneg())};
const decimal128 dec_result {mul_result.sig, mul_result.exp, mul_result.sign};

const auto res_add {detail::check_non_finite(dec_result, z)};
if (res_add != zero)
{
return res_add;
}

bool lhs_bigger {dec_result > z};
if (dec_result.isneg() && z.isneg())
{
lhs_bigger = !lhs_bigger;
}
bool abs_lhs_bigger {abs(dec_result) > abs(z)};

detail::normalize<decimal128>(mul_result.sig, mul_result.exp);

auto sig_z {z.full_significand()};
auto exp_z {z.biased_exponent()};
detail::normalize<decimal128>(sig_z, exp_z);
detail::decimal128_components z_components {sig_z, exp_z, z.isneg()};

if (!lhs_bigger)
{
detail::swap(mul_result, z_components);
abs_lhs_bigger = !abs_lhs_bigger;
}

detail::decimal128_components result {};

if (!mul_result.sign && z_components.sign)
{
result = d128_sub_impl(mul_result.sig, mul_result.exp, mul_result.sign,
z_components.sig, z_components.exp, z_components.sign,
abs_lhs_bigger);
}
else
{
result = d128_add_impl(mul_result.sig, mul_result.exp, mul_result.sign,
z_components.sig, z_components.exp, z_components.sign);
}

return {result.sig, result.exp, result.sign};
}

} //namespace decimal
} //namespace boost

Expand Down
34 changes: 26 additions & 8 deletions include/boost/decimal/detail/emulated256.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,17 +420,35 @@ constexpr uint256_t operator%(uint256_t lhs, std::uint64_t rhs) noexcept
}

// Get the 256-bit result of multiplication of two 128-bit unsigned integers
constexpr uint256_t umul256_impl(std::uint64_t a, std::uint64_t b, std::uint64_t c, std::uint64_t d) noexcept
constexpr uint256_t umul256_impl(std::uint64_t a_high, std::uint64_t a_low, std::uint64_t b_high, std::uint64_t b_low) noexcept
{
const auto ac = umul128(a, c);
const auto bc = umul128(b, c);
const auto ad = umul128(a, d);
const auto bd = umul128(b, d);
const auto low_product {static_cast<uint128>(a_low) * b_low};
const auto mid_product1 {static_cast<uint128>(a_low) * b_high};
const auto mid_product2 {static_cast<uint128>(a_high) * b_low};
const auto high_product {static_cast<uint128>(a_high) * b_high};

const auto intermediate = (bd >> 64) + static_cast<std::uint64_t>(ad) + static_cast<std::uint64_t>(bc);
uint128 carry {};

return {ac + (intermediate >> 64) + (ad >> 64) + (bc >> 64),
(intermediate << 64) + static_cast<std::uint64_t>(bd)};
const auto mid_combined {mid_product1 + mid_product2};
if (mid_combined < mid_product1)
{
carry = 1;
}

const auto mid_combined_high {mid_combined >> 64};
const auto mid_combined_low {mid_combined << 64};

const auto low_sum {low_product + mid_combined_low};
if (low_sum < low_product)
{
carry += 1;
}

uint256_t result {};
result.low = low_sum;
result.high = high_product + mid_combined_high + carry;

return result;
}

template<typename T>
Expand Down
11 changes: 9 additions & 2 deletions test/test_cmath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,14 @@ void test_copysign()
template <typename Dec>
void test_fma()
{
if (!BOOST_TEST_EQ(Dec(1, -1) * Dec(1, 1), Dec(1, 0)))
{
std::cerr << std::setprecision(std::numeric_limits<Dec>::digits10)
<< " Mul: " << Dec(1, -1) * Dec(1, 1)
<< "\nActual: " << Dec(1, 0) << std::endl;
}

BOOST_TEST_EQ(Dec(1, 0) + Dec(1, 0, true), Dec(0, 0));
BOOST_TEST_EQ(fma(Dec(1, -1), Dec(1, 1), Dec(1, 0, true)), Dec(0, 0));

std::uniform_real_distribution<double> dist(-1e10, 1e10);
Expand Down Expand Up @@ -1382,15 +1390,14 @@ int main()
test_copysign<decimal32>();
test_copysign<decimal64>();

#if (defined(__clang__) || defined(_MSC_VER) || !defined(__GNUC__) || (defined(__GNUC__) && __GNUC__ > 6))
test_fma<decimal32>();
test_fma<decimal64>();
test_fma<decimal128>();

test_sin<decimal32>();
test_cos<decimal32>();
test_sin<decimal64>();
test_cos<decimal64>();
#endif

test_modf<decimal32>();
test_modf<decimal64>();
Expand Down

0 comments on commit 270204a

Please sign in to comment.