// SPDX-License-Identifier: GPL-2.0 OR CDDL-1.0
/*
 * CDDL HEADER START
 *
 * The contents of this file are subject to the terms of the
 * Common Development and Distribution License (the "License").
 * You may not use this file except in compliance with the License.
 *
 * You can obtain a copy of the license at usr/src/OPENSOLARIS.LICENSE
 * or https://opensource.org/licenses/CDDL-1.0.
 * See the License for the specific language governing permissions
 * and limitations under the License.
 *
 * When distributing Covered Code, include this CDDL HEADER in each
 * file and include the License file at usr/src/OPENSOLARIS.LICENSE.
 * If applicable, add the following below this CDDL HEADER, with the
 * fields enclosed by brackets "[]" replaced with your own identifying
 * information: Portions Copyright [yyyy] [name of copyright owner]
 *
 * CDDL HEADER END
 */

/* Copyright (c) 2024-2025 Chilledheart  */

#include "net/doh_request.hpp"

#include "net/dns_addrinfo_helper.hpp"
#include "net/dns_message_request.hpp"
#include "net/dns_message_response_parser.hpp"
#include "net/http_parser.hpp"

namespace net {

using namespace dns_message;

DoHRequest::~DoHRequest() {
  VLOG(1) << "DoH Request freed memory";

  close();
}

void DoHRequest::close() {
  if (closed_) {
    return;
  }
  closed_ = true;
  cb_ = nullptr;
  if (ssl_socket_) {
    ssl_socket_->Disconnect();
  } else if (socket_.is_open()) {
    asio::error_code ec;
    socket_.close(ec);
  }
}

void DoHRequest::DoRequest(dns_message::DNStype dns_type, const std::string& host, int port, AsyncResolveCallback cb) {
  dns_type_ = dns_type;
  host_ = host;
  port_ = port;
  cb_ = std::move(cb);

  if (is_localhost(host_)) {
    VLOG(3) << "DoH Request: localhost host: " << host_;
    scoped_refptr<DoHRequest> self(this);
    asio::post(io_context_, [this, self]() {
      struct addrinfo* addrinfo = addrinfo_loopback(dns_type_ == dns_message::DNS_TYPE_AAAA, port_);
      OnDoneRequest({}, addrinfo);
    });
    return;
  }

  dns_message::request msg;
  if (!msg.init(host, dns_type)) {
    OnDoneRequest(asio::error::host_unreachable, nullptr);
    return;
  }
  auto buf = gurl_base::MakeRefCounted<GrowableIOBuffer>();

  int payload_size = 0;
  for (auto buffer : msg.buffers()) {
    payload_size += buffer.size();
  }

  {
    std::string request_header = absl::StrFormat(
        "POST %s HTTP/1.1\r\n"
        "Host: %s:%d\r\n"
        "Accept: */*\r\n"
        "Content-Type: application/dns-message\r\n"
        "Content-Length: %d\r\n"
        "\r\n",
        doh_path_, doh_host_, doh_port_, payload_size);
    buf->appendBytesAtEnd(request_header.c_str(), request_header.size());
  }

  for (auto buffer : msg.buffers()) {
    buf->appendBytesAtEnd(buffer.data(), buffer.size());
  }

  send_buf_ = buf;

  asio::error_code ec;
  socket_.open(endpoint_.protocol(), ec);
  if (ec) {
    OnDoneRequest(ec, nullptr);
    return;
  }
  socket_.non_blocking(true, ec);
  scoped_refptr<DoHRequest> self(this);
  socket_.async_connect(endpoint_, [this, self](asio::error_code ec) {
    // Cancelled, safe to ignore
    if (UNLIKELY(ec == asio::error::bad_descriptor || ec == asio::error::operation_aborted)) {
      return;
    }
    if (ec) {
      OnDoneRequest(ec, nullptr);
      return;
    }
    VLOG(3) << "DoH Remote Server Connected: " << endpoint_;
    // tcp socket connected
    OnSocketConnect();
  });
}

void DoHRequest::OnSocketConnect() {
  scoped_refptr<DoHRequest> self(this);
  asio::error_code ec;
  SetTCPCongestion(socket_.native_handle(), ec);
  SetTCPKeepAlive(socket_.native_handle(), ec);
  SetSocketTcpNoDelay(&socket_, ec);
  ssl_socket_ = SSLSocket::Create(ssl_socket_data_index_, nullptr, &io_context_, &socket_, ssl_ctx_,
                                  /*https_fallback*/ true, doh_host_, doh_port_);

  ssl_socket_->Connect([this, self](int rv) {
    asio::error_code ec;
    if (rv < 0) {
      ec = asio::error::connection_refused;
      OnDoneRequest(ec, nullptr);
      return;
    }
    VLOG(3) << "DoH Remote SSL Server Connected: " << endpoint_;
    // ssl socket connected
    OnSSLConnect();
  });
}

void DoHRequest::OnSSLConnect() {
  scoped_refptr<DoHRequest> self(this);

  // Also queue a ConfirmHandshake. It should also be blocked on ServerHello.
  absl::AnyInvocable<void(int)> cb = [this, self](int rv) {
    if (rv < 0) {
      asio::error_code ec = asio::error::connection_refused;
      OnDoneRequest(ec, nullptr);
    }
  };
  ssl_socket_->ConfirmHandshake(std::move(cb));

  if (!cb_) {
    return;
  }

  recv_buf_ = gurl_base::MakeRefCounted<GrowableIOBuffer>();
  ssl_socket_->WaitWrite([this, self](asio::error_code ec) { OnSSLWritable(ec); });
  ssl_socket_->WaitRead([this, self](asio::error_code ec) { OnSSLReadable(ec); });
}

void DoHRequest::OnSSLWritable(asio::error_code ec) {
  if (ec) {
    OnDoneRequest(ec, nullptr);
    return;
  }
  size_t written = ssl_socket_->Write(send_buf_.get(), ec);
  if (ec) {
    OnDoneRequest(ec, nullptr);
    return;
  }
  send_buf_->set_offset(send_buf_->offset() + written);
  VLOG(3) << "DoH Request Sent: " << written << " bytes Remaining: " << send_buf_->RemainingCapacity() << " bytes";
  if (UNLIKELY(!send_buf_->empty())) {
    scoped_refptr<DoHRequest> self(this);
    ssl_socket_->WaitWrite([this, self](asio::error_code ec) { OnSSLWritable(ec); });
    return;
  }
  VLOG(3) << "DoH Request Fully Sent";
}

void DoHRequest::OnSSLReadable(asio::error_code ec) {
  if (UNLIKELY(ec)) {
    OnDoneRequest(ec, nullptr);
    return;
  }
  size_t read;
  auto buf = gurl_base::MakeRefCounted<GrowableIOBuffer>();
  buf->SetCapacity(UINT16_MAX);
  do {
    ec = asio::error_code();
    read = ssl_socket_->Read(buf.get(), ec);
    if (ec == asio::error::interrupted) {
      continue;
    }
  } while (false);

  if (UNLIKELY(ec && ec != asio::error::try_again && ec != asio::error::would_block)) {
    OnDoneRequest(ec, nullptr);
    return;
  }
  // append buf to the end of recv_buf
  int previous_capacity = recv_buf_->capacity();
  recv_buf_->SetCapacity(previous_capacity + read);
  memcpy(recv_buf_->StartOfBuffer() + previous_capacity, buf->data(), read);

  VLOG(3) << "DoH Response Received: " << read << " bytes";

  switch (read_state_) {
    case Read_Header:
      OnReadHeader();
      break;
    case Read_Body:
      OnReadBody();
      break;
  }
}

void DoHRequest::OnReadHeader() {
  DCHECK_EQ(read_state_, Read_Header);
  HttpResponseParser parser;

  bool ok;
  int nparsed = parser.Parse(recv_buf_->span(), &ok);
  if (nparsed) {
    VLOG(3) << "Connection (doh resolver) "
            << " http: " << std::string_view(reinterpret_cast<const char*>(recv_buf_->data()), nparsed);
  }
  if (!ok) {
    LOG(WARNING) << "DoH Response Invalid HTTP Response";
    OnDoneRequest(asio::error::operation_not_supported, nullptr);
    return;
  }

  VLOG(3) << "DoH Response Header Parsed: " << nparsed << " bytes";
  recv_buf_->set_offset(recv_buf_->offset() + nparsed);

  if (UNLIKELY(parser.status_code() != 200)) {
    LOG(WARNING) << "DoH Response Unexpected HTTP Response Status Code: " << parser.status_code();
    OnDoneRequest(asio::error::operation_not_supported, nullptr);
    return;
  }

  if (UNLIKELY(parser.content_type() != "application/dns-message")) {
    LOG(WARNING) << "DoH Response Expected Type: application/dns-message but received: " << parser.content_type();
    OnDoneRequest(asio::error::operation_not_supported, nullptr);
    return;
  }

  if (UNLIKELY(parser.content_length() == 0)) {
    LOG(WARNING) << "DoH Response Missing Content Length";
    OnDoneRequest(asio::error::operation_not_supported, nullptr);
    return;
  }

  if (UNLIKELY(parser.content_length() >= UINT16_MAX)) {
    LOG(WARNING) << "DoH Response Too Large: " << parser.content_length() << " bytes";
    OnDoneRequest(asio::error::operation_not_supported, nullptr);
    return;
  }

  read_state_ = Read_Body;
  body_length_ = parser.content_length();

  OnReadBody();
}

void DoHRequest::OnReadBody() {
  DCHECK_EQ(read_state_, Read_Body);
  if (UNLIKELY(recv_buf_->RemainingCapacity() < body_length_)) {
    VLOG(3) << "DoH Response Expected Data: " << body_length_ << " bytes Current: " << recv_buf_->size() << " bytes";

    scoped_refptr<DoHRequest> self(this);
    ssl_socket_->WaitRead([this, self](asio::error_code ec) { OnSSLReadable(ec); });
    return;
  }

  OnParseDnsResponse();
}

void DoHRequest::OnParseDnsResponse() {
  DCHECK_EQ(read_state_, Read_Body);
  DCHECK_GE(recv_buf_->RemainingCapacity(), body_length_);

  dns_message::response_parser response_parser;
  dns_message::response response;

  dns_message::response_parser::result_type result;
  std::tie(result, std::ignore) =
      response_parser.parse(response, recv_buf_->data(), recv_buf_->data(), recv_buf_->data() + body_length_);
  if (result != dns_message::response_parser::good) {
    LOG(WARNING) << "DoH Response Bad Format";
    OnDoneRequest(asio::error::operation_not_supported, {});
    return;
  }
  VLOG(3) << "DoH Response Body Parsed: " << body_length_ << " bytes";
  recv_buf_->set_offset(recv_buf_->offset() + body_length_);

  struct addrinfo* addrinfo = addrinfo_dup(dns_type_ == dns_message::DNS_TYPE_AAAA, response, port_);

  OnDoneRequest({}, addrinfo);
}

void DoHRequest::OnDoneRequest(asio::error_code ec, struct addrinfo* addrinfo) {
  if (auto cb = std::move(cb_)) {
    cb(ec, addrinfo);
  } else {
    addrinfo_freedup(addrinfo);
  }
}

}  // namespace net
