diff options
| author | Mikhail Burakov <mburakov@mailbox.org> | 2022-12-26 12:15:30 +0100 |
|---|---|---|
| committer | Mikhail Burakov <mburakov@mailbox.org> | 2022-12-26 12:15:30 +0100 |
| commit | b2ccb9ff8fa7e7ad14143cfbf7e26ec79bfbfdc8 (patch) | |
| tree | 1c372c1ef20f3dfa9931f908837ac8f94e7a826c /main.c | |
| parent | 2fcd2038a47b1990532dde4078d71308dfcc58ca (diff) | |
Initial commit for version 2
Diffstat (limited to 'main.c')
| -rw-r--r-- | main.c | 428 |
1 files changed, 188 insertions, 240 deletions
@@ -15,289 +15,236 @@ * along with MQhTTp. If not, see <https://www.gnu.org/licenses/>. */ -#include <dirent.h> +//#include <dirent.h> #include <errno.h> -#include <lauxlib.h> -#include <lualib.h> -#include <mosquitto.h> -#include <search.h> +//#include <lauxlib.h> +//#include <lualib.h> +//#include <search.h> +#include <arpa/inet.h> +#include <netinet/in.h> #include <signal.h> -#include <stdio.h> +#include <stdbool.h> +#include <stdint.h> #include <stdlib.h> #include <string.h> -#include <sys/epoll.h> -#include <sys/uio.h> +#include <sys/socket.h> #include <unistd.h> -#include "logging.h" -#include "server.h" +#include "toolbox/buffer.h" +#include "toolbox/io_muxer.h" +#include "toolbox/mqtt.h" +#include "toolbox/mqtt_parser.h" +#include "toolbox/utils.h" -#define UNCONST(op) ((void*)(uintptr_t)(op)) - -struct Message { - char* topic; - void* payload; - size_t payload_size; - int handler; -}; - -struct Context { - struct mosquitto* mosq; - void* messages; - char* buffer; - size_t size; - size_t alloc; - lua_State* lua_state; +struct ServiceContext { + int http; + int mqtt; + struct IoMuxer io_muxer; + struct Buffer mqtt_buffer; }; -static struct { - struct Context* context; - const char* topic; - size_t topic_len; -} g_twalk_context; +static volatile sig_atomic_t g_signal; -static volatile sig_atomic_t g_shutdown; +static void OnSignal(int signal) { g_signal = signal; } -static void OnSignal(int num) { - (void)num; - g_shutdown = 1; +static struct sockaddr_in GetHttpAddr() { + const char* sport = getenv("HTTP_PORT"); + int port = sport ? atoi(sport) : 8080; + if (0 > port || port > UINT16_MAX) { + LOGE("Invalid http port"); + } + const char* saddr = getenv("HTTP_ADDR"); + in_addr_t addr = inet_addr(saddr ? saddr : "0.0.0.0"); + if (addr == INADDR_NONE) { + LOGE("Invalid http addr"); + exit(EXIT_FAILURE); + } + struct sockaddr_in result = { + .sin_family = AF_INET, + .sin_port = htons((uint16_t)port), + .sin_addr.s_addr = addr, + }; + return result; } -static int CompareMessages(const void* a, const void* b) { - return strcmp(((const struct Message*)a)->topic, - ((const struct Message*)b)->topic); +static struct sockaddr_in GetMqttAddr(int argc, char* argv[]) { + int port = argc > 2 ? atoi(argv[2]) : 1883; + if (0 > port || port > UINT16_MAX) { + LOGE("Invalid mqtt port"); + exit(EXIT_FAILURE); + } + in_addr_t addr = inet_addr(argc > 1 ? argv[1] : "127.0.0.1"); + if (addr == INADDR_NONE) { + LOGE("Invalid mqtt addr"); + exit(EXIT_FAILURE); + } + struct sockaddr_in result = { + .sin_family = AF_INET, + .sin_port = htons((uint16_t)port), + .sin_addr.s_addr = addr, + }; + return result; } -static void FreeMessage(void* nodep) { - struct Message* message = nodep; - free(message->topic); - free(message->payload); +static void OnHttpRead(void* user) { + // TODO(mburakov): Implement me!!! } -static bool Flush(int fd, const char* status, const char* type, - const void* body, size_t body_size) { - char buffer[256]; - int length; - if (type) { - length = snprintf(buffer, sizeof(buffer), - "HTTP/1.1 %s\r\n" - "Content-Type: %s\r\n" - "Content-Length: %zu\r\n" - "\r\n", - status, type, body_size); - } else { - length = snprintf(buffer, sizeof(buffer), - "HTTP/1.1 %s\r\n" - "Content-Length: %zu\r\n" - "\r\n", - status, body_size); +static void OnMqttConnectAck(void* user, bool success) { + struct ServiceContext* context = user; + if (!success) { + LOGW("Mqtt broker rejected connection"); + OnSignal(SIGTERM); + return; } - // TODO(mburakov): Verify length is valid at this point. - struct iovec iov[] = {{.iov_base = buffer, .iov_len = (size_t)length}, - {.iov_base = UNCONST(body), .iov_len = body_size}}; - ssize_t result = writev(fd, iov, 2); - if (result != (ssize_t)(iov[0].iov_len + iov[1].iov_len)) { - Log("Failed to write complete reply (%s)", strerror(errno)); - return false; + if (!MqttSubscribe(context->mqtt, 1, "+/#", 3)) { + LOGE("Failed to subscribe to mqtt topic (%s)", strerror(errno)); + OnSignal(SIGABRT); + return; } - return true; } -static bool BufferAppend(struct Context* context, const char* data, - size_t size) { - size_t alloc = context->size + size; - if (context->alloc < alloc) { - char* buffer = realloc(context->buffer, alloc); - if (!buffer) { - Log("Failed to reallocate buffer (%s)", strerror(errno)); - return false; - } - context->buffer = buffer; - context->alloc = alloc; +static void OnMqttSubscribeAck(void* user, bool success) { + (void)user; + if (!success) { + LOGW("Mqtt broker rejected subscription"); + OnSignal(SIGTERM); + return; } - memcpy(context->buffer + context->size, data, size); - context->size += size; - return true; } -static void CollectMatchingMessages(const void* nodep, VISIT which, int level) { - (void)level; - const struct Message* const* it = nodep; - if (which == preorder || which == endorder || - strncmp((*it)->topic, g_twalk_context.topic, g_twalk_context.topic_len)) - return; - static const char kPre[] = "<a href=\"/"; - BufferAppend(g_twalk_context.context, kPre, strlen(kPre)); - BufferAppend(g_twalk_context.context, (*it)->topic, strlen((*it)->topic)); - static const char kInt[] = "\">/"; - BufferAppend(g_twalk_context.context, kInt, strlen(kInt)); - BufferAppend(g_twalk_context.context, (*it)->topic, strlen((*it)->topic)); - static const char kPost[] = "<br>"; - BufferAppend(g_twalk_context.context, kPost, strlen(kPost)); +static void OnMqttPublish(void* user, const char* topic, size_t topic_size, + const void* payload, size_t payload_size) { + (void)user; + LOGD("%.*s <- %.*s", (int)topic_size, topic, (int)payload_size, + (const char*)payload); + // TODO(mburakov): Implement me! } -static bool HandleGetRequest(struct Context* context, int fd, - const char* target) { - if (!strcmp(target, "/favicon.ico")) - return Flush(fd, "404 Not Found", NULL, NULL, 0); - struct Message pred = {.topic = UNCONST(target + 1)}; - struct Message** it = tfind(&pred, &context->messages, CompareMessages); - if (it) { - return Flush(fd, "200 OK", "application/json", (*it)->payload, - (*it)->payload_size); - } - static const char kHeader[] = - "<!DOCTYPE html>" - "<html>" - "<head>" - "<title>MQhTTp</title>" - "</head>" - "<body>"; - static const char kFooter[] = - "</body>" - "</html>"; - context->size = 0; - g_twalk_context.context = context; - g_twalk_context.topic = target + 1; - g_twalk_context.topic_len = strlen(target) - 1; - if (!BufferAppend(context, kHeader, sizeof(kHeader) - 1)) return false; - twalk(context->messages, CollectMatchingMessages); - if (!BufferAppend(context, kFooter, sizeof(kFooter) - 1)) return false; - return Flush(fd, "200 OK", "text/html", context->buffer, context->size); +static void OnMqttFinished(void* user, size_t offset) { + struct ServiceContext* context = user; + BufferDiscard(&context->mqtt_buffer, offset); } -static bool HandleRequest(void* user, int fd, const char* method, - const char* target, const void* body, - size_t body_size) { - struct Context* context = user; - if (!strcmp(method, "GET")) return HandleGetRequest(context, fd, target); - if (strcmp(method, "POST")) - return Flush(fd, "405 Method Not Allowed", NULL, NULL, 0); - if (!body_size || !target[1]) - return Flush(fd, "400 Bad Request", NULL, NULL, 0); - int mosq_errno = mosquitto_publish(context->mosq, NULL, target + 1, - (int)body_size, body, 0, false); - if (mosq_errno == MOSQ_ERR_SUCCESS) return Flush(fd, "200 OK", NULL, NULL, 0); - const char* error = mosquitto_strerror(mosq_errno); - return Flush(fd, "500 Internal Server Error", "text/plain", error, - strlen(error)); +static void OnMqttRead(void* user) { + struct ServiceContext* context = user; + switch (BufferAppendFrom(&context->mqtt_buffer, context->mqtt)) { + case -1: + LOGE("Failed to read mqtt socket (%s)", strerror(errno)); + OnSignal(SIGABRT); + return; + case 0: + LOGW("Mqtt broker closed connection"); + OnSignal(SIGTERM); + return; + default: + break; + } + static const struct MqttParserCallbacks mqtt_callbacks = { + .on_connect_ack = OnMqttConnectAck, + .on_subscribe_ack = OnMqttSubscribeAck, + .on_publish = OnMqttPublish, + .on_finished = OnMqttFinished, + }; + for (;;) { + enum MqttParserResult result = + MqttParserParse(context->mqtt_buffer.data, context->mqtt_buffer.size, + &mqtt_callbacks, context); + switch (result) { + case kMqttParserResultFinished: + continue; + case kMqttParserResultWantMore: + if (!IoMuxerOnRead(&context->io_muxer, context->mqtt, OnMqttRead, + context)) { + LOGE("Failed to schedule mqtt read (%s)", strerror(errno)); + OnSignal(SIGABRT); + } + return; + case kMqttParserResultError: + LOGE("Failed to parse mqtt message"); + OnSignal(SIGABRT); + return; + default: + __builtin_unreachable(); + } + } } -static struct Message* GetMessage(void** messages, const char* topic) { - struct Message pred = {.topic = UNCONST(topic)}; - struct Message** it = tsearch(&pred, messages, CompareMessages); - if (!it) { - Log("Failed to add message to the map"); - return NULL; +int main(int argc, char* argv[]) { + struct sockaddr_in http_addr = GetHttpAddr(); + struct sockaddr_in mqtt_addr = GetMqttAddr(argc, argv); + + struct ServiceContext context; + IoMuxerCreate(&context.io_muxer); + context.http = socket(AF_INET, SOCK_STREAM, 0); + if (context.http == -1) { + LOGE("Failed to create http socket (%s)", strerror(errno)); + return EXIT_FAILURE; } - if (*it != &pred) return *it; - struct Message* message = calloc(1, sizeof(struct Message)); - if (!message) { - Log("Failed to allocate new message (%s)", strerror(errno)); - goto rollback_tsearch; + + int one = 1; + if (setsockopt(context.http, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) { + LOGE("Failed to reuse http socket (%s)", strerror(errno)); + goto close_http; } - message->topic = strdup(topic); - if (!message->topic) { - Log("Failed to copy topic (%s)", strerror(errno)); - goto rollback_calloc; + if (bind(context.http, (struct sockaddr*)&http_addr, sizeof(http_addr))) { + LOGE("Failed to bind http socket (%s)", strerror(errno)); + goto close_http; } - *it = message; - return message; -rollback_calloc: - free(message); -rollback_tsearch: - tdelete(&pred, messages, CompareMessages); - return NULL; -} - -static void HandleMosquitto(struct mosquitto* mosq, void* user, - const struct mosquitto_message* mosq_msg) { - (void)mosq; - struct Context* context = user; - struct Message* message = GetMessage(&context->messages, mosq_msg->topic); - if (!message) { - Log("Failed to get message"); - return; + if (listen(context.http, SOMAXCONN)) { + LOGE("Failed to listen http socket (%s)", strerror(errno)); + goto close_http; } - size_t payload_size = (size_t)mosq_msg->payloadlen; - void* buffer = malloc(payload_size); - if (!buffer) { - Log("Failed to copy payload (%s)", strerror(errno)); - return; + context.mqtt = socket(AF_INET, SOCK_STREAM, 0); + if (context.mqtt == -1) { + LOGE("Failed to create mqtt socket (%s)", strerror(errno)); + goto close_http; } - memcpy(buffer, mosq_msg->payload, payload_size); - free(message->payload); - message->payload = buffer; - message->payload_size = payload_size; - if (message->handler) { - // TODO(mburakov): Handle lua errors. - lua_rawgeti(context->lua_state, LUA_REGISTRYINDEX, message->handler); - lua_pushlstring(context->lua_state, message->payload, - message->payload_size); - lua_pcall(context->lua_state, 1, 0, 0); + + if (connect(context.mqtt, (struct sockaddr*)&mqtt_addr, sizeof(mqtt_addr))) { + LOGE("Failed to connect mqtt socket (%s)", strerror(errno)); + goto close_mqtt; + } + // TODO(mburakov): Implement keepalive + if (!MqttConnect(context.mqtt, 65535)) { + LOGE("Failed to connect mqtt (%s)", strerror(errno)); + goto close_mqtt; } -} -static void SourceCurrentDir(lua_State* lua_state) { - DIR* current_dir = opendir("."); - if (!current_dir) Terminate("Failed to open current dir"); - for (struct dirent* item; (item = readdir(current_dir));) { - if (item->d_type != DT_REG) continue; - size_t length = strlen(item->d_name); - static const char kLuaExt[] = {'.', 'l', 'u', 'a'}; - if (length < sizeof(kLuaExt)) continue; - const char* ext = item->d_name + length - sizeof(kLuaExt); - if (memcmp(ext, kLuaExt, sizeof(kLuaExt))) continue; - Log("Sourcing %s...", item->d_name); - if (luaL_dofile(lua_state, item->d_name)) - Log("%s", lua_tostring(lua_state, -1)); + if (signal(SIGINT, OnSignal) == SIG_ERR || + signal(SIGTERM, OnSignal) == SIG_ERR) { + LOGE("Failed to set signal handlers (%s)", strerror(errno)); + goto disconnect_mqtt; } - closedir(current_dir); -} -static int LuaSubscribe(lua_State* lua_state) { - // TODO(mburakov): Handle lua errors. -#if 0 - // mburakov: Userdata is broken on AArch64 - struct Context* context = lua_touserdata(lua_state, lua_upvalueindex(1)); -#else - const void* ctx = lua_tolstring(lua_state, lua_upvalueindex(1), NULL); - struct Context* context = *(void* const*)ctx; -#endif - const char* topic = lua_tolstring(lua_state, -2, NULL); - struct Message* message = GetMessage(&context->messages, topic); - if (!message) { - Log("Failed to get message"); - return 0; + IoMuxerCreate(&context.io_muxer); + if (!IoMuxerOnRead(&context.io_muxer, context.http, OnHttpRead, &context) || + !IoMuxerOnRead(&context.io_muxer, context.mqtt, OnMqttRead, &context)) { + LOGE("Failed to init iomuxer (%s)", strerror(errno)); + goto destroy_iomuxer; } - message->handler = luaL_ref(lua_state, LUA_REGISTRYINDEX); - return 0; -} -static int LuaPublish(lua_State* lua_state) { - // TODO(mburakov): Handle lua errors. -#if 0 - // mburakov: Userdata is broken on AArch64 - struct Context* context = lua_touserdata(lua_state, lua_upvalueindex(1)); -#else - const void* ctx = lua_tolstring(lua_state, lua_upvalueindex(1), NULL); - struct Context* context = *(void* const*)ctx; -#endif - char* buffer = context->buffer; - size_t topic_size, payload_size; - const char* topic = lua_tolstring(lua_state, -2, &topic_size); - const char* payload = lua_tolstring(lua_state, -1, &payload_size); - // mburakov: Not allowed to publish from mosquitto callback. - if (!BufferAppend(context, topic, topic_size + 1) || - !BufferAppend(context, payload, payload_size + 1)) { - Log("Failed to schedule mosquitto publish"); - context->buffer = buffer; - return 0; + BufferCreate(&context.mqtt_buffer); + while (!g_signal) { + enum IoMuxerResult result = IoMuxerIterate(&context.io_muxer, -1); + if (result == kIoMuxerResultError && errno != EINTR) { + LOGE("Failed to iterate iomuxer (%s)", strerror(errno)); + OnSignal(SIGABRT); + } } - return 0; -} -int main(int argc, char* argv[]) { +destroy_iomuxer: + IoMuxerDestroy(&context.io_muxer); +disconnect_mqtt: + MqttDisconnect(context.mqtt); +close_mqtt: + close(context.mqtt); +close_http: + close(context.http); + bool success = g_signal == SIGINT || g_signal == SIGTERM; + return success ? EXIT_SUCCESS : EXIT_FAILURE; + +#if 0 if (argc < 3) Terminate("Usage: %s <host> <port>", argv[0]); int port = atoi(argv[2]); if (0 >= port || port >= 65536) @@ -433,4 +380,5 @@ rollback_epoll_create: free(context.buffer); close(epfd); return result; +#endif } |
