/*++
Copyright (c) 2013 Microsoft Corporation

Module Name:

    smt_farkas_util.cpp

Abstract:

    Utility for combining inequalities using coefficients obtained from Farkas lemmas.
    
Author:

    Nikolaj Bjorner (nbjorner) 2013-11-2.

Revision History:

    NB. This utility is specialized to proofs generated by the arithmetic solvers.

--*/

#include "smt_farkas_util.h"
#include "ast_pp.h"
#include "th_rewriter.h"
#include "bool_rewriter.h"


namespace smt {

    farkas_util::farkas_util(ast_manager& m):
        m(m),
        a(m),
        m_ineqs(m),
        m_split_literals(false),
        m_time(0) {
    }

    void farkas_util::mk_coerce(expr*& e1, expr*& e2) {
        if (a.is_int(e1) && a.is_real(e2)) {
            e1 = a.mk_to_real(e1);
        }
        else if (a.is_int(e2) && a.is_real(e1)) {
            e2 = a.mk_to_real(e2);
        }
    }

    // TBD: arith_decl_util now supports coercion, so this should be deprecated.
    app* farkas_util::mk_add(expr* e1, expr* e2) {
        mk_coerce(e1, e2);
        return a.mk_add(e1, e2);
    }

    app* farkas_util::mk_mul(expr* e1, expr* e2) {
        mk_coerce(e1, e2);
        return a.mk_mul(e1, e2);
    }

    app* farkas_util::mk_le(expr* e1, expr* e2) {
        mk_coerce(e1, e2);
        return a.mk_le(e1, e2);
    }

    app* farkas_util::mk_ge(expr* e1, expr* e2) {
        mk_coerce(e1, e2);
        return a.mk_gt(e1, e2);
    }

    app* farkas_util::mk_gt(expr* e1, expr* e2) {
        mk_coerce(e1, e2);
        return a.mk_gt(e1, e2);
    }

    app* farkas_util::mk_lt(expr* e1, expr* e2) {
        mk_coerce(e1, e2);
        return a.mk_lt(e1, e2);
    }

    void farkas_util::mul(rational const& c, expr* e, expr_ref& res) {
        expr_ref tmp(m);
        if (c.is_one()) {
            tmp = e;
        }
        else {
            tmp = mk_mul(a.mk_numeral(c, c.is_int() && a.is_int(e)), e);
        }
        res = mk_add(res, tmp);
    }

    bool farkas_util::is_int_sort(app* c) {
        SASSERT(m.is_eq(c) || a.is_le(c) || a.is_lt(c) || a.is_gt(c) || a.is_ge(c));
        SASSERT(a.is_int(c->get_arg(0)) || a.is_real(c->get_arg(0)));
        return a.is_int(c->get_arg(0));
    }

    bool farkas_util::is_int_sort() {
        SASSERT(!m_ineqs.empty());
        return is_int_sort(m_ineqs[0].get());
    }

    void farkas_util::normalize_coeffs() {
        rational l(1);
        for (unsigned i = 0; i < m_coeffs.size(); ++i) {
            l = lcm(l, denominator(m_coeffs[i]));
        }
        if (!l.is_one()) {
            for (unsigned i = 0; i < m_coeffs.size(); ++i) {
                m_coeffs[i] *= l;
            }
        }
        m_normalize_factor = l;
    }

    app* farkas_util::mk_one() {
        return a.mk_numeral(rational(1), true);
    }

    app* farkas_util::fix_sign(bool is_pos, app* c) {
        expr* x, *y;
        SASSERT(m.is_eq(c) || a.is_le(c) || a.is_lt(c) || a.is_gt(c) || a.is_ge(c));
        bool is_int = is_int_sort(c);
        if (is_int && is_pos && (a.is_lt(c, x, y) || a.is_gt(c, y, x))) {
            return mk_le(mk_add(x, mk_one()), y);
        }
        if (is_int && !is_pos && (a.is_le(c, x, y) || a.is_ge(c, y, x))) {
            // !(x <= y) <=> x > y <=> x >= y + 1
            return mk_ge(x, mk_add(y, mk_one()));
        }
        if (is_pos) {
            return c;
        }
        if (a.is_le(c, x, y)) return mk_gt(x, y);
        if (a.is_lt(c, x, y)) return mk_ge(x, y);
        if (a.is_ge(c, x, y)) return mk_lt(x, y);
        if (a.is_gt(c, x, y)) return mk_le(x, y);
        UNREACHABLE();
        return c;
    }

    void farkas_util::partition_ineqs() {
        m_reps.reset();
        m_his.reset();
        ++m_time;
        for (unsigned i = 0; i < m_ineqs.size(); ++i) {
            m_reps.push_back(process_term(m_ineqs[i].get()));
        }
        unsigned head = 0;
        while (head < m_ineqs.size()) {
            unsigned r = find(m_reps[head]);
            unsigned tail = head;
            for (unsigned i = head+1; i < m_ineqs.size(); ++i) {
                if (find(m_reps[i]) == r) {
                    ++tail;
                    if (tail != i) {
                        SASSERT(tail < i);
                        std::swap(m_reps[tail], m_reps[i]);
                        app_ref tmp(m);
                        tmp = m_ineqs[i].get();
                        m_ineqs[i] = m_ineqs[tail].get();
                        m_ineqs[tail] = tmp;
                        std::swap(m_coeffs[tail], m_coeffs[i]);
                    }
                }
            }
            head = tail + 1;
            m_his.push_back(head);
        }
    }

    unsigned farkas_util::find(unsigned idx) {
        if (m_ts.size() <= idx) {
            m_roots.resize(idx+1);
            m_size.resize(idx+1);
            m_ts.resize(idx+1); 
            m_roots[idx] = idx;
            m_ts[idx] = m_time;
            m_size[idx] = 1;                
            return idx;
        }
        if (m_ts[idx] != m_time) {
            m_size[idx] = 1;
            m_ts[idx]    = m_time;
            m_roots[idx] = idx;
            return idx;
        }
        while (true) {
            if (m_roots[idx] == idx) {
                return idx;
            }
            idx = m_roots[idx];
        }
    }

    void farkas_util::merge(unsigned i, unsigned j) {
        i = find(i);
        j = find(j);
        if (i == j) {
            return;
        }
        if (m_size[i] > m_size[j]) {
            std::swap(i, j);
        }
        m_roots[i] = j;
        m_size[j] += m_size[i];
    }
    unsigned farkas_util::process_term(expr* e) {
        unsigned r = e->get_id();
        ptr_vector<expr> todo;
        ast_mark mark;
        todo.push_back(e);
        while (!todo.empty()) {
            e = todo.back();
            todo.pop_back();
            if (mark.is_marked(e)) {
                continue;
            }
            mark.mark(e, true);
            if (is_uninterp(e)) {
                merge(r, e->get_id());
            }
            if (is_app(e)) {
                app* a = to_app(e);
                for (unsigned i = 0; i < a->get_num_args(); ++i) {
                    todo.push_back(a->get_arg(i));
                }
            }
        }
        return r;
    }
    expr_ref farkas_util::extract_consequence(unsigned lo, unsigned hi) {
        bool is_int = is_int_sort();
        app_ref zero(a.mk_numeral(rational::zero(), is_int), m);
        expr_ref res(m);
        res = zero;
        bool is_strict = false;
        bool is_eq     = true;
        expr* x, *y;
        for (unsigned i = lo; i < hi; ++i) {
            app* c = m_ineqs[i].get();
            if (m.is_eq(c, x, y)) {
                mul(m_coeffs[i],  x, res);
                mul(-m_coeffs[i], y, res);
            }
            if (a.is_lt(c, x, y) || a.is_gt(c, y, x)) {
                mul(m_coeffs[i],  x, res);
                mul(-m_coeffs[i], y, res);
                is_strict = true;
                is_eq = false;
            }
            if (a.is_le(c, x, y) || a.is_ge(c, y, x)) {
                mul(m_coeffs[i],  x, res);
                mul(-m_coeffs[i], y, res);
                is_eq = false;
            }
        }
        
        zero = a.mk_numeral(rational::zero(), a.is_int(res));
        if (is_eq) {
            res = m.mk_eq(res, zero);
        }
        else if (is_strict) {
            res = mk_lt(res, zero);
        }
        else {
            res = mk_le(res, zero);
        }            
        res = m.mk_not(res);
        th_rewriter rw(m);
        params_ref params;
        params.set_bool("gcd_rounding", true);
        rw.updt_params(params);
        proof_ref pr(m);
        expr_ref result(m);
        rw(res, result, pr);
        fix_dl(result);
        return result;            
    }

    void farkas_util::fix_dl(expr_ref& r) {
        expr* e;
        if (m.is_not(r, e)) {
            r = e;
            fix_dl(r);
            r = m.mk_not(r);
            return;
        }
        expr* e1, *e2, *e3, *e4;
        if ((m.is_eq(r, e1, e2) || a.is_lt(r, e1, e2) || a.is_gt(r, e1, e2) || 
             a.is_le(r, e1, e2) || a.is_ge(r, e1, e2))) {
            if (a.is_add(e1, e3, e4) && a.is_mul(e3)) {
                r = m.mk_app(to_app(r)->get_decl(), a.mk_add(e4,e3), e2);
            }
        }
    }

    void farkas_util::reset() {
        m_ineqs.reset();
        m_coeffs.reset();        
    }
    
    void farkas_util::add(rational const & coef, app * c) {
        bool is_pos = true;
        expr* e;
        while (m.is_not(c, e)) {
            is_pos = !is_pos;
            c = to_app(e);
        }
        
        if (!coef.is_zero() && !m.is_true(c)) {
            m_coeffs.push_back(coef);                
            m_ineqs.push_back(fix_sign(is_pos, c));                
        }
    }
    
    expr_ref farkas_util::get() {
        m_normalize_factor = rational::one();
        expr_ref res(m);
        if (m_coeffs.empty()) {
            res = m.mk_false();
            return res;
        }
        bool is_int = is_int_sort();
        if (is_int) {                
            normalize_coeffs();
        }
        
        if (m_split_literals) {
            // partition equalities into variable disjoint sets.
            // take the conjunction of these instead of the
            // linear combination.
            partition_ineqs();
            expr_ref_vector lits(m);
            unsigned lo = 0;
            for (unsigned i = 0; i < m_his.size(); ++i) {
                unsigned hi = m_his[i];
                lits.push_back(extract_consequence(lo, hi));
                lo = hi;
            }
            bool_rewriter(m).mk_or(lits.size(), lits.c_ptr(), res);
            IF_VERBOSE(2, { if (lits.size() > 1) { verbose_stream() << "combined lemma: " << mk_pp(res, m) << "\n"; } });
        }
        else {
            res = extract_consequence(0, m_coeffs.size());
        }

        TRACE("arith", 
              for (unsigned i = 0; i < m_coeffs.size(); ++i) {
                  tout << m_coeffs[i] << " * (" << mk_pp(m_ineqs[i].get(), m) << ") ";
              }
              tout << "\n";
              tout << res << "\n";
              );

        return res;
    }
}

