/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <thrift/compiler/whisker/object.h>

#include <thrift/compiler/whisker/detail/overload.h>

#include <cassert>
#include <ostream>
#include <sstream>
#include <type_traits>

#include <fmt/core.h>
#include <fmt/ranges.h>

#include <boost/core/demangle.hpp>

namespace whisker {

namespace {

std::string demangle(const std::type_info& type) {
  return boost::core::demangle(type.name());
}

class to_string_visitor {
 public:
  explicit to_string_visitor(const object_print_options& opts) : opts_(opts) {}

  // Prevent implicit conversion to whisker::object. Otherwise, we can silently
  // compile an infinitely recursive visit() chain if there is a missing
  // overload for one of the alternatives in the variant.
  template <
      typename T = object,
      typename = std::enable_if_t<std::is_same_v<T, object>>>
  void visit(const T& value, tree_printer::scope scope) const {
    value.visit(
        [&](const array::ptr& a) { visit_maybe_truncate(a, std::move(scope)); },
        [&](const map::ptr& m) { visit_maybe_truncate(m, std::move(scope)); },
        [&](const native_function::ptr& f) {
          visit_maybe_truncate(f, std::move(scope));
        },
        [&](const native_handle<>& h) {
          visit_maybe_truncate(h, std::move(scope));
        },
        // All other types are printed inline so no truncation is necessary.
        [&](auto&& alternative) { visit(alternative, std::move(scope)); });
  }

 private:
  template <typename T>
  void visit_maybe_truncate(const T& value, tree_printer::scope scope) const {
    if (at_max_depth(scope)) {
      scope.println("...");
      return;
    }
    visit(value, scope);
  }

  void visit(i64 value, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    scope.println("i64({})", value);
  }

  void visit(f64 value, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    scope.println("f64({})", value);
  }

  void visit(const std::string& value, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    scope.println("'{}'", tree_printer::escape(value));
  }

  void visit(bool value, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    scope.println("{}", value ? "true" : "false");
  }

  void visit(null, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    scope.println("null");
  }

  void visit(const map::ptr& m, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    m->print_to(std::move(scope), opts_);
  }

  void visit(const array::ptr& a, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    a->print_to(std::move(scope), opts_);
  }

  void visit(const native_function::ptr& f, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    f->print_to(std::move(scope), opts_);
  }

  void visit(const native_handle<>& handle, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    scope.println("<native_handle type='{}'>", demangle(handle.type()));
    if (const prototype<>::ptr& proto = handle.proto()) {
      visit_maybe_truncate(proto, scope.open_node());
    }
  }

  void visit(const prototype<>::ptr& proto, tree_printer::scope scope) const {
    require_within_max_depth(scope);
    std::set<std::string> keys = proto->keys();
    scope.println("<prototype (size={})>", keys.size());
    for (const auto& key : keys) {
      auto element_scope = scope.open_transparent_property();
      element_scope.println("'{}'", key);
    }
    if (const prototype<>::ptr& parent = proto->parent()) {
      visit_maybe_truncate(parent, scope.open_node());
    }
  }

  [[nodiscard]] bool at_max_depth(const tree_printer::scope& scope) const {
    return scope.semantic_depth() == opts_.max_depth;
  }

  void require_within_max_depth(
      [[maybe_unused]] const tree_printer::scope& scope) const {
    assert(scope.semantic_depth() <= opts_.max_depth);
  }

  const object_print_options& opts_;
};

} // namespace

namespace {
class basic_map final : public map {
 public:
  std::optional<object> lookup_property(
      std::string_view identifier) const final {
    if (auto found = raw_.find(identifier); found != raw_.end()) {
      return found->second;
    }
    return std::nullopt;
  }

  std::optional<std::set<std::string>> keys() const final {
    std::set<std::string> keys;
    for (const auto& [key, _] : raw_) {
      keys.insert(key);
    }
    return keys;
  }

  void print_to(tree_printer::scope scope, const object_print_options& options)
      const final {
    default_print_to("map", *keys(), std::move(scope), options);
  }

  std::string describe_type() const final {
    // The built-in map type does not need to be more descriptive.
    return "map";
  }

  explicit basic_map(map::raw raw) : raw_(std::move(raw)) {}

 private:
  map::raw raw_;
};
} // namespace
/* static */ map::ptr map::of(map::raw raw) {
  return std::make_shared<basic_map>(std::move(raw));
}

void map::print_to(
    tree_printer::scope scope, const object_print_options& options) const {
  std::optional<std::set<std::string>> property_names = keys();
  if (!property_names.has_value()) {
    scope.println("map [custom] (not enumerable)");
    return;
  }
  default_print_to(
      "map [custom]",
      std::move(property_names).value(),
      std::move(scope),
      options);
}

std::string map::describe_type() const {
  return fmt::format("map [custom]='{}'>", demangle(typeid(*this)));
}

bool operator==(const map& lhs, const map& rhs) {
  if (std::addressof(lhs) == std::addressof(rhs)) {
    return true;
  }

  auto lhs_keys = lhs.keys();
  auto rhs_keys = rhs.keys();
  const bool keys_equal =
      lhs_keys.has_value() && rhs_keys.has_value() && *lhs_keys == *rhs_keys;
  if (!keys_equal) {
    return false;
  }
  for (const std::string& key : *lhs_keys) {
    std::optional<object> lhs_value = lhs.lookup_property(key);
    std::optional<object> rhs_value = rhs.lookup_property(key);
    // These should always be present because we are only attempting to fetch
    // enumerable keys.
    assert(lhs_value.has_value());
    assert(rhs_value.has_value());
    if (*lhs_value != *rhs_value) {
      return false;
    }
  }
  return true;
}

void map::default_print_to(
    std::string_view name,
    const std::set<std::string>& property_names,
    tree_printer::scope scope,
    const object_print_options& options) const {
  assert(scope.semantic_depth() <= options.max_depth);
  const auto size = property_names.size();
  scope.println("{} (size={})", name, size);

  for (const std::string& key : property_names) {
    auto cached = lookup_property(key);
    assert(cached.has_value());
    auto element_scope = scope.open_transparent_property();
    element_scope.println("'{}'", key);
    whisker::print_to(*cached, element_scope.open_node(), options);
  }
}

namespace {
class basic_array final : public array {
 public:
  std::size_t size() const final { return raw_.size(); }
  object at(std::size_t index) const final { return raw_.at(index); }

  void print_to(tree_printer::scope scope, const object_print_options& options)
      const final {
    default_print_to("array", std::move(scope), options);
  }

  std::string describe_type() const final {
    // The built-in array type does not need to be more descriptive.
    return "array";
  }

  explicit basic_array(array::raw raw) : raw_(std::move(raw)) {}

 private:
  array::raw raw_;
};
} // namespace
/* static */ array::ptr array::of(array::raw raw) {
  return std::make_shared<basic_array>(std::move(raw));
}

void array::print_to(
    tree_printer::scope scope, const object_print_options& options) const {
  default_print_to("array [custom]", std::move(scope), options);
}

std::string array::describe_type() const {
  return fmt::format("array [custom]='{}'", demangle(typeid(*this)));
}

bool operator==(const array& lhs, const array& rhs) {
  if (std::addressof(lhs) == std::addressof(rhs)) {
    return true;
  }

  std::size_t size = lhs.size();
  if (size != rhs.size()) {
    return false;
  }
  for (std::size_t i = 0; i < size; ++i) {
    object lhs_value = lhs.at(i);
    object rhs_value = rhs.at(i);
    if (lhs_value != rhs_value) {
      return false;
    }
  }
  return true;
}

void array::default_print_to(
    std::string_view name,
    tree_printer::scope scope,
    const object_print_options& options) const {
  assert(scope.semantic_depth() <= options.max_depth);

  const auto sz = size();
  scope.println("{} (size={})", name, sz);
  for (std::size_t i = 0; i < sz; ++i) {
    auto element_scope = scope.open_transparent_property();
    element_scope.println("[{}]", i);
    whisker::print_to(at(i), element_scope.open_node(), options);
  }
}

void native_function::print_to(
    tree_printer::scope scope, const object_print_options&) const {
  scope.println("<native_function>");
}

std::string native_function::describe_type() const {
  return fmt::format("<native_function type='{}'>", demangle(typeid(*this)));
}

/* static */ prototype<>::ptr prototype<>::from(
    descriptors_map descriptors, prototype::ptr parent) {
  return std::make_shared<basic_prototype<>>(
      std::move(descriptors), std::move(parent));
}

std::string detail::describe_native_handle_for_type(
    const std::type_info& type) {
  return fmt::format("<native_handle type='{}'>", demangle(type));
}

std::string native_handle<void>::describe_type() const {
  return detail::describe_native_handle_for_type(type());
}
/* static */ std::string native_handle<void>::describe_class_type() {
  return "<native_handle type='void'>";
}

std::string object::describe_type() const {
  return visit(
      [](i64) -> std::string { return "i64"; },
      [](f64) -> std::string { return "f64"; },
      [](const string&) -> std::string { return "string"; },
      [](boolean) -> std::string { return "boolean"; },
      [](null) -> std::string { return "null"; },
      [](const array::ptr& a) -> std::string { return a->describe_type(); },
      [](const map::ptr& m) -> std::string { return m->describe_type(); },
      [](const native_function::ptr& f) -> std::string {
        return f->describe_type();
      },
      [](const native_handle<>& h) -> std::string {
        return h.describe_type();
      });
}

std::string to_string(const object& obj, const object_print_options& options) {
  std::ostringstream out;
  print_to(obj, tree_printer::scope::make_root(out), options);
  return std::move(out).str();
}

void print_to(
    const object& obj,
    tree_printer::scope scope,
    const object_print_options& options) {
  to_string_visitor(options).visit(obj, std::move(scope));
}

std::ostream& operator<<(std::ostream& out, const object& o) {
  return out << to_string(o);
}

} // namespace whisker
