diff --git a/fbpcf/mpc_std_lib/oram/DummyDifferenceCalculator_impl.h b/fbpcf/mpc_std_lib/oram/DummyDifferenceCalculator_impl.h index 12095a0e..b35111c8 100644 --- a/fbpcf/mpc_std_lib/oram/DummyDifferenceCalculator_impl.h +++ b/fbpcf/mpc_std_lib/oram/DummyDifferenceCalculator_impl.h @@ -7,6 +7,7 @@ #pragma once +#include #include "fbpcf/mpc_std_lib/util/util.h" namespace fbpcf::mpc_std_lib::oram::insecure { @@ -61,7 +62,14 @@ DummyDifferenceCalculator::calculateDifferenceBatch( subtrahend = subtrahendShares.at(i) - subtrahend; T minuend = util::Adapters::convertFromBits(boolBuffer.at(i)); - rst[i] = indicator * (minuend - subtrahend); + if (indicator == 1) { + rst[i] = minuend - subtrahend; + } else if (indicator == -1) { + rst[i] = subtrahend - minuend; + } else { + throw std::runtime_error("invalid indicator!"); + } + } else { agent_->sendSingleT(indicatorShares.at(i)); agent_->sendSingleT(subtrahendShares.at(i)); diff --git a/fbpcf/mpc_std_lib/oram/WriteOnlyOram_impl.h b/fbpcf/mpc_std_lib/oram/WriteOnlyOram_impl.h index 963e0511..2dc7fafe 100644 --- a/fbpcf/mpc_std_lib/oram/WriteOnlyOram_impl.h +++ b/fbpcf/mpc_std_lib/oram/WriteOnlyOram_impl.h @@ -103,11 +103,12 @@ std::vector> WriteOnlyOram::generateMasks( auto difference = calculator_->calculateDifferenceBatch( indicatorShares, values, subtrahendShares); - for (size_t i = 0; i < batchSize; i++) { for (size_t j = 0; j < size_; j++) { - rst[i][j] = - rst[i][j] + indicatorKeyPairs.at(i).first.at(j) * difference.at(i); + bool indicator = indicatorKeyPairs.at(i).first.at(j); + if (indicator) { + rst[i][j] = rst[i][j] + difference.at(i); + } } } return rst; diff --git a/fbpcf/mpc_std_lib/oram/test/WriteOnlyORAMTest.cpp b/fbpcf/mpc_std_lib/oram/test/WriteOnlyORAMTest.cpp index 8bbf9818..59bdfbab 100644 --- a/fbpcf/mpc_std_lib/oram/test/WriteOnlyORAMTest.cpp +++ b/fbpcf/mpc_std_lib/oram/test/WriteOnlyORAMTest.cpp @@ -57,7 +57,7 @@ void testWriteOnlyOram( std::unique_ptr> factory0, std::unique_ptr> factory1, size_t oramSize) { - size_t batchSize = 16384; + size_t batchSize = oramSize * 30; auto [input0, input1, expectedValue] = util::generateRandomValuesToAdd(oramSize, batchSize); @@ -113,7 +113,7 @@ void runOramTestWithDummyComponents() { insecure::DummyDifferenceCalculatorFactory>( false, 0, *factories[1])); - size_t oramSize = 150; + size_t oramSize = 10; // use a smaller number due to performance issue. testWriteOnlyOram(std::move(factory0), std::move(factory1), oramSize); } diff --git a/fbpcf/mpc_std_lib/util/aggregationValue_impl.h b/fbpcf/mpc_std_lib/util/aggregationValue_impl.h index 16c8d500..47231992 100644 --- a/fbpcf/mpc_std_lib/util/aggregationValue_impl.h +++ b/fbpcf/mpc_std_lib/util/aggregationValue_impl.h @@ -54,18 +54,6 @@ inline AggregationValue operator-(const AggregationValue& v) { return AggregationValue{-v.conversionCount, -v.conversionValue}; } -inline AggregationValue operator*(int sign, const AggregationValue& v1) { - if (sign == 1) { - return v1; - } else if (sign == 0) { - return AggregationValue(0); - } else if (sign == -1) { - return -v1; - } else { - throw std::invalid_argument("can only multiply with -1, 0, 1"); - } -} - inline void operator+=(AggregationValue& v1, const AggregationValue& v2) { v1.conversionCount += v2.conversionCount; v1.conversionValue += v2.conversionValue;