From 86b2ccc6577579f7abbc22d73721858f15e51ac3 Mon Sep 17 00:00:00 2001 From: Mikhail Burakov Date: Sat, 7 Jan 2023 14:52:15 +0100 Subject: Implement subscribing to topics from lua --- main.c | 71 +++++++++++++++++++++++++++++++++++++++++++++++++++++---------- message.c | 2 ++ message.h | 1 + 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/main.c b/main.c index 9b760eb..dac79fb 100644 --- a/main.c +++ b/main.c @@ -192,6 +192,13 @@ static void GatherMessages(const void* nodep, VISIT which, int depth) { g_twalk_closure.ptr++->iov_len = 8; } +static void UnrefLuaCallbacks(const void* nodep, VISIT which, int depth) { + (void)depth; + struct Message* message = *(void* const*)nodep; + if (which == preorder || which == endorder) return; + luaL_unref(g_service.lua_state, LUA_REGISTRYINDEX, message->lua_callback); +} + static bool ServeHttpGet(int fd, const char* target) { if (!strcmp(target, "/favicon.ico")) return SendHttpReply(fd, "404 Not Found", NULL, NULL, 0); @@ -368,9 +375,9 @@ static void OnMqttConnectAck(void* user, bool success) { static int LuaPublish(lua_State* lua_state) { int result = 0; - size_t topic_length; - const char* topic = luaL_checklstring(lua_state, 1, &topic_length); - if (!topic || topic_length > UINT16_MAX) { + size_t topic_size; + const char* topic = luaL_checklstring(lua_state, 1, &topic_size); + if (!topic || topic_size > UINT16_MAX) { LOGW("Invalid topic argument for publish call"); goto bail_out; } @@ -380,8 +387,8 @@ static int LuaPublish(lua_State* lua_state) { LOGW("Invalid payload argument for publish call"); goto bail_out; } - LOGD("%.*s -> %.*s", (int)topic_length, topic, (int)payload_length, payload); - if (!MqttPublish(g_service.mqtt, topic, (uint16_t)topic_length, payload, + LOGD("%.*s -> %.*s", (int)topic_size, topic, (int)payload_length, payload); + if (!MqttPublish(g_service.mqtt, topic, (uint16_t)topic_size, payload, payload_length)) { LOGW("Failed to publish mqtt message (%s)", strerror(errno)); goto bail_out; @@ -394,8 +401,44 @@ bail_out: } static int LuaSubscribe(lua_State* lua_state) { - // TODO(mburakov): Implement me! - return 0; + int result = 0; + size_t topic_size; + const char* topic = luaL_checklstring(lua_state, 1, &topic_size); + if (!topic) { + LOGW("Invalid topic argument for subscribe call"); + goto bail_out; + } + luaL_checkany(lua_state, 2); + if (!lua_isfunction(lua_state, -1)) { + LOGW("Invalid callback argument for subscribe call"); + goto bail_out; + } + struct Message key = { + .topic = *(void**)(void*)&topic, + .topic_size = topic_size, + }; + struct Message** node = tsearch(&key, &g_service.messages, MessageCompare); + if (!node) { + LOGW("Failed to create message node (%s)", strerror(errno)); + goto bail_out; + } + if (*node == &key) { + struct Message* message = MessageCreate(topic, topic_size); + if (!message) { + LOGW("Failed to create message (%s)", strerror(errno)); + tdelete(&key, &g_service.messages, MessageCompare); + goto bail_out; + } + *node = message; + g_service.messages_count++; + } + luaL_unref(lua_state, LUA_REGISTRYINDEX, (*node)->lua_callback); + (*node)->lua_callback = luaL_ref(lua_state, LUA_REGISTRYINDEX); + result = 1; + +bail_out: + lua_pushboolean(lua_state, result); + return 1; } static void OnMqttSubscribeAck(void* user, bool success) { @@ -414,10 +457,8 @@ static void OnMqttSubscribeAck(void* user, bool success) { } luaL_openlibs(g_service.lua_state); - lua_pushcfunction(g_service.lua_state, LuaPublish); - lua_setglobal(g_service.lua_state, "publish"); - lua_pushcfunction(g_service.lua_state, LuaSubscribe); - lua_setglobal(g_service.lua_state, "subscribe"); + lua_register(g_service.lua_state, "publish", LuaPublish); + lua_register(g_service.lua_state, "subscribe", LuaSubscribe); DIR* current_dir = opendir("."); if (!current_dir) { @@ -471,6 +512,13 @@ static void OnMqttPublish(void* user, const char* topic, size_t topic_size, free((*node)->payload); (*node)->payload = payload_copy; (*node)->payload_size = payload_size; + if ((*node)->lua_callback != LUA_REFNIL) { + // TODO(mburakov): Handle lua errors. + lua_rawgeti(g_service.lua_state, LUA_REGISTRYINDEX, (*node)->lua_callback); + lua_pushlstring(g_service.lua_state, (*node)->payload, + (*node)->payload_size); + lua_pcall(g_service.lua_state, 1, 0, 0); + } return; delete_node: @@ -591,6 +639,7 @@ int main(int argc, char* argv[]) { OnSignal(SIGABRT); } } + twalk(g_service.messages, UnrefLuaCallbacks); tdestroy(g_service.messages, MessageDestroy); while (g_service.clients) DropClientContext(g_service.clients); if (g_service.lua_state) lua_close(g_service.lua_state); diff --git a/message.c b/message.c index 657c759..dcee14e 100644 --- a/message.c +++ b/message.c @@ -18,6 +18,7 @@ #include "message.h" #include +#include #include #include @@ -38,6 +39,7 @@ struct Message* MessageCreate(const char* topic, size_t topic_size) { message->topic_size = topic_size; message->payload = NULL; message->payload_size = 0; + message->lua_callback = LUA_REFNIL; return message; free_message: diff --git a/message.h b/message.h index 6ba9e1a..fb4a7ad 100644 --- a/message.h +++ b/message.h @@ -25,6 +25,7 @@ struct Message { size_t topic_size; void* payload; size_t payload_size; + int lua_callback; }; struct Message* MessageCreate(const char* topic, size_t topic_size); -- cgit v1.2.3