//
// Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
//
// NVIDIA CORPORATION and its licensors retain all intellectual property
// and proprietary rights in and to this software, related documentation
// and any modifications thereto.  Any use, reproduction, disclosure or
// distribution of this software and related documentation without an express
// license agreement from NVIDIA CORPORATION is strictly prohibited.
//

#pragma once

#include "nscq/logger.hpp"
#include "nscq/nscq.h"
#include "nscq/warn.hpp"

#include <cstring>
#include <memory>
#include <sstream>
#include <stdexcept>

#if (defined(__GNUC__) || defined(__clang__))
// Clang and GCC implement the IA64 C++ ABI, and can use the same symbol demangling code to get
// pretty symbol names. This can be used for logging.
#define NSCQ_IA64_ABI_DEMANGLE
#include <cxxabi.h>
#endif

namespace nscq {

struct not_implemented : public std::logic_error {
    using std::logic_error::logic_error;
};

// Returns the result code, embedded in a wrapper return type, if necessary.
// This should only need to be used with exception_to_rc().
template <typename TResult, typename TValue>
auto rc_to_result_type(nscq_rc_t rc) -> TResult {
    if constexpr (std::is_void_v<TValue>) {
        return rc;
    } else if constexpr (std::is_pointer_v<TValue>) {
        return {rc, nullptr};
    } else if constexpr (std::is_nothrow_default_constructible_v<TValue>) {
        return {rc, TValue()};
    }
}

// Utility template class and alias used to pull the api_t type alias out of a maybe-void type.
template <typename TInternal, typename = void>
struct api_t_wrapper {
    using api_t = void;
};

template <typename TInternal>
struct api_t_wrapper<TInternal, std::void_t<typename TInternal::api_t>> {
    using api_t = typename TInternal::api_t;
};

template <typename TInternal>
using api_t = typename api_t_wrapper<TInternal>::api_t;

// Types that provide api_t and return by value should provide a cast operator
template <typename TInternal, typename = std::enable_if_t<!std::is_void_v<api_t<TInternal>>>>
auto inline api_safe_repr(TInternal& t) -> api_t<TInternal> {
    return static_cast<api_t<TInternal>>(t);
}

// Trivial types (including primitives and simple structs) are okay to export as-is
template <typename TInternal, typename = std::enable_if_t<std::is_trivial_v<TInternal>>>
auto inline api_safe_repr(TInternal t) -> TInternal {
    return t;
}

// Overload for strings, where we convert to the inner char pointer
auto inline api_safe_repr(const std::string& str) -> const char* {
    return str.c_str();
}

// Types passed by a shared_ptr can also be trivial types, in which case the value can be
// exported by pointer. Otherwise, the type should define an api_t alias and an explicit
// conversion method to a const pointer to that type.
template <typename TInternal>
auto inline api_safe_repr(const std::shared_ptr<TInternal>& ptr)
    -> std::conditional_t<std::is_trivial_v<TInternal>, TInternal*,
                          const typename TInternal::api_t*> {
    if constexpr (std::is_trivial_v<TInternal>) {
        return ptr.get();
    } else {
        return static_cast<const typename TInternal::api_t*>(*ptr);
    }
}

// Checks that a value falls within the given range, or raises an exception. Useful for converting
// between enums.
template <typename TVal, TVal TMin, TVal TMax>
class range_check {
  public:
    static const TVal min = TMin;
    static const TVal max = TMax;
    using api_t = TVal;

    range_check() = default;

    // NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
    range_check(TVal value) : m_value(value) {
        if ((value < min) || (value > max)) {
            NSCQ_THROW(std::out_of_range, std::to_string(value));
        }
    }

    // NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
    operator TVal() const { return m_value; }

  private:
    TVal m_value = min;
};

class fabric_state {
  public:
    using api_t = nscq_fabric_state_t;
    using checked_driver_state_t =
        range_check<decltype(api_t::driver), NSCQ_DRIVER_FABRIC_STATE_UNKNOWN,
                    NSCQ_DRIVER_FABRIC_STATE_MANAGER_ERROR>;
    using checked_device_state_t =
        range_check<decltype(api_t::device), NSCQ_DEVICE_FABRIC_STATE_UNKNOWN,
                    NSCQ_DEVICE_FABRIC_STATE_BLACKLISTED>;

    fabric_state() = default;
    fabric_state(checked_driver_state_t driver_state, checked_device_state_t device_state);

    explicit operator api_t() const;

  private:
    api_t m_val{NSCQ_DRIVER_FABRIC_STATE_UNKNOWN, NSCQ_DEVICE_FABRIC_STATE_UNKNOWN};
};

using blacklist_reason_t =
    range_check<nscq_blacklist_reason_t, NSCQ_DEVICE_BLACKLIST_REASON_UNKNOWN,
                NSCQ_DEVICE_BLACKLIST_REASON_UNSPEC_DEVICE_FAILURE_PEER>;

using nvlink_state_t =
    range_check<nscq_nvlink_state_t, NSCQ_NVLINK_STATE_UNKNOWN, NSCQ_NVLINK_STATE_SLEEP>;

using nvlink_rx_sublink_state_t =
    range_check<nscq_nvlink_rx_sublink_state_t, NSCQ_NVLINK_STATUS_SUBLINK_RX_STATE_UNKNOWN, NSCQ_NVLINK_STATUS_SUBLINK_RX_STATE_OFF>;

using nvlink_tx_sublink_state_t =
    range_check<nscq_nvlink_tx_sublink_state_t, NSCQ_NVLINK_STATUS_SUBLINK_TX_STATE_UNKNOWN, NSCQ_NVLINK_STATUS_SUBLINK_TX_STATE_OFF>;


class label {
  public:
    using api_t = nscq_label_t;

    label() = default;
    explicit label(const std::string& str);

    explicit operator const api_t*();

  private:
    api_t m_val{0};
};

class uuid {
  public:
    using api_t = nscq_uuid_t;

    static const auto size_bytes = sizeof(nscq_uuid_t::bytes);

    uuid() = default;
    explicit uuid(const std::array<uint8_t, size_bytes>& bytes);
    explicit uuid(const api_t* val);
    explicit uuid(const std::string str);

    explicit operator const api_t*() const;
    explicit operator const std::string &() const;
    auto operator==(const uuid& other) const -> bool;
    auto operator==(const api_t* other) const -> bool;
    auto operator<(const uuid& other) const -> bool;
    auto operator<(const api_t* other) const -> bool;

  private:
    friend struct std::less<std::shared_ptr<uuid>>;
    api_t m_val{0};
    std::string m_str;

    static auto encode_str(const api_t& val) -> std::string;
};

struct invalid_uuid : public std::invalid_argument {
    invalid_uuid(const char* where, const uuid& id)
    : std::invalid_argument(make_string(where, id)) {}

  private:
    static inline auto make_string(const char* where, const uuid& id) -> std::string {
        std::stringstream ss;
        ss << "No resource found for " << std::string(id) << " at: " << where;
        return ss.str();
    }
};

struct resource_not_mountable : public std::invalid_argument {
    resource_not_mountable(const char* where, const uuid& id)
    : std::invalid_argument(make_string(where, id)) {}

  private:
    static inline auto make_string(const char* where, const uuid& id) -> std::string {
        std::stringstream ss;
        ss << "Resource " << std::string(id) << " is not mountable: " << where;
        return ss.str();
    }
};

struct overflow : public std::length_error {
    explicit overflow(const std::string& what) : std::length_error(what) {}
};

struct unsupported_drv : public std::runtime_error {
    explicit unsupported_drv() : std::runtime_error("Unsupported driver") {}
};

struct drv_error : public std::runtime_error {
    explicit drv_error(uint32_t what)
    : std::runtime_error("Driver error: " + std::to_string(what)) {}
};

struct timeout_error : public std::runtime_error {
    explicit timeout_error() : std::runtime_error("Call timed out") {}
};

inline auto copy_to_label(const std::string& str, nscq_label_t* label) -> void {
    const auto label_size = sizeof(label->data);
    if (str.length() >= (label_size - 1)) {
        NSCQ_THROW(overflow, str);
    }
    std::strncpy(&label->data[0], str.c_str(), label_size);
}

// Template parameter encapsulation mechanism - this can be used to create
// function overloads that are distinguished by a single explicit type parameter.
template <typename... TList>
struct type {};

// Generalized template function for converting internal API implementations which throw exceptions
// into result codes that can be passed back through the C API. The argument is a lambda function;
// if its return type is void, the wrapper will return a single status code. If the lambda returns
// an actual value, the wrapper will compose and return a struct of type R.
template <typename TResult, typename TFunc>
auto exception_to_rc(TFunc&& func) -> TResult {
    try {
        if constexpr (std::is_void_v<decltype(func())>) {
            func();
            return static_cast<nscq_rc_t>(consume_warning());
        } else {
            auto ret = func();
            return {static_cast<nscq_rc_t>(consume_warning()), ret};
        }
    } catch (const not_implemented& e) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_NOT_IMPLEMENTED);
    } catch (const invalid_uuid& e) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_INVALID_UUID);
    } catch (const resource_not_mountable& e) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_RESOURCE_NOT_MOUNTABLE);
    } catch (const overflow& e) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_OVERFLOW);
    } catch (const std::out_of_range& e) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_UNEXPECTED_VALUE);
    } catch (const unsupported_drv& e) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_UNSUPPORTED_DRV);
    } catch (const drv_error& e) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_DRV);
    } catch (const timeout_error& e) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_TIMEOUT);
    } catch (...) {
        return rc_to_result_type<TResult, decltype(func())>(NSCQ_RC_ERROR_UNSPECIFIED);
    }
}

#if defined(NSCQ_IA64_ABI_DEMANGLE)
// abi::__cxa_demangle is part of the IA64 C++ ABI. We use this to get the pretty names for C++
// symbols, for logging purposes.
template <typename TType>
auto type_name() -> std::string {
    auto* name = abi::__cxa_demangle(typeid(TType).name(), nullptr, nullptr, nullptr);
    std::string ret(name);
    free(name); // NOLINT(cppcoreguidelines-no-malloc,hicpp-no-malloc)
    return ret;
}
#else
template <typename TType>
auto type_name() -> std::string {
    return std::string(typeid(TType).name());
}
#endif

} // namespace nscq

namespace std {

// Template specialization of std::less for std::shared_ptr<uuid>.
// This will be picked up by std::map<std::shared_ptr<uuid>, T> types as the default key_compare.
template <>
struct less<shared_ptr<nscq::uuid>> {
    auto operator()(const shared_ptr<nscq::uuid>& a, const shared_ptr<nscq::uuid>& b) const
        -> bool {
        return memcmp(&a->m_val, &b->m_val, sizeof(nscq::uuid::api_t)) < 0;
    }
};

} // namespace std
