#include <string.h>
#include <stdio.h>

#include "jwt.h"
#include "dev_name.h"
#include "crypto.h"

#include "esp_log.h"

static const char *TAG = "jwt";

char *jwt_gen(cJSON *payload)
{
    // Generate header
    char *kid = dev_name();
    if (!kid)
    {
        ESP_LOGE(TAG, "Failed to allocated memory to store device name!");
        return NULL;
    }

    cJSON *header_json = cJSON_CreateObject();
    if (!header_json)
        return NULL;
    cJSON_AddStringToObject(header_json, "alg", "ES256");
    cJSON_AddStringToObject(header_json, "typ", "JWT");
    cJSON_AddStringToObject(header_json, "kid", kid);

    char *header = cJSON_PrintUnformatted(header_json);
    free(kid);
    cJSON_Delete(header_json);
    if (!header)
    {
        ESP_LOGE(TAG, "Failed to generate JSON header!");
        return NULL;
    }

    char *header_b64 = crypto_encode_base64_safe_url(header, strlen(header));
    free(header);

    if (!header_b64)
    {
        ESP_LOGE(TAG, "Failed to encode header to base64!");
        return NULL;
    }

    // Encode body to JSON
    char *body_json = cJSON_PrintUnformatted(payload);
    if (!body_json)
    {
        ESP_LOGE(TAG, "Failed to encode body to JSON!");
        free(header_b64);
        return NULL;
    }

    char *body_b64 = crypto_encode_base64_safe_url(body_json, strlen(body_json));
    free(body_json);
    if (!body_b64)
    {
        ESP_LOGE(TAG, "Failed to encode body to base64!");
        free(header_b64);
        return NULL;
    }

    // Assemble unsigned JWT parts
    char *unsigned_jwt = calloc(1, strlen(header_b64) + strlen(body_b64) + 2);
    if (!unsigned_jwt)
    {
        ESP_LOGE(TAG, "Failed to allocate memory to store unsigned JWT!");
        free(header_b64);
        free(body_b64);
        return NULL;
    }

    sprintf(unsigned_jwt, "%s.%s", header_b64, body_b64);
    free(header_b64);
    free(body_b64);

    size_t sig_len = 0;
    char *sig = crypto_sign_sha256_payload(unsigned_jwt, strlen(unsigned_jwt), &sig_len);

    if (!sig || sig_len == 0)
    {
        ESP_LOGE(TAG, "Failed to sign JWT!");
        if (sig)
            free(sig);
        free(unsigned_jwt);
        return NULL;
    }

    char *sig_b64 = crypto_encode_base64_safe_url(sig, sig_len);
    free(sig);
    if (!sig_b64)
    {
        ESP_LOGE(TAG, "Failed to encode base64 signature to base64!");
        free(unsigned_jwt);
        return NULL;
    }

    char *jwt = calloc(1, 1 + strlen(unsigned_jwt) + 1 + strlen(sig_b64));
    if (!jwt)
    {
        ESP_LOGE(TAG, "Failed to allocate memory to store final JWT!");
        free(unsigned_jwt);
        free(sig_b64);
        return NULL;
    }
    sprintf(jwt, "%s.%s", unsigned_jwt, sig_b64);

    free(unsigned_jwt);
    free(sig_b64);

    return jwt;
}