#include "balanceUtils.hpp"
#include <mpi.h>
#include <stk_mesh/base/BulkData.hpp>
#include <stk_mesh/base/Entity.hpp>
#include <stk_topology/topology.hpp>
#include "stk_mesh/base/Field.hpp"  // for field_data
#include "stk_mesh/base/FieldBase.hpp"  // for field_data
#include "FaceSearchTolerance.hpp"

namespace stk
{
namespace balance
{

//////////////////////////////////////////////////////////////////////////

size_t BalanceSettings::getNumNodesRequiredForConnection(stk::topology element1Topology, stk::topology element2Topology) const
{
    return 1;
}

double BalanceSettings::getGraphEdgeWeight(stk::topology element1Topology, stk::topology element2Topology) const
{
    return 1;
}

int BalanceSettings::getGraphVertexWeight(stk::topology type) const
{
    return 1;
}

double BalanceSettings::getGraphVertexWeight(stk::mesh::Entity entity, int criteria_index) const
{
    return 1;
}

BalanceSettings::GraphOption BalanceSettings::getGraphOption() const
{
    return BalanceSettings::LOADBALANCE;
}

bool BalanceSettings::includeSearchResultsInGraph() const
{
    return false;
}

double BalanceSettings::getToleranceForFaceSearch(const stk::mesh::BulkData & mesh, const stk::mesh::FieldBase & coordField, const stk::mesh::EntityVector & faceNodes) const
{
    return 0.0;
}

void BalanceSettings::setToleranceFunctionForFaceSearch(std::shared_ptr<stk::balance::FaceSearchTolerance> faceSearchTolerance)
{
}

double BalanceSettings::getToleranceForParticleSearch() const
{
    return 0.0;
}

double BalanceSettings::getGraphEdgeWeightForSearch() const
{
    return 1.0;
}

bool BalanceSettings::getEdgesForParticlesUsingSearch() const
{
    return false;
}

double BalanceSettings::getVertexWeightMultiplierForVertexInSearch() const
{
    return 15;
}

bool BalanceSettings::isIncrementalRebalance() const
{
    return false;
}

bool BalanceSettings::isMultiCriteriaRebalance() const
{
    return false;
}

bool BalanceSettings::areVertexWeightsProvidedInAVector() const
{
    return false;
}

bool BalanceSettings::areVertexWeightsProvidedViaFields() const
{
    return false;
}

std::vector<double> BalanceSettings::getVertexWeightsViaVector() const
{
    return std::vector<double>();
}

double BalanceSettings::getImbalanceTolerance() const
{
    return 1.01;
}

void BalanceSettings::setDecompMethod(const std::string& method)
{
}

std::string BalanceSettings::getDecompMethod() const
{
    return std::string("parmetis");
}

std::string BalanceSettings::getCoordinateFieldName() const
{
    return std::string("coordinates");
}

bool BalanceSettings::shouldPrintMetrics() const
{
    return false;
}

int BalanceSettings::getNumCriteria() const
{
    return 1;
}

void BalanceSettings::modifyDecomposition(DecompositionChangeList & decomp) const
{}

double BalanceSettings::getParticleRadius(stk::mesh::Entity particle) const
{
    return 0.5;
}

bool BalanceSettings::setVertexWeightsBasedOnNumberAdjacencies() const
{
    return false;
}

// For graph based methods (parmetis) only
bool BalanceSettings::allowModificationOfVertexWeightsForSmallMeshes() const
{
    return true;
}

// For graph based methods (parmetis) only
bool BalanceSettings::shouldFixMechanisms() const
{
    return false;
}


//////////////////////////////////////

size_t GraphCreationSettings::getNumNodesRequiredForConnection(stk::topology element1Topology, stk::topology element2Topology) const
{
    const int noConnection = 1000;
    const int s = noConnection;
    const static int connectionTable[7][7] = {
        {1, 1, 1, 1, 1, 1, s}, // 0 dim
        {1, 1, 1, 1, 1, 1, s}, // 1 dim
        {1, 1, 2, 3, 2, 3, s}, // 2 dim linear
        {1, 1, 3, 3, 3, 3, s}, // 3 dim linear
        {1, 1, 2, 3, 3, 4, s}, // 2 dim higher-order
        {1, 1, 3, 3, 4, 4, s}, // 3 dim higher-order
        {s, s, s, s, s, s, s}  // super element
    };

    int element1Index = getConnectionTableIndex(element1Topology);
    int element2Index = getConnectionTableIndex(element2Topology);

    return connectionTable[element1Index][element2Index];
}

double GraphCreationSettings::getGraphEdgeWeightForSearch() const
{
    return edgeWeightForSearch;
}

double GraphCreationSettings::getGraphEdgeWeight(stk::topology element1Topology, stk::topology element2Topology) const
{
    const double defaultWeight = 1.0;
    const double noConnection = 0;
    const double s = noConnection;
    const double largeWeight = 5;
    const double L = largeWeight;
    const double twoDimWeight = 5;
    const double q = twoDimWeight;
    const double D = defaultWeight;
    const static double weightTable[7][7] = {
        {L, L, L, L, L, L, s}, // 0 dim
        {L, L, L, L, L, L, s}, // 1 dim
        {L, L, q, q, q, q, s}, // 2 dim linear
        {L, L, q, D, q, D, s}, // 3 dim linear
        {L, L, q, q, q, q, s}, // 2 dim higher-order
        {L, L, q, D, q, D, s}, // 3 dim higher-order
        {s, s, s, s, s, s, s}  // super element
    };

    int element1Index = getConnectionTableIndex(element1Topology);
    int element2Index = getConnectionTableIndex(element2Topology);

    return weightTable[element1Index][element2Index];
}

double GraphCreationSettings::getGraphVertexWeight(stk::mesh::Entity entity, int criteria_index) const
{
    return 1.0;
}

int GraphCreationSettings::getGraphVertexWeight(stk::topology type) const
{
    switch(type)
    {
        case stk::topology::PARTICLE:
        case stk::topology::LINE_2:
        case stk::topology::BEAM_2:
            return 1;
            break;
        case stk::topology::SHELL_TRIANGLE_3:
            return 3;
            break;
        case stk::topology::SHELL_TRIANGLE_6:
            return 6;
            break;
        case stk::topology::SHELL_QUADRILATERAL_4:
            return 6;
            break;
        case stk::topology::SHELL_QUADRILATERAL_8:
            return 12;
            break;
        case stk::topology::HEXAHEDRON_8:
            return 3;
            break;
        case stk::topology::HEXAHEDRON_20:
            return 12;
            break;
        case stk::topology::TETRAHEDRON_4:
            return 1;
            break;
        case stk::topology::TETRAHEDRON_10:
            return 3;
            break;
        case stk::topology::WEDGE_6:
            return 2;
            break;
        case stk::topology::WEDGE_15:
            return 12;
            break;
        default:
            if ( type.is_superelement( ))
            {
                return 10;
            }
            throw("Invalid Element Type In WeightsOfElement");
            break;
    }
    return 0;
}

BalanceSettings::GraphOption GraphCreationSettings::getGraphOption() const
{
    return BalanceSettings::LOADBALANCE;
}

bool GraphCreationSettings::includeSearchResultsInGraph() const
{
    return true;
}

double GraphCreationSettings::getToleranceForParticleSearch() const
{
    return mToleranceForParticleSearch;
}

void GraphCreationSettings::setToleranceFunctionForFaceSearch(std::shared_ptr<stk::balance::FaceSearchTolerance> faceSearchTolerance)
{
    m_faceSearchToleranceFunction = faceSearchTolerance;
    m_UseConstantToleranceForFaceSearch = false;
}

double GraphCreationSettings::getToleranceForFaceSearch(const stk::mesh::BulkData & mesh, const stk::mesh::FieldBase & coordField, const stk::mesh::EntityVector & faceNodes) const
{
    if (m_UseConstantToleranceForFaceSearch) {
        return mToleranceForFaceSearch;
    }
    else {
        return m_faceSearchToleranceFunction->compute(mesh, coordField, faceNodes);
    }
}

bool GraphCreationSettings::getEdgesForParticlesUsingSearch() const
{
    return false;
}

double GraphCreationSettings::getVertexWeightMultiplierForVertexInSearch() const
{
    return vertexWeightMultiplierForVertexInSearch;
}

std::string GraphCreationSettings::getDecompMethod() const
{
    return method;
}

void GraphCreationSettings::setDecompMethod(const std::string& input_method)
{
    method = input_method;
}
void GraphCreationSettings::setToleranceForFaceSearch(double tol)
{
    mToleranceForFaceSearch = tol;
}
void GraphCreationSettings::setToleranceForParticleSearch(double tol)
{
    mToleranceForParticleSearch = tol;
}

int GraphCreationSettings::getConnectionTableIndex(stk::topology elementTopology) const
{
    int tableIndex = -1;
    switch(elementTopology)
    {
        case stk::topology::PARTICLE:
            tableIndex = 0;
            break;
        case stk::topology::LINE_2:
        case stk::topology::LINE_2_1D:
        case stk::topology::LINE_3_1D:
        case stk::topology::BEAM_2:
        case stk::topology::BEAM_3:
        case stk::topology::SHELL_LINE_2:
        case stk::topology::SHELL_LINE_3:
            tableIndex = 1;
            break;
        case stk::topology::TRI_3_2D:
        case stk::topology::TRI_4_2D:
        case stk::topology::QUAD_4_2D:
        case stk::topology::SHELL_TRI_3:
        case stk::topology::SHELL_TRI_4:
        case stk::topology::SHELL_QUAD_4:
            tableIndex = 2;
            break;
        case stk::topology::TET_4:
        case stk::topology::PYRAMID_5:
        case stk::topology::WEDGE_6:
        case stk::topology::HEX_8:
            tableIndex = 3;
            break;
        case stk::topology::TRI_6_2D:
        case stk::topology::QUAD_8_2D:
        case stk::topology::QUAD_9_2D:
        case stk::topology::SHELL_TRI_6:
        case stk::topology::SHELL_QUAD_8:
        case stk::topology::SHELL_QUAD_9:
            tableIndex = 4;
            break;
        case stk::topology::TET_8:
        case stk::topology::TET_10:
        case stk::topology::TET_11:
        case stk::topology::PYRAMID_13:
        case stk::topology::PYRAMID_14:
        case stk::topology::WEDGE_15:
        case stk::topology::WEDGE_18:
        case stk::topology::HEX_20:
        case stk::topology::HEX_27:
            tableIndex = 5;
            break;
        default:
            if(elementTopology.is_superelement())
            {
                tableIndex = 6;
            }
            else
            {
                std::cerr << "Topology is " << elementTopology << std::endl;
                throw("Invalid Element Type in GetDimOfElement");
            }
            break;
    };
    return tableIndex;
}

bool GraphCreationSettings::shouldFixMechanisms() const
{
    return true;
}


}
}
