/*
* Copyright (C) 2022 Mikhail Burakov. This file is part of toolbox.
*
* toolbox is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* toolbox is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with toolbox. If not, see .
*/
#include "mqtt.h"
#include
#include
#include
#include
#define LENGTH(op) (sizeof(op) / sizeof *(op))
#define UNCONST(op) ((void*)(uintptr_t)(op))
static size_t WriteVarint(size_t varint, uint8_t* buffer) {
if (varint > 268435455) return 0;
size_t result = 0;
for (;;) {
buffer[result] = varint & 0x7f;
varint = varint >> 7;
if (varint) {
buffer[result] |= 0x80;
result++;
} else {
return result + 1;
}
}
}
bool MqttConnect(int mqtt, uint16_t keepalive) {
struct __attribute__((__packed__)) {
uint8_t packet_type;
uint8_t message_length;
uint16_t protocol_name_length;
char protocol_name[4];
uint8_t protocol_level;
uint8_t connect_flags;
uint16_t keepalive;
uint16_t client_id_length;
} connect_message = {
.packet_type = 0x10,
.message_length = 12,
.protocol_name_length = htons(4),
.protocol_name = {'M', 'Q', 'T', 'T'},
.protocol_level = 4,
.connect_flags = 0x02,
.keepalive = htons(keepalive),
.client_id_length = 0,
};
static_assert(sizeof(connect_message) == 14,
"Unexpected connect message size");
return write(mqtt, &connect_message, sizeof(connect_message)) ==
sizeof(connect_message);
}
bool MqttSubscribe(int mqtt, uint16_t message_id, const char* topic,
uint16_t topic_size) {
uint8_t prefix[9] = {0x82};
static const uint8_t qos[] = {0};
size_t prefix_digits = WriteVarint(
sizeof(message_id) + sizeof(topic_size) + topic_size + sizeof(qos),
prefix + 1);
if (!prefix_digits) return false;
uint8_t* ptr = prefix + 1 + prefix_digits;
*ptr++ = message_id >> 8 & 0xff;
*ptr++ = message_id & 0xff;
*ptr++ = topic_size >> 8 & 0xff;
*ptr++ = topic_size & 0xff;
struct iovec iov[] = {
{.iov_base = prefix, .iov_len = (size_t)(ptr - prefix)},
{.iov_base = UNCONST(topic), .iov_len = topic_size},
{.iov_base = UNCONST(qos), .iov_len = sizeof(qos)},
};
return writev(mqtt, iov, LENGTH(iov)) ==
(ssize_t)(iov[0].iov_len + iov[1].iov_len + iov[2].iov_len);
}
bool MqttPublish(int mqtt, const char* topic, uint16_t topic_size,
const void* payload, size_t payload_size) {
uint8_t prefix[7] = {0x30};
size_t prefix_digits =
WriteVarint(sizeof(topic_size) + topic_size + payload_size, prefix + 1);
if (!prefix_digits) return false;
uint8_t* ptr = prefix + 1 + prefix_digits;
*ptr++ = topic_size >> 8 & 0xff;
*ptr++ = topic_size & 0xff;
struct iovec iov[] = {
{.iov_base = prefix, .iov_len = (size_t)(ptr - prefix)},
{.iov_base = UNCONST(topic), .iov_len = topic_size},
{.iov_base = UNCONST(payload), .iov_len = payload_size},
};
return writev(mqtt, iov, LENGTH(iov)) ==
(ssize_t)(iov[0].iov_len + iov[1].iov_len + iov[2].iov_len);
}
bool MqttPing(int mqtt) {
struct __attribute__((__packed__)) {
uint8_t packet_type;
uint8_t message_length;
} ping_message = {
.packet_type = 0xd0,
.message_length = 0,
};
static_assert(sizeof(ping_message) == 2, "Unexpected ping message size");
return write(mqtt, &ping_message, sizeof(ping_message)) ==
sizeof(ping_message);
}
bool MqttDisconnect(int mqtt) {
struct __attribute__((__packed__)) {
uint8_t packet_type;
uint8_t message_length;
} disconnect_message = {
.packet_type = 0xe0,
.message_length = 0,
};
static_assert(sizeof(disconnect_message) == 2,
"Unexpected disconnect message size");
return write(mqtt, &disconnect_message, sizeof(disconnect_message)) ==
sizeof(disconnect_message);
}