summaryrefslogtreecommitdiff
path: root/mqtt.c
blob: 5f61b362d6e569585a1fbd4085493c1bbfcb15f4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/*
 * Copyright (C) 2022 Mikhail Burakov. This file is part of toolbox.
 *
 * toolbox is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * toolbox is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with toolbox.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "mqtt.h"

#include <arpa/inet.h>
#include <assert.h>
#include <sys/uio.h>
#include <unistd.h>

#define LENGTH(op) (sizeof(op) / sizeof *(op))
#define UNCONST(op) ((void*)(uintptr_t)(op))

static size_t WriteVarint(size_t varint, uint8_t* buffer) {
  if (varint > 268435455) return 0;
  size_t result = 0;
  for (;;) {
    buffer[result] = varint & 0x7f;
    varint = varint >> 7;
    if (varint) {
      buffer[result] |= 0x80;
      result++;
    } else {
      return result + 1;
    }
  }
}

bool MqttConnect(int mqtt, uint16_t keepalive) {
  struct __attribute__((__packed__)) {
    uint8_t packet_type;
    uint8_t message_length;
    uint16_t protocol_name_length;
    char protocol_name[4];
    uint8_t protocol_level;
    uint8_t connect_flags;
    uint16_t keepalive;
    uint16_t client_id_length;
  } connect_message = {
      .packet_type = 0x10,
      .message_length = 12,
      .protocol_name_length = htons(4),
      .protocol_name = {'M', 'Q', 'T', 'T'},
      .protocol_level = 4,
      .connect_flags = 0x02,
      .keepalive = htons(keepalive),
      .client_id_length = 0,
  };
  static_assert(sizeof(connect_message) == 14,
                "Unexpected connect message size");
  return write(mqtt, &connect_message, sizeof(connect_message)) ==
         sizeof(connect_message);
}

bool MqttSubscribe(int mqtt, uint16_t message_id, const char* topic,
                   uint16_t topic_size) {
  uint8_t prefix[9] = {0x82};
  static const uint8_t qos[] = {0};
  size_t prefix_digits = WriteVarint(
      sizeof(message_id) + sizeof(topic_size) + topic_size + sizeof(qos),
      prefix + 1);
  if (!prefix_digits) return false;

  uint8_t* ptr = prefix + 1 + prefix_digits;
  *ptr++ = message_id >> 8 & 0xff;
  *ptr++ = message_id & 0xff;
  *ptr++ = topic_size >> 8 & 0xff;
  *ptr++ = topic_size & 0xff;

  struct iovec iov[] = {
      {.iov_base = prefix, .iov_len = (size_t)(ptr - prefix)},
      {.iov_base = UNCONST(topic), .iov_len = topic_size},
      {.iov_base = UNCONST(qos), .iov_len = sizeof(qos)},
  };
  return writev(mqtt, iov, LENGTH(iov)) ==
         (ssize_t)(iov[0].iov_len + iov[1].iov_len + iov[2].iov_len);
}

bool MqttPublish(int mqtt, const char* topic, uint16_t topic_size,
                 const void* payload, size_t payload_size) {
  uint8_t prefix[7] = {0x30};
  size_t prefix_digits =
      WriteVarint(sizeof(topic_size) + topic_size + payload_size, prefix + 1);
  if (!prefix_digits) return false;

  uint8_t* ptr = prefix + 1 + prefix_digits;
  *ptr++ = topic_size >> 8 & 0xff;
  *ptr++ = topic_size & 0xff;

  struct iovec iov[] = {
      {.iov_base = prefix, .iov_len = (size_t)(ptr - prefix)},
      {.iov_base = UNCONST(topic), .iov_len = topic_size},
      {.iov_base = UNCONST(payload), .iov_len = payload_size},
  };
  return writev(mqtt, iov, LENGTH(iov)) ==
         (ssize_t)(iov[0].iov_len + iov[1].iov_len + iov[2].iov_len);
}

bool MqttPing(int mqtt) {
  struct __attribute__((__packed__)) {
    uint8_t packet_type;
    uint8_t message_length;
  } ping_message = {
      .packet_type = 0xc0,
      .message_length = 0,
  };
  static_assert(sizeof(ping_message) == 2, "Unexpected ping message size");
  return write(mqtt, &ping_message, sizeof(ping_message)) ==
         sizeof(ping_message);
}

bool MqttDisconnect(int mqtt) {
  struct __attribute__((__packed__)) {
    uint8_t packet_type;
    uint8_t message_length;
  } disconnect_message = {
      .packet_type = 0xe0,
      .message_length = 0,
  };
  static_assert(sizeof(disconnect_message) == 2,
                "Unexpected disconnect message size");
  return write(mqtt, &disconnect_message, sizeof(disconnect_message)) ==
         sizeof(disconnect_message);
}