/*
 * Copyright (C) 2022 Mikhail Burakov. This file is part of toolbox.
 *
 * toolbox is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * toolbox is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with toolbox.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "mqtt.h"

#include <arpa/inet.h>
#include <assert.h>
#include <sys/uio.h>
#include <unistd.h>

#define LENGTH(op) (sizeof(op) / sizeof *(op))
#define UNCONST(op) ((void*)(uintptr_t)(op))

static size_t WriteVarint(size_t varint, uint8_t* buffer) {
  if (varint > 268435455) return 0;
  size_t result = 0;
  for (;;) {
    buffer[result] = varint & 0x7f;
    varint = varint >> 7;
    if (varint) {
      buffer[result] |= 0x80;
      result++;
    } else {
      return result + 1;
    }
  }
}

bool MqttConnect(int mqtt, uint16_t keepalive) {
  struct __attribute__((__packed__)) {
    uint8_t packet_type;
    uint8_t message_length;
    uint16_t protocol_name_length;
    char protocol_name[4];
    uint8_t protocol_level;
    uint8_t connect_flags;
    uint16_t keepalive;
    uint16_t client_id_length;
  } connect_message = {
      .packet_type = 0x10,
      .message_length = 12,
      .protocol_name_length = htons(4),
      .protocol_name = {'M', 'Q', 'T', 'T'},
      .protocol_level = 4,
      .connect_flags = 0x02,
      .keepalive = htons(keepalive),
      .client_id_length = 0,
  };
  static_assert(sizeof(connect_message) == 14,
                "Unexpected connect message size");
  return write(mqtt, &connect_message, sizeof(connect_message)) ==
         sizeof(connect_message);
}

bool MqttSubscribe(int mqtt, uint16_t message_id, const char* topic,
                   uint16_t topic_size) {
  uint8_t prefix[9] = {0x82};
  static const uint8_t qos[] = {0};
  size_t prefix_digits = WriteVarint(
      sizeof(message_id) + sizeof(topic_size) + topic_size + sizeof(qos),
      prefix + 1);
  if (!prefix_digits) return false;

  uint8_t* ptr = prefix + 1 + prefix_digits;
  *ptr++ = message_id >> 8 & 0xff;
  *ptr++ = message_id & 0xff;
  *ptr++ = topic_size >> 8 & 0xff;
  *ptr++ = topic_size & 0xff;

  struct iovec iov[] = {
      {.iov_base = prefix, .iov_len = (size_t)(ptr - prefix)},
      {.iov_base = UNCONST(topic), .iov_len = topic_size},
      {.iov_base = UNCONST(qos), .iov_len = sizeof(qos)},
  };
  return writev(mqtt, iov, LENGTH(iov)) ==
         (ssize_t)(iov[0].iov_len + iov[1].iov_len + iov[2].iov_len);
}

bool MqttPublish(int mqtt, const char* topic, uint16_t topic_size,
                 const void* payload, size_t payload_size) {
  uint8_t prefix[7] = {0x30};
  size_t prefix_digits =
      WriteVarint(sizeof(topic_size) + topic_size + payload_size, prefix + 1);
  if (!prefix_digits) return false;

  uint8_t* ptr = prefix + 1 + prefix_digits;
  *ptr++ = topic_size >> 8 & 0xff;
  *ptr++ = topic_size & 0xff;

  struct iovec iov[] = {
      {.iov_base = prefix, .iov_len = (size_t)(ptr - prefix)},
      {.iov_base = UNCONST(topic), .iov_len = topic_size},
      {.iov_base = UNCONST(payload), .iov_len = payload_size},
  };
  return writev(mqtt, iov, LENGTH(iov)) ==
         (ssize_t)(iov[0].iov_len + iov[1].iov_len + iov[2].iov_len);
}

bool MqttPing(int mqtt) {
  struct __attribute__((__packed__)) {
    uint8_t packet_type;
    uint8_t message_length;
  } ping_message = {
      .packet_type = 0xc0,
      .message_length = 0,
  };
  static_assert(sizeof(ping_message) == 2, "Unexpected ping message size");
  return write(mqtt, &ping_message, sizeof(ping_message)) ==
         sizeof(ping_message);
}

bool MqttDisconnect(int mqtt) {
  struct __attribute__((__packed__)) {
    uint8_t packet_type;
    uint8_t message_length;
  } disconnect_message = {
      .packet_type = 0xe0,
      .message_length = 0,
  };
  static_assert(sizeof(disconnect_message) == 2,
                "Unexpected disconnect message size");
  return write(mqtt, &disconnect_message, sizeof(disconnect_message)) ==
         sizeof(disconnect_message);
}