diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index 9c7ac58c5fbc..6551d60d3122 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -89,10 +89,14 @@ void MergeJoin::initialize() { initializeFilter(joinNode_->filter(), leftType, rightType); if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || - joinNode_->isRightJoin() || joinNode_->isFullJoin()) { + joinNode_->isRightJoin() || joinNode_->isFullJoin() || + joinNode_->isLeftSemiFilterJoin() || + joinNode_->isRightSemiFilterJoin()) { joinTracker_ = JoinTracker(outputBatchSize_, pool()); } - } else if (joinNode_->isAntiJoin()) { + } else if ( + joinNode_->isAntiJoin() || joinNode_->isLeftSemiFilterJoin() || + joinNode_->isRightSemiFilterJoin() || joinNode_->isFullJoin()) { // Anti join needs to track the left side rows that have no match on the // right. joinTracker_ = JoinTracker(outputBatchSize_, pool()); @@ -338,7 +342,8 @@ void MergeJoin::addOutputRow( const RowVectorPtr& left, vector_size_t leftIndex, const RowVectorPtr& right, - vector_size_t rightIndex) { + vector_size_t rightIndex, + bool isRightJoinForFullOuter) { // All left side projections share the same dictionary indices (leftIndices_). rawLeftIndices_[outputSize_] = leftIndex; @@ -358,22 +363,33 @@ void MergeJoin::addOutputRow( copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_); if (joinTracker_) { - if (isRightJoin(joinType_)) { + if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_) || + (isFullJoin(joinType_) && isRightJoinForFullOuter)) { // Record right-side row with a match on the left-side. - joinTracker_->addMatch(right, rightIndex, outputSize_); + joinTracker_->addMatch( + right, rightIndex, outputSize_, isRightJoinForFullOuter); } else { // Record left-side row with a match on the right-side. - joinTracker_->addMatch(left, leftIndex, outputSize_); + joinTracker_->addMatch( + left, leftIndex, outputSize_, isRightJoinForFullOuter); } } - } - - // Anti join needs to track the left side rows that have no match on the - // right. - if (isAntiJoin(joinType_)) { + } else if ( + isAntiJoin(joinType_) || isLeftSemiFilterJoin(joinType_) || + (isFullJoin(joinType_) && !isRightJoinForFullOuter)) { + // Anti join needs to track the left side rows that have no match on the + // right. + VELOX_CHECK(joinTracker_); + // Record left-side row with a match on the right-side. + joinTracker_->addMatch( + left, leftIndex, outputSize_, isRightJoinForFullOuter); + } else if ( + isRightSemiFilterJoin(joinType_) || + (isFullJoin(joinType_) && isRightJoinForFullOuter)) { VELOX_CHECK(joinTracker_); // Record left-side row with a match on the right-side. - joinTracker_->addMatch(left, leftIndex, outputSize_); + joinTracker_->addMatch( + right, rightIndex, outputSize_, isRightJoinForFullOuter); } ++outputSize_; @@ -390,14 +406,14 @@ bool MergeJoin::prepareOutput( return true; } - if (isRightJoin(joinType_) && right != currentRight_) { - return true; - } - // If there is a new right, we need to flatten the dictionary. if (!isRightFlattened_ && right && currentRight_ != right) { flattenRightProjections(); } + + if (right != currentRight_) { + return true; + } return false; } @@ -509,6 +525,39 @@ bool MergeJoin::prepareOutput( bool MergeJoin::addToOutput() { if (isRightJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { return addToOutputForRightJoin(); + } else if (isFullJoin(joinType_) && filter_) { + if (!leftForRightJoinMatch_) { + leftForRightJoinMatch_ = leftMatch_; + rightForRightJoinMatch_ = rightMatch_; + } + + if (leftMatch_ && rightMatch_ && !leftJoinForFullFinished_) { + auto left = addToOutputForLeftJoin(); + if (!leftMatch_) { + leftJoinForFullFinished_ = true; + } + if (left) { + if (!leftMatch_) { + leftMatch_ = leftForRightJoinMatch_; + rightMatch_ = rightForRightJoinMatch_; + } + + return true; + } + } + + if (!leftMatch_ && !rightJoinForFullFinished_) { + leftMatch_ = leftForRightJoinMatch_; + rightMatch_ = rightForRightJoinMatch_; + rightJoinForFullFinished_ = true; + } + + auto right = addToOutputForRightJoin(); + + leftForRightJoinMatch_ = leftMatch_; + rightForRightJoinMatch_ = rightMatch_; + + return right; } else { return addToOutputForLeftJoin(); } @@ -560,7 +609,7 @@ bool MergeJoin::addToOutputForLeftJoin() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isLeftSemiFilterJoin(joinType_)) { + if (isLeftSemiFilterJoin(joinType_) && !filter_) { // LeftSemiFilter produce each row from the left at most once. rightEnd = rightStart + 1; } @@ -578,6 +627,10 @@ bool MergeJoin::addToOutputForLeftJoin() { } addOutputRow(left, i, right, j); } + + if (isLeftSemiFilterJoin(joinType_) && !filter_) { + break; + } } } } @@ -643,7 +696,7 @@ bool MergeJoin::addToOutputForRightJoin() { // one match on the other side, we could explore specialized algorithms // or data structures that short-circuit the join process once a match // is found. - if (isRightSemiFilterJoin(joinType_)) { + if (isRightSemiFilterJoin(joinType_) && !filter_) { // RightSemiFilter produce each row from the right at most once. leftEnd = leftStart + 1; } @@ -659,7 +712,16 @@ bool MergeJoin::addToOutputForRightJoin() { leftMatch_->setCursor(l, j); return true; } - addOutputRow(left, j, right, i); + + if (isFullJoin(joinType_)) { + addOutputRow(left, j, right, i, true); + } else { + addOutputRow(left, j, right, i); + } + } + + if (isRightSemiFilterJoin(joinType_) && !filter_) { + break; } } } @@ -698,6 +760,37 @@ vector_size_t firstNonNull( } } // namespace +RowVectorPtr MergeJoin::filterOutputForSemiJoin(const RowVectorPtr& output) { + const auto numRows = output->size(); + const auto& matchedRows = joinTracker_->matchingRows(numRows); + const auto numPassed = matchedRows.countSelected(); + if (numPassed == 0) { + return nullptr; + } + + BufferPtr indices = allocateIndices(numPassed, pool()); + auto* rawIndices = indices->asMutable(); + size_t index{0}; + + // If all matches for a given left-side row fail the filter, add a row to + // the output with nulls for the right-side columns. + auto onMiss = [&](auto row, bool flag) {}; + + auto onMatch = [&](auto row) { + if (isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { + rawIndices[index++] = row; + } + }; + for (auto i = 0; i < numRows; ++i) { + if (matchedRows.isValid(i)) { + joinTracker_->processFilterResult(i, true, onMiss, onMatch); + } + } + + // Some, but not all rows passed. + return wrap(index, indices, output); +} + RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) { const auto numRows = output->size(); const auto& filterRows = joinTracker_->matchingRows(numRows); @@ -750,7 +843,16 @@ RowVectorPtr MergeJoin::getOutput() { continue; } else if (isAntiJoin(joinType_)) { output = filterOutputForAntiJoin(output); - if (output) { + if (output != nullptr && output->size() > 0) { + return output; + } + + // No rows survived the filter for anti join. Get more rows. + continue; + } else if ( + isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { + output = filterOutputForSemiJoin(output); + if (output != nullptr && output->size() > 0) { return output; } @@ -776,22 +878,19 @@ RowVectorPtr MergeJoin::getOutput() { } if (rightInput_) { - if (isFullJoin(joinType_)) { - rightIndex_ = 0; - } else { - auto firstNonNullIndex = firstNonNull(rightInput_, rightKeys_); - if (isRightJoin(joinType_) && firstNonNullIndex > 0) { - prepareOutput(nullptr, rightInput_); - for (auto i = 0; i < firstNonNullIndex; ++i) { - addOutputRowForRightJoin(rightInput_, i); - } - } - rightIndex_ = firstNonNullIndex; - if (finishedRightBatch()) { - // Ran out of rows on the right side. - rightInput_ = nullptr; + auto firstNonNullIndex = firstNonNull(rightInput_, rightKeys_); + if ((isRightJoin(joinType_) || isFullJoin(joinType_)) && + firstNonNullIndex > 0) { + prepareOutput(nullptr, rightInput_); + for (auto i = 0; i < firstNonNullIndex; ++i) { + addOutputRowForRightJoin(rightInput_, i); } } + rightIndex_ = firstNonNullIndex; + if (finishedRightBatch()) { + // Ran out of rows on the right side. + rightInput_ = nullptr; + } } else { noMoreRightInput_ = true; } @@ -813,6 +912,8 @@ RowVectorPtr MergeJoin::doGetOutput() { // results from the current match. if (addToOutput()) { return std::move(output_); + } else { + previousLeftMatch_ = leftMatch_; } } @@ -872,6 +973,8 @@ RowVectorPtr MergeJoin::doGetOutput() { if (addToOutput()) { return std::move(output_); + } else { + previousLeftMatch_ = leftMatch_; } } @@ -1009,7 +1112,7 @@ RowVectorPtr MergeJoin::doGetOutput() { isFullJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. - if (prepareOutput(input_, nullptr)) { + if (prepareOutput(input_, rightInput_)) { output_->resize(outputSize_); return std::move(output_); } @@ -1035,7 +1138,7 @@ RowVectorPtr MergeJoin::doGetOutput() { if (isRightJoin(joinType_) || isFullJoin(joinType_)) { // If output_ is currently wrapping a different buffer, return it // first. - if (prepareOutput(nullptr, rightInput_)) { + if (prepareOutput(input_, rightInput_)) { output_->resize(outputSize_); return std::move(output_); } @@ -1084,6 +1187,8 @@ RowVectorPtr MergeJoin::doGetOutput() { endRightIndex < rightInput_->size(), std::nullopt}; + leftJoinForFullFinished_ = false; + rightJoinForFullFinished_ = false; if (!leftMatch_->complete || !rightMatch_->complete) { if (!leftMatch_->complete) { // Need to continue looking for the end of match. @@ -1098,6 +1203,7 @@ RowVectorPtr MergeJoin::doGetOutput() { } index_ = endIndex; + if (isFullJoin(joinType_)) { rightIndex_ = endRightIndex; } else { @@ -1111,6 +1217,8 @@ RowVectorPtr MergeJoin::doGetOutput() { if (addToOutput()) { return std::move(output_); + } else { + previousLeftMatch_ = leftMatch_; } if (!rightInput_) { @@ -1127,8 +1235,6 @@ RowVectorPtr MergeJoin::doGetOutput() { RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const auto numRows = output->size(); - RowVectorPtr fullOuterOutput = nullptr; - BufferPtr indices = allocateIndices(numRows, pool()); auto rawIndices = indices->asMutable(); vector_size_t numPassed = 0; @@ -1145,69 +1251,23 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // If all matches for a given left-side row fail the filter, add a row to // the output with nulls for the right-side columns. - auto onMiss = [&](auto row) { - if (!isAntiJoin(joinType_)) { + auto onMiss = [&](auto row, bool flag) { + if (!isLeftSemiFilterJoin(joinType_) && + !isRightSemiFilterJoin(joinType_)) { rawIndices[numPassed++] = row; - if (isFullJoin(joinType_)) { - // For filtered rows, it is necessary to insert additional data - // to ensure the result set is complete. Specifically, we - // need to generate two records: one record containing the - // columns from the left table along with nulls for the - // right table, and another record containing the columns - // from the right table along with nulls for the left table. - // For instance, the current output is filtered based on the condition - // t > 1. - - // 1, 1 - // 2, 2 - // 3, 3 - - // In this scenario, we need to additionally insert a record 1, 1. - // Subsequently, we will set the values of the columns on the left to - // null and the values of the columns on the right to null as well. By - // doing so, we will obtain the final result set. - - // 1, null - // null, 1 - // 2, 2 - // 3, 3 - fullOuterOutput = BaseVector::create( - output->type(), output->size() + 1, pool()); - - for (auto i = 0; i < row + 1; i++) { - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), i, i, 1); + if (!isRightJoin(joinType_)) { + if (isFullJoin(joinType_) && flag) { + for (auto& projection : leftProjections_) { + auto target = output->childAt(projection.outputChannel); + target->setNull(row, true); } - } - - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), row + 1, row, 1); - } - - for (auto i = row + 1; i < output->size(); i++) { - for (auto j = 0; j < output->type()->size(); j++) { - fullOuterOutput->childAt(j)->copy( - output->childAt(j).get(), i + 1, i, 1); + } else { + for (auto& projection : rightProjections_) { + auto target = output->childAt(projection.outputChannel); + target->setNull(row, true); } } - - for (auto& projection : leftProjections_) { - auto target = fullOuterOutput->childAt(projection.outputChannel); - target->setNull(row, true); - } - - for (auto& projection : rightProjections_) { - auto target = fullOuterOutput->childAt(projection.outputChannel); - target->setNull(row + 1, true); - } - } else if (!isRightJoin(joinType_)) { - for (auto& projection : rightProjections_) { - auto target = output->childAt(projection.outputChannel); - target->setNull(row, true); - } } else { for (auto& projection : leftProjections_) { auto target = output->childAt(projection.outputChannel); @@ -1217,19 +1277,22 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { } }; + auto onMatch = [&](auto row) { + if (isLeftSemiFilterJoin(joinType_) || isRightSemiFilterJoin(joinType_)) { + rawIndices[numPassed++] = row; + } + }; + for (auto i = 0; i < numRows; ++i) { if (filterRows.isValid(i)) { const bool passed = !decodedFilterResult_.isNullAt(i) && decodedFilterResult_.valueAt(i); - joinTracker_->processFilterResult(i, passed, onMiss); + joinTracker_->processFilterResult(i, passed, onMiss, onMatch); - if (isAntiJoin(joinType_)) { - if (!passed) { - rawIndices[numPassed++] = i; - } - } else { - if (passed) { + if (!isAntiJoin(joinType_) && !isLeftSemiFilterJoin(joinType_) && + !isRightSemiFilterJoin(joinType_)) { + if (passed && !joinTracker_->isRightJoinForFullOuter(i)) { rawIndices[numPassed++] = i; } } @@ -1242,19 +1305,19 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // Every time we start a new left key match, `processFilterResult()` will // check if at least one row from the previous match passed the filter. If - // none did, it calls onMiss to add a record with null right projections to - // the output. + // none did, it calls onMiss to add a record with null right projections + // to the output. // // Before we leave the current buffer, since we may not have seen the next - // left key match yet, the last key match may still be pending to produce a - // row (because `processFilterResult()` was not called yet). + // left key match yet, the last key match may still be pending to produce + // a row (because `processFilterResult()` was not called yet). // // To handle this, we need to call `noMoreFilterResults()` unless the - // same current left key match may continue in the next buffer. So there are - // two cases to check: + // same current left key match may continue in the next buffer. So there + // are two cases to check: // - // 1. If leftMatch_ is nullopt, there for sure the next buffer will contain - // a different key match. + // 1. If leftMatch_ is nullopt, there for sure the next buffer will + // contain a different key match. // // 2. leftMatch_ may not be nullopt, but may be related to a different // (subsequent) left key. So we check if the last row in the batch has the @@ -1262,6 +1325,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch(numRows - 1)) { joinTracker_->noMoreFilterResults(onMiss); } + + if (leftMatch_ && !previousLeftMatch_) { + joinTracker_->noMoreFilterResults(onMiss); + } } else { filterRows_.resize(numRows); filterRows_.setAll(); @@ -1283,17 +1350,10 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (numPassed == numRows) { // All rows passed. - if (fullOuterOutput) { - return fullOuterOutput; - } return output; } // Some, but not all rows passed. - if (fullOuterOutput) { - return wrap(numPassed, indices, fullOuterOutput); - } - return wrap(numPassed, indices, output); } diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index a429d270199e..d62768efb7c8 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -207,7 +207,8 @@ class MergeJoin : public Operator { const RowVectorPtr& left, vector_size_t leftIndex, const RowVectorPtr& right, - vector_size_t rightIndex); + vector_size_t rightIndex, + bool isRightJoinForFullOuter = false); // If the right side projected columns in the current output vector happen to // span more than one vector from the right side, they cannot be simply @@ -267,6 +268,8 @@ class MergeJoin : public Operator { /// rows from the left side that have a match on the right. RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output); + RowVectorPtr filterOutputForSemiJoin(const RowVectorPtr& output); + /// As we populate the results of the join, we track whether a given /// output row is a result of a match between left and right sides or a miss. /// We use JoinTracker::addMatch and addMiss methods for that. @@ -297,6 +300,9 @@ class MergeJoin : public Operator { : matchingRows_{numRows, false} { leftRowNumbers_ = AlignedBuffer::allocate(numRows, pool); rawLeftRowNumbers_ = leftRowNumbers_->asMutable(); + + rightJoinRows_ = AlignedBuffer::allocate(numRows, pool); + rawRightJoinRows_ = rightJoinRows_->asMutable(); } /// Records a row of output that corresponds to a match between a left-side @@ -307,7 +313,8 @@ class MergeJoin : public Operator { void addMatch( const VectorPtr& left, vector_size_t leftIndex, - vector_size_t outputIndex) { + vector_size_t outputIndex, + bool rightJoinForFullOuter = false) { matchingRows_.setValid(outputIndex, true); if (lastVector_ != left || lastIndex_ != leftIndex) { @@ -318,6 +325,7 @@ class MergeJoin : public Operator { } rawLeftRowNumbers_[outputIndex] = lastLeftRowNumber_; + rawRightJoinRows_[outputIndex] = rightJoinForFullOuter; } /// Returns a subset of "match" rows in [0, numRows) range that were @@ -351,25 +359,32 @@ class MergeJoin : public Operator { /// rows that correspond to a single left-side row. Use /// 'noMoreFilterResults' to make sure 'onMiss' is called for the last /// left-side row. - template + template void processFilterResult( vector_size_t outputIndex, bool passed, - TOnMiss onMiss) { + TOnMiss onMiss, + TOnMatch onMatch) { auto rowNumber = rawLeftRowNumbers_[outputIndex]; if (currentLeftRowNumber_ != rowNumber) { if (currentRow_ != -1 && !currentRowPassed_) { - onMiss(currentRow_); + onMiss(currentRow_, rawRightJoinRows_[currentRow_]); } currentRow_ = outputIndex; currentLeftRowNumber_ = rowNumber; currentRowPassed_ = false; + firstMatched_ = false; } else { currentRow_ = outputIndex; } if (passed) { currentRowPassed_ = true; + + if (!firstMatched_) { + onMatch(outputIndex); + firstMatched_ = true; + } } } @@ -385,12 +400,17 @@ class MergeJoin : public Operator { /// filter failed for all matches of that row. template void noMoreFilterResults(TOnMiss onMiss) { - if (!currentRowPassed_) { - onMiss(currentRow_); + if (!currentRowPassed_ && currentRow_ >= 0) { + onMiss(currentRow_, rawRightJoinRows_[currentRow_]); } currentRow_ = -1; currentRowPassed_ = false; + firstMatched_ = false; + } + + bool isRightJoinForFullOuter(vector_size_t row) { + return rawRightJoinRows_[row]; } private: @@ -412,6 +432,9 @@ class MergeJoin : public Operator { BufferPtr leftRowNumbers_; vector_size_t* rawLeftRowNumbers_; + BufferPtr rightJoinRows_; + bool* rawRightJoinRows_; + // Synthetic number assigned to the last added "match" row or zero if no row // has been added yet. vector_size_t lastLeftRowNumber_{0}; @@ -425,6 +448,8 @@ class MergeJoin : public Operator { // True if at least one row in a block of output rows corresponding a single // left-side row identified by 'currentRowNumber' passed the filter. bool currentRowPassed_{false}; + + bool firstMatched_{false}; }; /// Used to record both left and right join. @@ -505,6 +530,10 @@ class MergeJoin : public Operator { // A set of rows with matching keys on the left side. std::optional leftMatch_; + std::optional previousLeftMatch_ = + Match{{}, -1, -1, false, std::nullopt}; + ; + // A set of rows with matching keys on the right side. std::optional rightMatch_; @@ -518,5 +547,13 @@ class MergeJoin : public Operator { // True if all the right side data has been received. bool noMoreRightInput_{false}; + + bool leftJoinForFullFinished_{false}; + + bool rightJoinForFullFinished_{false}; + + std::optional leftForRightJoinMatch_; + + std::optional rightForRightJoinMatch_; }; } // namespace facebook::velox::exec diff --git a/velox/exec/MergeSource.cpp b/velox/exec/MergeSource.cpp index eb5653d22cb0..a90d21967214 100644 --- a/velox/exec/MergeSource.cpp +++ b/velox/exec/MergeSource.cpp @@ -243,8 +243,10 @@ BlockingReason MergeJoinSource::next( common::testutil::TestValue::adjust( "facebook::velox::exec::MergeSource::next", this); return state_.withWLock([&](auto& state) { - if (state.data != nullptr) { - *data = std::move(state.data); + if (!state.dataQueue.empty()) { + *data = std::move(state.dataQueue.front()); + state.dataQueue.pop(); + notify(producerPromise_); return BlockingReason::kNotBlocked; } @@ -284,11 +286,7 @@ BlockingReason MergeJoinSource::enqueue( return BlockingReason::kNotBlocked; } - if (state.data != nullptr) { - return waitForConsumer(future); - } - - state.data = std::move(data); + state.dataQueue.push(std::move(data)); notify(consumerPromise_); return waitForConsumer(future); diff --git a/velox/exec/MergeSource.h b/velox/exec/MergeSource.h index c15fb7182892..31b03058415b 100644 --- a/velox/exec/MergeSource.h +++ b/velox/exec/MergeSource.h @@ -74,6 +74,7 @@ class MergeJoinSource { struct State { bool atEnd = false; RowVectorPtr data; + std::queue dataQueue; }; folly::Synchronized state_; diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 6e627dadd25a..42f3cb4d9763 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -838,6 +838,190 @@ TEST_F(MergeJoinTest, semiJoin) { core::JoinType::kRightSemiFilter); } +TEST_F(MergeJoinTest, semiJoinWithMultiMatchedRowsInDifferentBatches) { + auto left = + makeRowVector({"t0"}, {makeNullableFlatVector({2, 2, 2, 2, 2})}); + + auto right = makeRowVector( + {"u0"}, {makeNullableFlatVector({2, 2, 2, 2, 2, 2})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto testSemiJoin = [&](const std::string& filter, + const std::string& sql, + const std::vector& outputLayout, + core::JoinType joinType) { + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values(split(left, 2)) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(split(right, 2)) + .planNode(), + filter, + outputLayout, + joinType) + .planNode(); + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchRows, "2") + .config(core::QueryConfig::kMaxOutputBatchRows, "2") + .assertResults(sql); + }; + + // Without filter + testSemiJoin( + "", + "SELECT t0 FROM t where t0 IN (SELECT u0 from u)", + {"t0"}, + core::JoinType::kLeftSemiFilter); + testSemiJoin( + "", + "SELECT u0 FROM u where u0 IN (SELECT t0 from t)", + {"u0"}, + core::JoinType::kRightSemiFilter); + + // With filter + testSemiJoin( + "t0 >1", + "SELECT t0 FROM t where t0 IN (SELECT u0 from u) and t0 > 1", + {"t0"}, + core::JoinType::kLeftSemiFilter); + testSemiJoin( + "u0 > 1", + "SELECT u0 FROM u where u0 IN (SELECT t0 from t) and u0 > 1", + {"u0"}, + core::JoinType::kRightSemiFilter); +} + +TEST_F(MergeJoinTest, leftJoinWithFilter) { + auto left = makeRowVector({"t0"}, {makeNullableFlatVector({1, 1})}); + + auto right = makeRowVector({"u0"}, {makeNullableFlatVector({1, 1})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Anti join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({split(left, 2)}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(split(right, 2)) + .planNode(), + "t0 > 2", + {"t0"}, + core::JoinType::kLeft) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchRows, "2") + .config(core::QueryConfig::kMaxOutputBatchRows, "2") + .assertResults( + "SELECT t0 FROM t WHERE NOT exists (select 1 from u where t0 = u0 AND t.t0 > 2 ) "); +} + +TEST_F( + MergeJoinTest, + antiJoinWithFilterWithMultiMatchedRowsInDifferentBatches) { + auto left = + makeRowVector({"t0"}, {makeNullableFlatVector({1, 2, 3})}); + + auto right = + makeRowVector({"u0"}, {makeNullableFlatVector({1, 2, 3})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Anti join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({split(left, 2)}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator) + .values(split(right, 2)) + .planNode(), + "t0 > 2", + {"t0"}, + core::JoinType::kAnti) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .config(core::QueryConfig::kPreferredOutputBatchRows, "2") + .config(core::QueryConfig::kMaxOutputBatchRows, "2") + .assertResults( + "SELECT t0 FROM t WHERE NOT exists (select 1 from u where t0 = u0 AND t.t0 > 2 ) "); +} + +TEST_F(MergeJoinTest, antiJoinWithFilterWithMultiMatchedRows) { + auto left = makeRowVector({"t0"}, {makeNullableFlatVector({1, 2})}); + + auto right = + makeRowVector({"u0"}, {makeNullableFlatVector({1, 2, 2, 2})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Anti join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "t0 > 2", + {"t0"}, + core::JoinType::kAnti) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT t0 FROM t WHERE NOT exists (select 1 from u where t0 = u0 AND t.t0 > 2 ) "); +} + +TEST_F(MergeJoinTest, antiJoinWithTwoJoinKeysInDifferentBatch) { + auto left = makeRowVector( + {"a", "b"}, + {makeNullableFlatVector({1, 1, 1, 1}), + makeNullableFlatVector({3.0, 3.0, 3.0, 3.0})}); + + auto right = makeRowVector( + {"c", "d"}, + {makeNullableFlatVector({1, 1, 1}), + makeNullableFlatVector({2.0, 2.0, 4.0})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Anti join. + auto planNodeIdGenerator = std::make_shared(); + auto plan = PlanBuilder(planNodeIdGenerator) + .values({split(left, 2)}) + .mergeJoin( + {"a"}, + {"c"}, + PlanBuilder(planNodeIdGenerator) + .values({split(right, 2)}) + .planNode(), + "b < d", + {"a", "b"}, + core::JoinType::kAnti) + .planNode(); + + AssertQueryBuilder(plan, duckDbQueryRunner_) + .assertResults( + "SELECT * FROM t WHERE NOT exists (select * from u where t.a = u.c and t.b < u.d)"); +} + TEST_F(MergeJoinTest, rightJoin) { auto left = makeRowVector( {"t0"}, @@ -1128,6 +1312,45 @@ TEST_F(MergeJoinTest, fullOuterJoin) { "SELECT * FROM t FULL OUTER JOIN u ON t.t0 = u.u0 AND t.t0 > 2"); } +TEST_F(MergeJoinTest, fullOuterJoinWithDuplicateMatch) { + // Each row on the left side has at most one match on the right side. + auto left = makeRowVector( + {"a", "b"}, + { + makeNullableFlatVector({1, 2, 2, 2, 3, 5, 6, std::nullopt}), + makeNullableFlatVector( + {2.0, 100.0, 1.0, 1.0, 3.0, 1.0, 6.0, std::nullopt}), + }); + + auto right = makeRowVector( + {"c", "d"}, + { + makeNullableFlatVector( + {0, 2, 2, 2, 2, 3, 4, 5, 7, std::nullopt}), + makeNullableFlatVector( + {0.0, 3.0, -1.0, -1.0, 3.0, 2.0, 1.0, 3.0, 7.0, std::nullopt}), + }); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + auto planNodeIdGenerator = std::make_shared(); + + auto rightPlan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"a"}, + {"c"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "b < d", + {"a", "b", "c", "d"}, + core::JoinType::kFull) + .planNode(); + AssertQueryBuilder(rightPlan, duckDbQueryRunner_) + .assertResults("SELECT * from t FULL OUTER JOIN u ON a = c AND b < d"); +} + TEST_F(MergeJoinTest, fullOuterJoinNoFilter) { auto left = makeRowVector( {"t0", "t1", "t2", "t3"},