// SPDX-License-Identifier: GPL-2.0 OR CDDL-1.0
/*
 * CDDL HEADER START
 *
 * The contents of this file are subject to the terms of the
 * Common Development and Distribution License (the "License").
 * You may not use this file except in compliance with the License.
 *
 * You can obtain a copy of the license at usr/src/OPENSOLARIS.LICENSE
 * or https://opensource.org/licenses/CDDL-1.0.
 * See the License for the specific language governing permissions
 * and limitations under the License.
 *
 * When distributing Covered Code, include this CDDL HEADER in each
 * file and include the License file at usr/src/OPENSOLARIS.LICENSE.
 * If applicable, add the following below this CDDL HEADER, with the
 * fields enclosed by brackets "[]" replaced with your own identifying
 * information: Portions Copyright [yyyy] [name of copyright owner]
 *
 * CDDL HEADER END
 */

/* Copyright (c) 2019-2025 Chilledheart  */

#include "net/cipher.hpp"

#include <memory>

#include <base/rand_util.h>
#include "third_party/boringssl/src/include/openssl/base64.h"
#include "third_party/boringssl/src/include/openssl/md5.h"

#include "core/logging.hpp"
#include "crypto/decrypter.hpp"
#include "crypto/encrypter.hpp"
#include "net/hkdf_sha1.hpp"

#define CHUNK_SIZE_LEN 2U
#define CHUNK_SIZE_MASK 0x3FFFU

namespace net {

class cipher_impl {
 public:
  cipher_impl(enum cipher_method method, bool enc) {
    DCHECK_GT(method, CRYPTO_INVALID);
    if (enc) {
      encrypter = crypto::Encrypter::CreateFromCipherSuite(method);
    } else {
      decrypter = crypto::Decrypter::CreateFromCipherSuite(method);
    }
  }

  static int parse_key(const std::string& key, uint8_t* skey, size_t skey_len) {
    const char* base64 = key.c_str();
    const size_t base64_len = key.size();
    size_t out_len;

    if (!EVP_DecodedLength(&out_len, base64_len)) {
      LOG(WARNING) << "Invalid base64 len: " << base64_len;
      return 0;
    }
    std::unique_ptr<uint8_t[]> out = std::make_unique<uint8_t[]>(out_len);

    out_len = EVP_DecodeBase64(out.get(), &out_len, out_len, reinterpret_cast<const uint8_t*>(base64), base64_len);
    if (out_len > 0 && out_len >= skey_len) {
      memcpy(skey, out.get(), skey_len);
      return skey_len;
    }
    return 0;
  }

  /// The master key can be input directly from user or generated from a
  /// password. The key derivation is still following EVP_BytesToKey(3) in
  /// OpenSSL.
  static int derive_key(const std::string& key, uint8_t* skey, size_t skey_len) {
    const char* pass = key.c_str();
    size_t datal = key.size();
    MD5_CTX c;
    uint8_t md_buf[MD5_DIGEST_LENGTH];
    int addmd;
    unsigned int i, j, mds;

    if (key.empty()) {
      return skey_len;
    }

    mds = 16;
    MD5_Init(&c);

    for (j = 0, addmd = 0; j < skey_len; addmd++) {
      MD5_Init(&c);
      if (addmd) {
        MD5_Update(&c, md_buf, mds);
      }
      MD5_Update(&c, reinterpret_cast<const uint8_t*>(pass), datal);
      MD5_Final(md_buf, &c);

      for (i = 0; i < mds; i++, j++) {
        if (j >= skey_len)
          break;
        skey[j] = md_buf[i];
      }
    }

    return skey_len;
  }

  /// EncryptPacket is a function that takes a secret key,
  /// a non-secret nonce, a message, and produces ciphertext and authentication
  /// tag. Nonce (NoncePrefix + packet_number, or vice versa) must be unique for
  /// a given key in each invocation.
  int EncryptPacket(uint64_t packet_number, uint8_t* c, size_t* clen, const uint8_t* m, size_t mlen) {
    int err = 0;

    if (!encrypter->EncryptPacket(packet_number, nullptr, 0U, reinterpret_cast<const char*>(m), mlen,
                                  reinterpret_cast<char*>(c), clen, *clen)) {
      err = -1;
    }

    return err;
  }

  /// DecryptPacket is a function that takes a secret key,
  /// non-secret nonce, ciphertext, authentication tag, and produces original
  /// message. If any of the input is tampered with, decryption will fail.
  int DecryptPacket(uint64_t packet_number, uint8_t* p, size_t* plen, const uint8_t* m, size_t mlen) {
    int err = 0;

    if (!decrypter->DecryptPacket(packet_number, nullptr, 0U, reinterpret_cast<const char*>(m), mlen,
                                  reinterpret_cast<char*>(p), plen, *plen)) {
      err = -1;
    }

    return err;
  }

  bool SetKey(const uint8_t* key, size_t key_len) {
    if (encrypter) {
      return encrypter->SetKey(reinterpret_cast<const char*>(key), key_len);
    }
    if (decrypter) {
      return decrypter->SetKey(reinterpret_cast<const char*>(key), key_len);
    }
    return false;
  }

  bool SetNoncePrefix(const uint8_t* nonce_prefix, size_t nonce_prefix_len) {
    if (encrypter) {
      return encrypter->SetNoncePrefix(reinterpret_cast<const char*>(nonce_prefix),
                                       std::min(nonce_prefix_len, encrypter->GetNoncePrefixSize()));
    }
    if (decrypter) {
      return decrypter->SetNoncePrefix(reinterpret_cast<const char*>(nonce_prefix),
                                       std::min(nonce_prefix_len, decrypter->GetNoncePrefixSize()));
    }
    return false;
  }

  bool SetIV(const uint8_t* iv, size_t iv_len) {
    if (encrypter) {
      return encrypter->SetIV(reinterpret_cast<const char*>(iv), iv_len);
    }
    if (decrypter) {
      return decrypter->SetIV(reinterpret_cast<const char*>(iv), iv_len);
    }
    return false;
  }

  size_t GetKeySize() const { return encrypter ? encrypter->GetKeySize() : decrypter->GetKeySize(); }

  size_t GetNoncePrefixSize() const {
    return encrypter ? encrypter->GetNoncePrefixSize() : decrypter->GetNoncePrefixSize();
  }

  size_t GetIVSize() const { return encrypter ? encrypter->GetIVSize() : decrypter->GetIVSize(); }

  size_t GetTagSize() const { return encrypter ? encrypter->GetTagSize() : decrypter->GetTagSize(); }

  const uint8_t* GetKey() const { return encrypter ? encrypter->GetKey() : decrypter->GetKey(); }

  const uint8_t* GetIV() const { return encrypter ? encrypter->GetIV() : decrypter->GetIV(); }

  const uint8_t* GetNoncePrefix() const {
    return encrypter ? encrypter->GetNoncePrefix() : decrypter->GetNoncePrefix();
  }

  uint32_t cipher_id() const {
    if (encrypter) {
      return encrypter->cipher_id();
    }
    if (decrypter) {
      return decrypter->cipher_id();
    }
    return CRYPTO_INVALID;
  }

  std::unique_ptr<crypto::Encrypter> encrypter;
  std::unique_ptr<crypto::Decrypter> decrypter;
};

cipher::cipher(const std::string& key,
               const std::string& password,
               enum cipher_method method,
               cipher_visitor_interface* visitor,
               bool enc)
    : salt_(), key_(), counter_(), init_(false), visitor_(visitor) {
  DCHECK(is_valid_cipher_method(method));
  VLOG(3) << "cipher: " << (enc ? "encoder" : "decoder") << " create with key \"" << key << "\" password \"" << password
          << "\" cipher_method: " << to_cipher_method_str(method);

  impl_ = std::make_unique<cipher_impl>(method, enc);
  key_bitlen_ = impl_->GetKeySize() * 8;
  key_len_ = !key.empty() ? cipher_impl::parse_key(key, key_, key_bitlen_ / 8)
                          : cipher_impl::derive_key(password, key_, key_bitlen_ / 8);

  DumpHex("cipher: KEY", key_, key_len_);

  tag_len_ = impl_->GetTagSize();
}

cipher::~cipher() = default;

void cipher::process_bytes(GrowableIOBuffer* ciphertext) {
  if (chunk_) {
    const int chunk_size = chunk_->RemainingCapacity() + ciphertext->size();
    auto previous_chunk = chunk_;
    chunk_ = gurl_base::MakeRefCounted<GrowableIOBuffer>();
    chunk_->SetCapacity(chunk_size);
    memcpy(chunk_->data(), previous_chunk->data(), previous_chunk->RemainingCapacity());
    memcpy(chunk_->data() + previous_chunk->RemainingCapacity(), ciphertext->data(), ciphertext->RemainingCapacity());
  } else {
    chunk_ = ciphertext;
  }

  if (!init_) {
    if (chunk_->size() < key_len_) {
      return;
    }
    decrypt_salt(chunk_.get());

    init_ = true;
  }

  while (!chunk_->empty()) {
    auto plaintext = gurl_base::MakeRefCounted<GrowableIOBuffer>();
    plaintext->SetCapacity(SOCKET_BUF_SIZE);

    uint64_t counter = counter_;

    int ret = chunk_decrypt_frame(&counter, plaintext.get(), chunk_.get());

    if (ret == -EAGAIN) {
      break;
    }

    if (ret < 0) {
      visitor_->on_protocol_error();
      break;
    }

    counter_ = counter;

    // ready to deliver plaintext
    plaintext->SetCapacity(plaintext->offset());
    plaintext->set_offset(0);

    // DISCARD
    if (!visitor_->on_received_data(plaintext.get())) {
      break;
    }
  }
}

void cipher::encrypt(const uint8_t* plaintext_data, size_t plaintext_size, GrowableIOBuffer* ciphertext) {
  DCHECK(ciphertext);

  if (!init_) {
    encrypt_salt(ciphertext);
    init_ = true;
  }

  uint64_t counter = counter_;

  // TBD better to apply MTU-like things
  int ret = chunk_encrypt_frame(&counter, plaintext_data, plaintext_size, ciphertext);
  if (ret < 0) {
    visitor_->on_protocol_error();
    return;
  }

  counter_ = counter;
}

void cipher::decrypt_salt(GrowableIOBuffer* chunk) {
  DCHECK(!init_);
  DCHECK_EQ(chunk->offset(), 0);

#ifdef HAVE_MBEDTLS
  if (impl_->cipher_id() >= CRYPTO_AES_128_CFB && impl_->cipher_id() <= CRYPTO_CAMELLIA_256_CFB) {
    const size_t nonce_len = impl_->GetIVSize();
    VLOG(4) << "decrypt: nonce: " << nonce_len;
    uint8_t nonce[MAX_NONCE_LENGTH] = {};
    memcpy(nonce, chunk->data(), nonce_len);
    chunk->set_offset(chunk->offset() + nonce_len);
    set_key_stream(nonce, nonce_len);
    DumpHex("DE-NONCE", nonce, nonce_len);
    return;
  }
#endif

  const size_t salt_len = key_len_;
  VLOG(4) << "decrypt: salt: " << salt_len;

  memcpy(salt_, chunk->data(), salt_len);
  chunk->set_offset(chunk->offset() + salt_len);
  set_key_aead(salt_, salt_len);

  DumpHex("DE-SALT", salt_, salt_len);
}

void cipher::encrypt_salt(GrowableIOBuffer* chunk) {
  DCHECK(!init_);
  DCHECK_EQ(chunk->offset(), 0);

#ifdef HAVE_MBEDTLS
  if (impl_->cipher_id() >= CRYPTO_AES_128_CFB && impl_->cipher_id() <= CRYPTO_CAMELLIA_256_CFB) {
    const size_t nonce_len = impl_->GetIVSize();
    VLOG(4) << "encrypt: nonce: " << nonce_len;
    uint8_t nonce[MAX_NONCE_LENGTH] = {};
    gurl_base::RandBytes(nonce, nonce_len);
    auto previous_capacity = chunk->capacity();
    chunk->SetCapacity(previous_capacity + nonce_len);
    memcpy(chunk->bytes(), nonce, nonce_len);
    set_key_stream(nonce, nonce_len);
    DumpHex("EN-NONCE", nonce, nonce_len);
    return;
  }
#endif

  const size_t salt_len = key_len_;
  VLOG(4) << "encrypt: salt: " << salt_len;
  gurl_base::RandBytes(salt_, key_len_);
  auto previous_capacity = chunk->capacity();
  chunk->SetCapacity(previous_capacity + salt_len);
  memcpy(chunk->StartOfBuffer() + previous_capacity, salt_, salt_len);
  set_key_aead(salt_, salt_len);

  DumpHex("EN-SALT", salt_, salt_len);
}

int cipher::chunk_decrypt_frame(uint64_t* counter, GrowableIOBuffer* plaintext, GrowableIOBuffer* ciphertext) const {
#ifdef HAVE_MBEDTLS
  if (impl_->cipher_id() >= CRYPTO_AES_128_CFB && impl_->cipher_id() <= CRYPTO_CAMELLIA_256_CFB) {
    return chunk_decrypt_frame_stream(counter, plaintext, ciphertext);
  } else {
#endif
    return chunk_decrypt_frame_aead(counter, plaintext, ciphertext);
#ifdef HAVE_MBEDTLS
  }
#endif
}

int cipher::chunk_decrypt_frame_aead(uint64_t* counter,
                                     GrowableIOBuffer* plaintext,
                                     GrowableIOBuffer* ciphertext) const {
  int err;
  int mlen;
  int tlen = tag_len_;
  size_t plen = 0;
  int clen = CHUNK_SIZE_LEN + tlen;

  VLOG(4) << "decrypt: 1st chunk: origin: " << CHUNK_SIZE_LEN << " encrypted: " << clen
          << " actual: " << ciphertext->size();

  if (ciphertext->size() < tlen + static_cast<int>(CHUNK_SIZE_LEN) + tlen) {
    return -EAGAIN;
  }

  union {
    uint8_t buf[2];
    uint16_t cover;
  } len;

  static_assert(sizeof(len) == CHUNK_SIZE_LEN, "Chunk Size not matched");

  plen = sizeof(len.cover);

  err = impl_->DecryptPacket(*counter, len.buf, &plen, ciphertext->bytes(), clen);

  if (err) {
    return -EBADMSG;
  }

  DCHECK_EQ(plen, CHUNK_SIZE_LEN);

  mlen = ntohs(len.cover);
  mlen = mlen & CHUNK_SIZE_MASK;

  if (mlen == 0) {
    return -EBADMSG;
  }

  ciphertext->set_offset(ciphertext->offset() + clen);
  plaintext->SetCapacity(plaintext->offset() + mlen);

  clen = tlen + mlen;

  VLOG(4) << "decrypt: 2nd chunk: origin: " << mlen << " encrypted: " << clen << " actual: " << ciphertext->size();

  if (ciphertext->size() < clen) {
    ciphertext->set_offset(ciphertext->offset() - CHUNK_SIZE_LEN - tlen);
    return -EAGAIN;
  }

  (*counter)++;

  plen = plaintext->capacity();
  err = impl_->DecryptPacket(*counter, plaintext->bytes(), &plen, ciphertext->bytes(), clen);
  if (err) {
    ciphertext->set_offset(ciphertext->offset() - CHUNK_SIZE_LEN - tlen);
    return -EBADMSG;
  }

  DCHECK_EQ(static_cast<int>(plen), mlen);

  (*counter)++;

  ciphertext->set_offset(ciphertext->offset() + clen);
  plaintext->set_offset(plaintext->offset() + plen);

  return 0;
}

int cipher::chunk_decrypt_frame_stream(uint64_t* counter,
                                       GrowableIOBuffer* plaintext,
                                       GrowableIOBuffer* ciphertext) const {
  int err;
  size_t plen = ciphertext->size();
  plaintext->SetCapacity(plaintext->offset() + ciphertext->size());

  VLOG(4) << "decrypt: stream chunk: " << plen << " bytes";

  err = impl_->DecryptPacket(*counter, plaintext->bytes(), &plen, ciphertext->bytes(), ciphertext->size());
  if (err) {
    return -EBADMSG;
  }
  plaintext->set_offset(plaintext->offset() + plen);
  ciphertext->set_offset(ciphertext->offset() + ciphertext->size());
  (*counter)++;
  return 0;
}

int cipher::chunk_encrypt_frame(uint64_t* counter,
                                const uint8_t* plaintext_data,
                                size_t plaintext_size,
                                GrowableIOBuffer* ciphertext) const {
#ifdef HAVE_MBEDTLS
  if (impl_->cipher_id() >= CRYPTO_AES_128_CFB && impl_->cipher_id() <= CRYPTO_CAMELLIA_256_CFB) {
    return chunk_encrypt_frame_stream(counter, plaintext_data, plaintext_size, ciphertext);
  } else {
#endif
    return chunk_encrypt_frame_aead(counter, plaintext_data, plaintext_size, ciphertext);
#ifdef HAVE_MBEDTLS
  }
#endif
}

int cipher::chunk_encrypt_frame_aead(uint64_t* counter,
                                     const uint8_t* plaintext_data,
                                     size_t plaintext_size,
                                     GrowableIOBuffer* ciphertext) const {
  const int tlen = tag_len_;
  const int c_total_len = 2 * tlen + CHUNK_SIZE_LEN + plaintext_size;
  DCHECK_LE(plaintext_size, CHUNK_SIZE_MASK);

  int err;
  int previous_capacity = ciphertext->capacity();
  size_t clen = CHUNK_SIZE_LEN + tlen;
  int headroom = ciphertext->RemainingCapacity();

  ciphertext->SetCapacity(previous_capacity + c_total_len);

  union {
    uint8_t buf[2];
    uint16_t cover;
  } len;

  static_assert(sizeof(len) == CHUNK_SIZE_LEN, "Chunk Size not matched");

  len.cover = htons(plaintext_size & CHUNK_SIZE_MASK);

  VLOG(4) << "encrypt: 1st chunk: origin: " << CHUNK_SIZE_LEN << " encrypted: " << clen;

  DCHECK_GE(c_total_len, static_cast<int>(clen));

  err = impl_->EncryptPacket(*counter, ciphertext->bytes() + headroom, &clen, len.buf, CHUNK_SIZE_LEN);
  if (err) {
    ciphertext->SetCapacity(previous_capacity);
    return -EBADMSG;
  }

  DCHECK_EQ(clen, CHUNK_SIZE_LEN + tlen);
  headroom += clen;

  (*counter)++;

  clen = plaintext_size + tlen;

  VLOG(4) << "encrypt: 2nd chunk: origin: " << plaintext_size << " encrypted: " << clen;

  DCHECK_GE(ciphertext->RemainingCapacity(), static_cast<int>(clen));
  // FIXME it is a bug with crypto layer
  memset(ciphertext->bytes() + headroom, 0, clen);

  err = impl_->EncryptPacket(*counter, ciphertext->bytes() + headroom, &clen, plaintext_data, plaintext_size);
  if (err) {
    ciphertext->SetCapacity(previous_capacity);
    return -EBADMSG;
  }

  headroom += clen;
  DCHECK_EQ(clen, plaintext_size + tlen);

  DCHECK_EQ(ciphertext->offset() + headroom, ciphertext->capacity());

  (*counter)++;

  return 0;
}

int cipher::chunk_encrypt_frame_stream(uint64_t* counter,
                                       const uint8_t* plaintext_data,
                                       size_t plaintext_size,
                                       GrowableIOBuffer* ciphertext) const {
  int err;
  const int previous_capacity = ciphertext->capacity();
  int headroom = previous_capacity - ciphertext->offset();
  size_t clen = plaintext_size;

  ciphertext->SetCapacity(previous_capacity + clen);

  VLOG(4) << "encrypt: stream chunk: origin: " << plaintext_size << " actual: " << clen;

  err = impl_->EncryptPacket(*counter, ciphertext->bytes() + headroom, &clen, plaintext_data, plaintext_size);
  if (err) {
    ciphertext->SetCapacity(previous_capacity);
    return -EBADMSG;
  }
  DCHECK_EQ(ciphertext->capacity(), previous_capacity + static_cast<int>(clen));
  // nop: ciphertext->SetCapacity(previous_capacity + clen);
  (*counter)++;

  return 0;
}

void cipher::set_key_stream(const uint8_t* nonce, int nonce_len) {
  counter_ = 0;

  if (!impl_->SetIV(nonce, nonce_len)) {
    LOG(WARNING) << "SetIV Failed";
  }

  if (!impl_->SetKey(key_, key_len_)) {
    LOG(WARNING) << "SetKey Failed";
  }

  DumpHex("KEY", impl_->GetKey(), impl_->GetKeySize());
  DumpHex("IV", impl_->GetIV(), impl_->GetIVSize());
}

void cipher::set_key_aead(const uint8_t* salt, int salt_len) {
  DCHECK_EQ(salt_len, key_len_);
  uint8_t skey[MAX_KEY_LENGTH] = {};
  int err = crypto_hkdf(salt, salt_len, key_, key_len_, reinterpret_cast<const uint8_t*>(SUBKEY_INFO),
                        sizeof(SUBKEY_INFO) - 1, skey, key_len_);
  if (err) {
    LOG(FATAL) << "Unable to generate subkey";
  }

  counter_ = 0;
  uint8_t nonce[MAX_NONCE_LENGTH] = {};

  if (!impl_->SetKey(skey, key_len_)) {
    LOG(WARNING) << "SetKey Failed";
  }

  if (!impl_->SetNoncePrefix(nonce, impl_->GetNoncePrefixSize())) {
    LOG(WARNING) << "SetNoncePrefix Failed";
  }

  DumpHex("SKEY", impl_->GetKey(), impl_->GetKeySize());
  DumpHex("NONCE_PREFIX", impl_->GetNoncePrefix(), impl_->GetNoncePrefixSize());
}

}  // namespace net
