diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index 23c07b8acb95f..d19d2db299cba 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -549,15 +549,16 @@ class BackpressureController : public BackpressureControl { class BackpressureHandler { private: - BackpressureHandler(size_t low_threshold, size_t high_threshold, + BackpressureHandler(ExecNode* input, size_t low_threshold, size_t high_threshold, std::unique_ptr backpressure_control) - : low_threshold_(low_threshold), + : input_(input), + low_threshold_(low_threshold), high_threshold_(high_threshold), backpressure_control_(std::move(backpressure_control)) {} public: static Result Make( - size_t low_threshold, size_t high_threshold, + ExecNode* input, size_t low_threshold, size_t high_threshold, std::unique_ptr backpressure_control) { if (low_threshold >= high_threshold) { return Status::Invalid("low threshold (", low_threshold, @@ -566,7 +567,7 @@ class BackpressureHandler { if (backpressure_control == NULLPTR) { return Status::Invalid("null backpressure control parameter"); } - BackpressureHandler backpressure_handler(low_threshold, high_threshold, + BackpressureHandler backpressure_handler(input, low_threshold, high_threshold, std::move(backpressure_control)); return std::move(backpressure_handler); } @@ -579,7 +580,16 @@ class BackpressureHandler { } } + Status ForceShutdown() { + // It may be unintuitive to call Resume() here, but this is to avoid a deadlock. + // Since acero's executor won't terminate if any one node is paused, we need to + // force resume the node before stopping production. + backpressure_control_->Resume(); + return input_->StopProducing(); + } + private: + ExecNode* input_; size_t low_threshold_; size_t high_threshold_; std::unique_ptr backpressure_control_; @@ -629,6 +639,8 @@ class BackpressureConcurrentQueue : public ConcurrentQueue { return ConcurrentQueue::TryPopUnlocked(); } + Status ForceShutdown() { return handler_.ForceShutdown(); } + private: BackpressureHandler handler_; }; @@ -672,9 +684,9 @@ class InputState { std::unique_ptr backpressure_control = std::make_unique( /*node=*/asof_input, /*output=*/asof_node, backpressure_counter); - ARROW_ASSIGN_OR_RAISE(auto handler, - BackpressureHandler::Make(low_threshold, high_threshold, - std::move(backpressure_control))); + ARROW_ASSIGN_OR_RAISE( + auto handler, BackpressureHandler::Make(asof_input, low_threshold, high_threshold, + std::move(backpressure_control))); return std::make_unique(index, tolerance, must_hash, may_rehash, key_hasher, asof_node, std::move(handler), schema, time_col_index, key_col_index); @@ -930,6 +942,12 @@ class InputState { total_batches_ = n; } + Status ForceShutdown() { + // Force the upstream input node to unpause. Necessary to avoid deadlock when we + // terminate the process thread + return queue_.ForceShutdown(); + } + private: // Pending record batches. The latest is the front. Batches cannot be empty. BackpressureConcurrentQueue> queue_; @@ -1323,6 +1341,9 @@ class AsofJoinNode : public ExecNode { if (st.ok()) { st = output_->InputFinished(this, batches_produced_); } + for (const auto& s : state_) { + st &= s->ForceShutdown(); + } })); } @@ -1679,6 +1700,15 @@ class AsofJoinNode : public ExecNode { const Ordering& ordering() const override { return ordering_; } Status InputReceived(ExecNode* input, ExecBatch batch) override { + // InputReceived may be called after execution was finished. Pushing it to the + // InputState is unnecessary since we're done (and anyway may cause the + // BackPressureController to pause the input, causing a deadlock), so drop it. + if (process_task_.is_finished()) { + DEBUG_SYNC(this, "Input received while done. Short circuiting.", + DEBUG_MANIP(std::endl)); + return Status::OK(); + } + // Get the input ARROW_DCHECK(std_has(inputs_, input)); size_t k = std_find(inputs_, input) - inputs_.begin(); @@ -1687,6 +1717,7 @@ class AsofJoinNode : public ExecNode { auto rb = *batch.ToRecordBatch(input->output_schema()); DEBUG_SYNC(this, "received batch from input ", k, ":", DEBUG_MANIP(std::endl), rb->ToString(), DEBUG_MANIP(std::endl)); + ARROW_RETURN_NOT_OK(state_.at(k)->Push(rb)); process_.Push(true); return Status::OK(); diff --git a/cpp/src/arrow/acero/asof_join_node_test.cc b/cpp/src/arrow/acero/asof_join_node_test.cc index 96c00e6a4bd59..df3172b2a09bc 100644 --- a/cpp/src/arrow/acero/asof_join_node_test.cc +++ b/cpp/src/arrow/acero/asof_join_node_test.cc @@ -1424,7 +1424,8 @@ AsyncGenerator> GetGen(BatchesWithSchema bws) { } template -void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) { +void TestBackpressure(BatchesMaker maker, int batch_size, int num_l_batches, + int num_r0_batches, int num_r1_batches, bool slow_r0) { auto l_schema = schema({field("time", int32()), field("key", int32()), field("l_value", int32())}); auto r0_schema = @@ -1432,16 +1433,17 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) { auto r1_schema = schema({field("time", int32()), field("key", int32()), field("r1_value", int32())}); - auto make_shift = [&maker, num_batches, batch_size]( - const std::shared_ptr& schema, int shift) { + auto make_shift = [&maker, batch_size](int num_batches, + const std::shared_ptr& schema, + int shift) { return maker({[](int row) -> int64_t { return row; }, [num_batches](int row) -> int64_t { return row / num_batches; }, [shift](int row) -> int64_t { return row * 10 + shift; }}, schema, num_batches, batch_size); }; - ASSERT_OK_AND_ASSIGN(auto l_batches, make_shift(l_schema, 0)); - ASSERT_OK_AND_ASSIGN(auto r0_batches, make_shift(r0_schema, 1)); - ASSERT_OK_AND_ASSIGN(auto r1_batches, make_shift(r1_schema, 2)); + ASSERT_OK_AND_ASSIGN(auto l_batches, make_shift(num_l_batches, l_schema, 0)); + ASSERT_OK_AND_ASSIGN(auto r0_batches, make_shift(num_r0_batches, r0_schema, 1)); + ASSERT_OK_AND_ASSIGN(auto r1_batches, make_shift(num_r1_batches, r1_schema, 2)); BackpressureCountingNode::Register(); RegisterTestNodes(); // for GatedNode @@ -1449,6 +1451,7 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) { struct BackpressureSourceConfig { std::string name_prefix; bool is_gated; + bool is_delayed; std::shared_ptr schema; decltype(l_batches) batches; @@ -1463,9 +1466,9 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) { // Two ungated and one gated std::vector source_configs = { - {"0", false, l_schema, l_batches}, - {"1", true, r0_schema, r0_batches}, - {"2", false, r1_schema, r1_batches}, + {"0", false, false, l_schema, l_batches}, + {"1", true, slow_r0, r0_schema, r0_batches}, + {"2", false, false, r1_schema, r1_batches}, }; std::vector bp_counters(source_configs.size()); @@ -1474,9 +1477,16 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) { std::vector bp_decls; for (size_t i = 0; i < source_configs.size(); i++) { const auto& config = source_configs[i]; - - src_decls.emplace_back("source", - SourceNodeOptions(config.schema, GetGen(config.batches))); + if (config.is_delayed) { + src_decls.emplace_back( + "source", + SourceNodeOptions(config.schema, MakeDelayedGen(config.batches, "slow_source", + /*delay_sec=*/0.5, + /*noisy=*/false))); + } else { + src_decls.emplace_back("source", + SourceNodeOptions(config.schema, GetGen(config.batches))); + } bp_options.push_back( std::make_shared(&bp_counters[i])); std::shared_ptr options = bp_options.back(); @@ -1486,11 +1496,12 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) { if (config.is_gated) { bp_decl = {std::string{GatedNodeOptions::kName}, {bp_decl}, gate_options}; } - bp_decls.push_back(bp_decl); + bp_decls.emplace_back(bp_decl); } - Declaration asofjoin = {"asofjoin", bp_decls, - GetRepeatedOptions(source_configs.size(), "time", {"key"}, 0)}; + auto opts = GetRepeatedOptions(source_configs.size(), "time", {"key"}, 0); + + Declaration asofjoin = {"asofjoin", bp_decls, opts}; ASSERT_OK_AND_ASSIGN(std::shared_ptr tpool, internal::ThreadPool::Make(1)); @@ -1512,14 +1523,14 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) { return true; }; - BusyWait(10.0, has_bp_been_applied); + BusyWait(60.0, has_bp_been_applied); ASSERT_TRUE(has_bp_been_applied()); gate.ReleaseAllBatches(); ASSERT_FINISHES_OK_AND_ASSIGN(BatchesWithCommonSchema batches, batches_fut); - // One of the inputs is gated. The other two will eventually be resumed by the asof - // join node + // One of the inputs is gated and was released. The other two will eventually be resumed + // by the asof join node for (size_t i = 0; i < source_configs.size(); i++) { const auto& counters = bp_counters[i]; if (!source_configs[i].is_gated) { @@ -1529,7 +1540,9 @@ void TestBackpressure(BatchesMaker maker, int num_batches, int batch_size) { } TEST(AsofJoinTest, BackpressureWithBatches) { - return TestBackpressure(MakeIntegerBatches, /*num_batches=*/20, /*batch_size=*/1); + // Give the first right hand table a delay to stress test race conditions + return TestBackpressure(MakeIntegerBatches, /*batch_size=*/1, /*num_l_batches=*/20, + /*num_r0_batches=*/50, /*num_r1_batches=*/20, /*slow_r0=*/true); } template @@ -1595,7 +1608,10 @@ TEST(AsofJoinTest, BackpressureWithBatchesGen) { GTEST_SKIP() << "Skipping - see GH-36331"; int num_batches = GetEnvValue("ARROW_BACKPRESSURE_DEMO_NUM_BATCHES", 20); int batch_size = GetEnvValue("ARROW_BACKPRESSURE_DEMO_BATCH_SIZE", 1); - return TestBackpressure(MakeIntegerBatchGenForTest, num_batches, batch_size); + return TestBackpressure(MakeIntegerBatchGenForTest, /*batch_size=*/batch_size, + /*num_l_batches=*/num_batches, + /*num_r0_batches=*/num_batches, /*num_r1_batches=*/num_batches, + /*slow_r0=*/false); } } // namespace acero