Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add blpop and brpop cmd #438

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions src/base_cmd.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,91 @@ BaseCmd* BaseCmdGroup::GetSubCmd(const std::string& cmdName) {
return subCmd->second.get();
}

void BaseCmd::BlockThisClientToWaitLRPush(std::vector<std::string>& keys, int64_t expire_time, PClient* client,
BlockedConnNode::Type type) {
std::unique_lock<std::shared_mutex> latch(g_pikiwidb->GetBlockMtx());
auto& key_to_conns = g_pikiwidb->GetMapFromKeyToConns();
std::shared_ptr<std::atomic<bool>> is_done = std::make_shared<std::atomic<bool>>(false);
for (auto key : keys) {
pikiwidb::BlockKey blpop_key{client->GetCurrentDB(), key};

auto it = key_to_conns.find(blpop_key);
if (it == key_to_conns.end()) {
key_to_conns.emplace(blpop_key, std::make_unique<std::list<BlockedConnNode>>());
it = key_to_conns.find(blpop_key);
}
auto& wait_list_of_this_key = it->second;
wait_list_of_this_key->emplace_back(expire_time, client, type, is_done);
}
}

void BaseCmd::ServeAndUnblockConns(PClient* client) {
pikiwidb::BlockKey key{client->GetCurrentDB(), client->Key()};

std::shared_lock<std::shared_mutex> read_latch(g_pikiwidb->GetBlockMtx());
auto& key_to_conns = g_pikiwidb->GetMapFromKeyToConns();
auto it = key_to_conns.find(key);
if (it == key_to_conns.end()) {
// no client is waitting for this key
return;
}
read_latch.unlock();

std::unique_lock<std::shared_mutex> write_lock(g_pikiwidb->GetBlockMtx());
auto& waitting_list = it->second;
std::vector<std::string> elements;
storage::Status s;

// traverse this list from head to tail(in the order of adding sequence) ,means "first blocked, first get served“
for (auto conn_blocked = waitting_list->begin(); conn_blocked != waitting_list->end();) {
if (conn_blocked->is_done_->exchange(true)) {
conn_blocked = waitting_list->erase(conn_blocked);
continue;
}

PClient* BlockedClient = (*conn_blocked).GetBlockedClient();

if (BlockedClient->State() == ClientState::kClosed) {
conn_blocked = waitting_list->erase(conn_blocked);
continue;
}

switch (conn_blocked->GetCmdType()) {
case BlockedConnNode::Type::BLPop:
s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->LPop(client->Key(), 1, &elements);
break;
case BlockedConnNode::Type::BRPop:
s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->RPop(client->Key(), 1, &elements);
break;
}

if (s.ok()) {
BlockedClient->AppendArrayLen(2);
BlockedClient->AppendString(client->Key());
BlockedClient->AppendString(elements[0]);
} else if (s.IsNotFound()) {
// this key has no more elements to serve more blocked conn.
break;
} else {
BlockedClient->SetRes(CmdRes::kErrOther, s.ToString());
}
BlockedClient->SendPacket();
conn_blocked = waitting_list->erase(conn_blocked); // remove this conn from current waiting list
}
}

bool BlockedConnNode::IsExpired() {
if (expire_time_ == 0) {
return false;
}
auto now = std::chrono::system_clock::now();
int64_t now_in_ms = std::chrono::time_point_cast<std::chrono::milliseconds>(now).time_since_epoch().count();
if (expire_time_ <= now_in_ms) {
return true;
}
return false;
}

bool BaseCmdGroup::DoInitial(PClient* client) {
client->SetSubCmdName(client->argv_[1]);
if (!subCmds_.contains(client->SubCmdName())) {
Expand Down
37 changes: 37 additions & 0 deletions src/base_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ const std::string kCmdNameRPush = "rpush";
const std::string kCmdNameRPushx = "rpushx";
const std::string kCmdNameLPop = "lpop";
const std::string kCmdNameRPop = "rpop";
const std::string kCmdNameBLPop = "blpop";
const std::string kCmdNameBRPop = "brpop";
const std::string kCmdNameLRem = "lrem";
const std::string kCmdNameLRange = "lrange";
const std::string kCmdNameLTrim = "ltrim";
Expand Down Expand Up @@ -208,6 +210,23 @@ enum AclCategory {
kAclCategoryRaft = (1 << 21),
};

class BlockedConnNode {
public:
enum Type { BLPop = 0, BRPop };
virtual ~BlockedConnNode() {}
BlockedConnNode(int64_t expire_time, PClient* client, Type type, std::shared_ptr<std::atomic<bool>> is_done)
: expire_time_(expire_time), client_(client), type_(type), is_done_(is_done) {}
bool IsExpired();
PClient* GetBlockedClient() { return client_; }
std::shared_ptr<std::atomic<bool>> is_done_;
Type GetCmdType() { return type_; }

private:
Type type_;
int64_t expire_time_;
PClient* client_;
};

/**
* @brief Base class for all commands
* BaseCmd, as the base class for all commands, mainly implements some common functions
Expand Down Expand Up @@ -317,6 +336,11 @@ class BaseCmd : public std::enable_shared_from_this<BaseCmd> {

uint32_t GetCmdID() const;

void ServeAndUnblockConns(PClient* client);

void BlockThisClientToWaitLRPush(std::vector<std::string>& keys, int64_t expire_time, PClient* client,
BlockedConnNode::Type type);

protected:
// Execute a specific command
virtual void DoCmd(PClient* client) = 0;
Expand Down Expand Up @@ -363,4 +387,17 @@ class BaseCmdGroup : public BaseCmd {
private:
std::map<std::string, std::unique_ptr<BaseCmd>> subCmds_;
};

struct BlockKey { // this data struct is made for the scenario of multi dbs in pika.
int db_id;
std::string key;
bool operator==(const BlockKey& p) const { return p.db_id == db_id && p.key == key; }
};

struct BlockKeyHash {
std::size_t operator()(const BlockKey& k) const {
return std::hash<int>{}(k.db_id) ^ std::hash<std::string>{}(k.key);
}
};

} // namespace pikiwidb
2 changes: 2 additions & 0 deletions src/cmd_admin.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ void SortCmd::DoCmd(PClient* client) {
storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->RPush(store_key_, ret_, &reply_num);
if (s.ok()) {
client->AppendInteger(reply_num);
client->SetKey(store_key_);
ServeAndUnblockConns(client);
} else {
client->SetRes(CmdRes::kErrOther, s.ToString());
}
Expand Down
2 changes: 2 additions & 0 deletions src/cmd_keys.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ void RenameCmd::DoCmd(PClient* client) {
storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->Rename(client->Key(), client->argv_[2]);
if (s.ok()) {
client->SetRes(CmdRes::kOK);
client->SetKey(client->argv_[2]);
ServeAndUnblockConns(client);
} else if (s.IsNotFound()) {
client->SetRes(CmdRes::kNotFound, s.ToString());
} else {
Expand Down
93 changes: 93 additions & 0 deletions src/cmd_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

#include "cmd_list.h"
#include "pikiwidb.h"
#include "pstd_string.h"
#include "store.h"

Expand All @@ -25,6 +26,7 @@ void LPushCmd::DoCmd(PClient* client) {
PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->LPush(client->Key(), list_values, &reply_num);
if (s.ok()) {
client->AppendInteger(reply_num);
ServeAndUnblockConns(client);
} else if (s.IsInvalidArgument()) {
client->SetRes(CmdRes::kMultiKey);
} else {
Expand Down Expand Up @@ -72,6 +74,8 @@ void RPoplpushCmd::DoCmd(PClient* client) {
storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->RPoplpush(source_, receiver_, &value);
if (s.ok()) {
client->AppendString(value);
client->SetKey(receiver_);
ServeAndUnblockConns(client);
} else if (s.IsNotFound()) {
client->AppendStringLen(-1);
} else if (s.IsInvalidArgument()) {
Expand All @@ -96,6 +100,7 @@ void RPushCmd::DoCmd(PClient* client) {
PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->RPush(client->Key(), list_values, &reply_num);
if (s.ok()) {
client->AppendInteger(reply_num);
ServeAndUnblockConns(client);
} else if (s.IsInvalidArgument()) {
client->SetRes(CmdRes::kMultiKey);
} else {
Expand Down Expand Up @@ -125,6 +130,94 @@ void RPushxCmd::DoCmd(PClient* client) {
}
}

BLPopCmd::BLPopCmd(const std::string& name, int16_t arity)
: BaseCmd(name, arity, kCmdFlagsWrite, kAclCategoryWrite | kAclCategoryList) {}

bool BLPopCmd::DoInitial(PClient* client) {
client->SetKey(client->argv_[1]);
int64_t timeout = 0;
if (!pstd::String2int(client->argv_.back().data(), client->argv_.back().size(), &timeout)) {
client->SetRes(CmdRes::kInvalidInt);
return false;
}
constexpr int64_t seconds_of_ten_years = 10 * 365 * 24 * 3600;
if (timeout < 0 || timeout > seconds_of_ten_years) {
client->SetRes(CmdRes::kErrOther,
"timeout can't be a negative value and can't exceed the number of seconds in 10 years");
return false;
}

if (timeout > 0) {
auto now = std::chrono::system_clock::now();
expire_time_ =
std::chrono::time_point_cast<std::chrono::milliseconds>(now).time_since_epoch().count() + timeout * 1000;
}
return true;
}

void BLPopCmd::DoCmd(PClient* client) {
std::vector<std::string> elements;
std::vector<std::string> list_keys(client->argv_.begin() + 1, client->argv_.end() - 1);
for (auto key : list_keys) {
storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->LPop(key, 1, &elements);
if (s.ok()) {
client->AppendArrayLen(2);
client->AppendString(key);
client->AppendString(elements[0]);
return;
} else if (s.IsNotFound()) {
} else {
client->SetRes(CmdRes::kErrOther, s.ToString());
return;
}
}
BlockThisClientToWaitLRPush(list_keys, expire_time_, client, BlockedConnNode::Type::BLPop);
}

BRPopCmd::BRPopCmd(const std::string& name, int16_t arity)
: BaseCmd(name, arity, kCmdFlagsWrite, kAclCategoryWrite | kAclCategoryList) {}

bool BRPopCmd::DoInitial(PClient* client) {
client->SetKey(client->argv_[1]);
int64_t timeout = 0;
if (!pstd::String2int(client->argv_.back().data(), client->argv_.back().size(), &timeout)) {
client->SetRes(CmdRes::kInvalidInt);
return false;
}
constexpr int64_t seconds_of_ten_years = 10 * 365 * 24 * 3600;
if (timeout < 0 || timeout > seconds_of_ten_years) {
client->SetRes(CmdRes::kErrOther,
"timeout can't be a negative value and can't exceed the number of seconds in 10 years");
return false;
}

if (timeout > 0) {
auto now = std::chrono::system_clock::now();
expire_time_ =
std::chrono::time_point_cast<std::chrono::milliseconds>(now).time_since_epoch().count() + timeout * 1000;
}
return true;
}

void BRPopCmd::DoCmd(PClient* client) {
std::vector<std::string> elements;
std::vector<std::string> list_keys(client->argv_.begin() + 1, client->argv_.end() - 1);
for (auto key : list_keys) {
storage::Status s = PSTORE.GetBackend(client->GetCurrentDB())->GetStorage()->RPop(key, 1, &elements);
if (s.ok()) {
client->AppendArrayLen(2);
client->AppendString(key);
client->AppendString(elements[0]);
return;
} else if (s.IsNotFound()) {
} else {
client->SetRes(CmdRes::kErrOther, s.ToString());
return;
}
}
BlockThisClientToWaitLRPush(list_keys, expire_time_, client, BlockedConnNode::Type::BRPop);
}

LPopCmd::LPopCmd(const std::string& name, int16_t arity)
: BaseCmd(name, arity, kCmdFlagsWrite, kAclCategoryWrite | kAclCategoryList) {}

Expand Down
26 changes: 26 additions & 0 deletions src/cmd_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@ class RPushCmd : public BaseCmd {
void DoCmd(PClient* client) override;
};

class BLPopCmd : public BaseCmd {
public:
BLPopCmd(const std::string& name, int16_t arity);

protected:
bool DoInitial(PClient* client) override;

private:
void DoCmd(PClient* client) override;

int64_t expire_time_{0};
};

class BRPopCmd : public BaseCmd {
public:
BRPopCmd(const std::string& name, int16_t arity);

protected:
bool DoInitial(PClient* client) override;

private:
void DoCmd(PClient* client) override;

int64_t expire_time_{0};
};

class RPopCmd : public BaseCmd {
public:
RPopCmd(const std::string& name, int16_t arity);
Expand Down
2 changes: 2 additions & 0 deletions src/cmd_table_manager.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ void CmdTableManager::InitCmdTable() {
ADD_COMMAND(LPushx, -3);
ADD_COMMAND(RPushx, -3);
ADD_COMMAND(LPop, 2);
ADD_COMMAND(BLPop, -3);
ADD_COMMAND(BRPop, -3);
ADD_COMMAND(LIndex, 3);
ADD_COMMAND(LLen, 2);
ADD_COMMAND(RPoplpush, 3);
Expand Down
24 changes: 24 additions & 0 deletions src/pikiwidb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,26 @@ void PikiwiDB::OnNewConnection(uint64_t connId, std::shared_ptr<pikiwidb::PClien
client->OnConnect();
}

void PikiwiDB::ScanEvictedBlockedConnsOfBlrpop() {
std::unique_lock<std::shared_mutex> latch(block_mtx_);
auto& key_to_blocked_conns = g_pikiwidb->GetMapFromKeyToConns();
for (auto& it : key_to_blocked_conns) {
auto& conns_list = it.second;
for (auto conn_node = conns_list->begin(); conn_node != conns_list->end();) {
if (conn_node->is_done_->exchange(true) || conn_node->GetBlockedClient()->State() == ClientState::kClosed) {
conn_node = conns_list->erase(conn_node);
} else if (conn_node->IsExpired()) {
PClient* conn_ptr = conn_node->GetBlockedClient();
conn_ptr->AppendString("");
conn_ptr->SendPacket();
conn_node = conns_list->erase(conn_node);
} else {
conn_node++;
}
}
}
}

bool PikiwiDB::Init() {
char runid[kRunidSize + 1] = "";
getRandomHexChars(runid, kRunidSize);
Expand Down Expand Up @@ -174,6 +194,10 @@ bool PikiwiDB::Init() {
timerTask->SetCallback([]() { PREPL.Cron(); });
event_server_->AddTimerTask(timerTask);

auto BLRPopTimerTask = std::make_shared<net::CommonTimerTask>(250);
BLRPopTimerTask->SetCallback(std::bind(&PikiwiDB::ScanEvictedBlockedConnsOfBlrpop, this));
event_server_->AddTimerTask(BLRPopTimerTask);

return true;
}

Expand Down
Loading
Loading