summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.c168
1 files changed, 131 insertions, 37 deletions
diff --git a/main.c b/main.c
index 31e78cc..9d42425 100644
--- a/main.c
+++ b/main.c
@@ -18,6 +18,7 @@
#include <errno.h>
#include <luajit.h>
#include <mosquitto.h>
+#include <search.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
@@ -30,13 +31,17 @@
#include "server.h"
#define UNCONST(op) ((void*)(uintptr_t)(op))
-#define STRIOVEC(op) \
- { .iov_base = UNCONST(op), .iov_len = sizeof(op) - 1 }
-#define FLUSHARG(op) \
- { STRIOVEC(op), {.iov_base = NULL, .iov_len = 0}, }
+
+struct Message {
+ char* topic;
+ void* payload;
+ size_t payloadlen;
+ void* handler;
+};
struct Context {
struct mosquitto* mosq;
+ void* messages;
};
static volatile sig_atomic_t g_shutdown;
@@ -46,8 +51,22 @@ static void OnSignal(int num) {
g_shutdown = 1;
}
-static bool Flush(int fd, const struct iovec iov[2]) {
- // TODO(mburakov): Change to iterative writing.
+static int CompareMessages(const void* a, const void* b) {
+ return strcmp(((const struct Message*)a)->topic,
+ ((const struct Message*)b)->topic);
+}
+
+static bool Flush(int fd, const char* status, const void* body,
+ size_t body_size) {
+ char buffer[64];
+ int length = snprintf(buffer, sizeof(buffer),
+ "HTTP/1.1 %s\r\n"
+ "Content-Length: %zu\r\n"
+ "\r\n",
+ status, body_size);
+ // TODO(mburakov): Verify length is valid at this point.
+ struct iovec iov[] = {{.iov_base = buffer, .iov_len = (size_t)length},
+ {.iov_base = UNCONST(body), .iov_len = body_size}};
ssize_t result = writev(fd, iov, 2);
if (result != (ssize_t)(iov[0].iov_len + iov[1].iov_len)) {
Log("Failed to write complete reply (%s)", strerror(errno));
@@ -56,43 +75,112 @@ static bool Flush(int fd, const struct iovec iov[2]) {
return true;
}
+static void CollectMatchingMessages(const void* nodep, VISIT which,
+ void* closure) {
+ struct {
+ const char* topic;
+ size_t topic_len;
+ char* buffer;
+ size_t buffer_size;
+ }* arg = closure;
+ const struct Message* const* it = nodep;
+ if (which == preorder || which == endorder ||
+ strncmp((*it)->topic, arg->topic, arg->topic_len))
+ return;
+ size_t topic_len = strlen((*it)->topic);
+ size_t buffer_size = arg->buffer_size + topic_len + 1;
+ char* buffer = realloc(arg->buffer, buffer_size);
+ if (!buffer) {
+ Log("Failed to realloc buffer (%s)", strerror(errno));
+ return;
+ }
+ memcpy(buffer + arg->buffer_size, (*it)->topic, topic_len);
+ buffer[buffer_size - 1] = '\n';
+ arg->buffer = buffer;
+ arg->buffer_size = buffer_size;
+}
+
static bool HandleRequest(void* user, int fd, const char* method,
const char* target, const void* body,
size_t body_size) {
struct Context* context = user;
- if (strcmp(method, "POST")) {
- static const struct iovec kStatus405[2] = FLUSHARG(
- "HTTP/1.1 405 Method Not Allowed\r\n"
- "Content-Length: 0\r\n"
- "\r\n");
- return Flush(fd, kStatus405);
- }
- if (!body_size || !target[1]) {
- static const struct iovec kStatus400[2] = FLUSHARG(
- "HTTP/1.1 400 Bad Request\r\n"
- "Content-Length: 0\r\n"
- "\r\n");
- return Flush(fd, kStatus400);
+ if (!strcmp(method, "GET")) {
+ struct {
+ const char* topic;
+ size_t topic_len;
+ char* buffer;
+ size_t buffer_size;
+ } arg = {.topic = target + 1,
+ .topic_len = strlen(target) - 1,
+ .buffer = NULL,
+ .buffer_size = 0};
+ twalk_r(context->messages, CollectMatchingMessages, &arg);
+ return Flush(fd, "200 OK", arg.buffer, arg.buffer_size);
}
+ if (strcmp(method, "POST"))
+ return Flush(fd, "405 Method Not Allowed", NULL, 0);
+ if (!body_size || !target[1]) return Flush(fd, "400 Bad Request", NULL, 0);
int mosq_errno = mosquitto_publish(context->mosq, NULL, target + 1,
(int)body_size, body, 0, false);
- if (mosq_errno == MOSQ_ERR_SUCCESS) {
- static const struct iovec kStatus200[2] = FLUSHARG(
- "HTTP/1.1 200 OK\r\n"
- "Content-Length: 0\r\n"
- "\r\n");
- return Flush(fd, kStatus200);
- }
- char buffer[64];
+ if (mosq_errno == MOSQ_ERR_SUCCESS) return Flush(fd, "200 OK", NULL, 0);
const char* error = mosquitto_strerror(mosq_errno);
- snprintf(buffer, sizeof(buffer),
- "HTTP/1.1 500 Internal Server Error\r\n"
- "Content-Length: %zu\r\n\r\n",
- strlen(error));
- const struct iovec iov[2] = {
- {.iov_base = buffer, .iov_len = strlen(buffer)},
- {.iov_base = UNCONST(error), .iov_len = strlen(error)}};
- return Flush(fd, iov);
+ return Flush(fd, "500 Internal Server Error", error, strlen(error));
+}
+
+static void StoreMessagePayload(void** messages,
+ const struct mosquitto_message* message) {
+ struct Message pred = {.topic = message->topic};
+ struct Message** it = tsearch(&pred, messages, CompareMessages);
+ if (!it) {
+ Log("Failed to add message to the map");
+ return;
+ }
+ if (*it != &pred) {
+ size_t payloadlen = (size_t)message->payloadlen;
+ void* buffer = malloc(payloadlen);
+ if (!buffer) {
+ Log("Failed to copy payload (%s)", strerror(errno));
+ return;
+ }
+ memcpy(buffer, message->payload, payloadlen);
+ free((*it)->payload);
+ (*it)->payload = buffer;
+ (*it)->payloadlen = payloadlen;
+ return;
+ }
+ struct Message* added = calloc(1, sizeof(struct Message));
+ if (!added) {
+ Log("Failed to allocate new message (%s)", strerror(errno));
+ goto rollback_tsearch;
+ }
+ added->topic = strdup(message->topic);
+ if (!added->topic) {
+ Log("Failed to copy topic (%s)", strerror(errno));
+ goto rollback_calloc;
+ }
+ size_t payloadlen = (size_t)message->payloadlen;
+ added->payload = malloc(payloadlen);
+ if (!added->payload) {
+ Log("Failed to copy payload (%s)", strerror(errno));
+ goto rollback_strdup;
+ }
+ memcpy(added->payload, message->payload, payloadlen);
+ added->payloadlen = payloadlen;
+ *it = added;
+ return;
+rollback_strdup:
+ free(added->topic);
+rollback_calloc:
+ free(added);
+rollback_tsearch:
+ tdelete(&pred, messages, CompareMessages);
+}
+
+static void HandleMosquitto(struct mosquitto* mosq, void* user,
+ const struct mosquitto_message* message) {
+ (void)mosq;
+ struct Context* context = user;
+ StoreMessagePayload(&context->messages, message);
}
int main(int argc, char* argv[]) {
@@ -103,7 +191,7 @@ int main(int argc, char* argv[]) {
int epfd = epoll_create(1);
if (epfd == -1) Terminate("Failed to create epoll (%s)", strerror(errno));
int result = EXIT_FAILURE;
- struct Context context;
+ struct Context context = {.mosq = NULL, .messages = NULL};
struct Server* server = ServerCreate(epfd, HandleRequest, &context);
if (!server) {
Log("Failed to create server");
@@ -115,16 +203,22 @@ int main(int argc, char* argv[]) {
mosquitto_strerror(mosq_errno));
goto rollback_server_create;
}
- context.mosq = mosquitto_new(NULL, true, NULL);
+ context.mosq = mosquitto_new(NULL, true, &context);
if (!context.mosq) {
Log("Failed to create mosquitto (%s)", strerror(errno));
goto rollback_mosquitto_lib_init;
}
+ mosquitto_message_callback_set(context.mosq, HandleMosquitto);
mosq_errno = mosquitto_connect(context.mosq, argv[1], port, 60);
if (mosq_errno != MOSQ_ERR_SUCCESS) {
Log("Failed to connect mosquitto (%s)", mosquitto_strerror(mosq_errno));
goto rollback_mosquitto_new;
}
+ mosq_errno = mosquitto_subscribe(context.mosq, NULL, "+/#", 0);
+ if (mosq_errno != MOSQ_ERR_SUCCESS) {
+ Log("Failed to subscribe mosquitto (%s)", mosquitto_strerror(mosq_errno));
+ goto rollback_mosquitto_connect;
+ }
int mosq_sock = mosquitto_socket(context.mosq);
if (mosq_sock == -1) {
Log("Failed to get mosquitto socket");