﻿/**
\private
*/
#pragma once
#include "constraint.h"
#include "general_constraint.h"
#include "linear_expression.h"
#include "quad_expression.h"
#include "general_expression.h"
#include "variable.h"
#include "../arhiplex_itf.h"

namespace arhiplex
{
class Operators
{
  public:
    friend LinearExpression GetLinearExpression(const QuadExpression& expr)
    {
        return expr.Get()->GetLinearPart();
    }

    static QuadExpression GetQuadExpression(const LinearExpression& expr)
    {
        return expr.Get()->GetQuadPart();
    }

    static QuadExpression multiply(const LinearExpression& expr, const Variable& var, const double expr_coef = 1.0)
    {
        QuadExpression res;
        for (int i = 0; i < expr.GetTermsCount(); ++i)
        {
            auto var1 = expr.GetTermVariable(i);
            auto coef = expr.GetTermCoeff(i);
            if(const auto result_coef = coef * expr_coef)
                res.AddTerm(var, var1, coef * expr_coef);
        }

        if(const auto result_coef = expr.GetConstant() * expr_coef)
            res.AddExpression(var * result_coef);
        
        return res;
    }
    friend QuadExpression operator*(const Variable &var1, const Variable &var2)
    {
        return QuadExpression(var1, var2);
    }

    friend QuadExpression operator*(const LinearExpression& expr, const Variable& var)
    {
        return multiply(expr, var);
    }

    friend QuadExpression operator*(const LinearExpression& expr1, const LinearExpression& expr2)
    {
        // Build explicitly to correctly account for:
        // - quad terms: sum_i sum_j ai*bj * xi*xj
        // - linear terms: c1*bj*xj + c2*ai*xi
        // - constant: c1*c2 (stored in Expression offset via LinearExpression API)
        QuadExpression res;

        const double c1 = expr1.GetConstant();
        const double c2 = expr2.GetConstant();

        // quadratic part from linear terms
        for (int i = 0; i < expr1.GetTermsCount(); ++i) {
            const auto v1 = expr1.GetTermVariable(i);
            const double a = expr1.GetTermCoeff(i);
            if (a == 0.0) {
                continue;
            }
            for (int j = 0; j < expr2.GetTermsCount(); ++j) {
                const auto v2 = expr2.GetTermVariable(j);
                const double b = expr2.GetTermCoeff(j);
                const double coef = a * b;
                if (coef != 0.0) {
                    res.AddTerm(v1, v2, coef);
                }
            }
        }

        // linear part from constants * terms
        if (c1 != 0.0) {
            for (int j = 0; j < expr2.GetTermsCount(); ++j) {
                const auto v2 = expr2.GetTermVariable(j);
                const double b = expr2.GetTermCoeff(j);
                const double coef = c1 * b;
                if (coef != 0.0) {
                    res.AddExpression(v2 * coef);
                }
            }
        }
        if (c2 != 0.0) {
            for (int i = 0; i < expr1.GetTermsCount(); ++i) {
                const auto v1 = expr1.GetTermVariable(i);
                const double a = expr1.GetTermCoeff(i);
                const double coef = c2 * a;
                if (coef != 0.0) {
                    res.AddExpression(v1 * coef);
                }
            }
        }

        // constant term
        if (c1 != 0.0 && c2 != 0.0) {
            // QuadExpression has no constant API, but its linear part shares the same Expression offset.
            res.AddExpression(LinearExpression(c1 * c2));
        }

        return res;
    }

    friend QuadExpression operator*(const Variable& var, const LinearExpression& expr)
    {
        return expr * var;
    }

    friend LinearExpression operator+(const Variable &var, double a)
    {
        LinearExpression tmp(var, 1.0);
        tmp.SetConstant(a);
        return tmp;
    }

    friend LinearExpression operator+(double a, const Variable &var)
    {
        LinearExpression tmp(var, 1.0);
        tmp.SetConstant(a);
        return tmp;
    }

    friend LinearExpression operator+(const Variable &left, const Variable &right)
    {
        LinearExpression tmp(left, 1.0);
        tmp.AddTerm(right, 1.0);
        return tmp;
    }

    friend LinearExpression operator-(const Variable &var, double a)
    {
        LinearExpression tmp(var, 1.0);
        tmp.SetConstant(-a);
        return tmp;
    }

    friend LinearExpression operator-(double a, const Variable &var)
    {
        LinearExpression tmp(var, -1.0);
        tmp.SetConstant(a);
        return tmp;
    }

    friend LinearExpression operator-(const Variable &var)
    {
        return 0.0 - var;
    }

    friend LinearExpression operator-(const Variable &left, const Variable &right)
    {
        LinearExpression tmp(left, 1.0);
        tmp.AddTerm(right, -1.0);
        return tmp;
    }

    friend LinearExpression operator*(double a, const Variable &var)
    {
        return LinearExpression(var, a);
    }

    friend LinearExpression operator*(const Variable &var, double a)
    {
        return LinearExpression(var, a);
    }

    friend GeneralExpression operator*(const LinearExpression &expr1, const QuadExpression &expr2) {
        GeneralExpression tmp(expr1.Get());
        return tmp * expr2;
    }

    friend GeneralExpression operator*(const LinearExpression &expr1, const GeneralExpression &expr2) {
        GeneralExpression tmp(expr1.Get());
        return tmp * expr2;
    }

    friend GeneralExpression operator/(const LinearExpression &expr, const Variable &var) {
        GeneralExpression tmp(expr.Get());
        return tmp / var;
    }

    friend GeneralExpression operator/(const LinearExpression &expr1, const LinearExpression &expr2) {
        GeneralExpression tmp(expr1.Get());
        return tmp / expr2;
    }

    friend GeneralExpression operator/(const LinearExpression &expr1, const QuadExpression &expr2) {
        GeneralExpression tmp(expr1.Get());
        return tmp / expr2;
    }
    friend GeneralExpression operator/(const LinearExpression &expr1, const GeneralExpression &expr2) {
        GeneralExpression tmp(expr1.Get());
        return tmp / expr2;
    }
    friend GeneralExpression operator-(const LinearExpression &left, const GeneralExpression &right) {
        GeneralExpression tmp(left.Get());
        return tmp - right;
    }
    friend GeneralExpression operator+(const LinearExpression &left, const GeneralExpression &right) {
        GeneralExpression tmp(left.Get());
        return tmp + right;
    }

    friend LinearExpression operator/(const Variable &var, double a)
    {
        return LinearExpression(var, 1.0 / a);
    }
    
    friend Constraint operator==(double a, const Variable &right)
    {
        return {right, constraint_sense::equal, a};
    }

    friend Constraint operator==(const Variable &left, double a)
    {
        return {left, constraint_sense::equal, a};
    }

    friend Constraint operator==(const Variable &left, const Variable &right)
    {
        return {left - right, constraint_sense::equal, 0.0};
    }

    friend Constraint operator>=(double a, const Variable &right)
    {
        return {right, constraint_sense::less_equal, a};
    }

    friend Constraint operator>=(const Variable &left, double a)
    {
        return {left, constraint_sense::greater_equal, a};
    }

    friend Constraint operator>=(const Variable &left, const Variable &right)
    {
        return {left - right, constraint_sense::greater_equal, 0.0};
    }

    friend Constraint operator<=(double a, const Variable &right)
    {
        return {right, constraint_sense::greater_equal, a};
    }

    friend Constraint operator<=(const Variable &left, double a)
    {
        return {left, constraint_sense::less_equal, a};
    }

    friend Constraint operator<=(const Variable &left, const Variable &right)
    {
        return {left - right, constraint_sense::less_equal, 0.0};
    }

    friend Constraint operator>=(const LinearExpression &left, double a)
    {
        return {left, constraint_sense::greater_equal, a};
    }

    friend Constraint operator>=(double a, const LinearExpression &right)
    {
        return {right, constraint_sense::less_equal, a};
    }

    friend Constraint operator>=(const LinearExpression &left, const Variable &var)
    {
        return {left - var, constraint_sense::greater_equal, 0.0};
    }

    friend Constraint operator>=(const Variable &var, const LinearExpression &right)
    {
        return {right - var, constraint_sense::less_equal, 0.0};
    }

    friend Constraint operator>=(const LinearExpression &left, const LinearExpression &right)
    {
        return {left - right, constraint_sense::greater_equal, 0.0};
    }

    friend Constraint operator<=(const LinearExpression &left, double a)
    {
        return {left, constraint_sense::less_equal, a};
    }

    friend Constraint operator<=(double a, const LinearExpression &right)
    {
        return {right, constraint_sense::greater_equal, a};
    }

    friend Constraint operator<=(const LinearExpression &left, const Variable &var)
    {
        return {left - var, constraint_sense::less_equal, 0.0};
    }

    friend Constraint operator<=(const Variable &var, const LinearExpression &right)
    {
        return {right - var, constraint_sense::greater_equal, 0.0};
    }

    friend Constraint operator<=(const LinearExpression &left, const LinearExpression &right)
    {
        return {left - right, constraint_sense::less_equal, 0.0};
    }

    friend Constraint operator==(const LinearExpression &left, double a)
    {
        return {left, constraint_sense::equal, a};
    }

    friend Constraint operator==(double a, const LinearExpression &right)
    {
        return {right, constraint_sense::equal, a};
    }

    friend Constraint operator==(const LinearExpression &left, const Variable &var)
    {
        return {left - var, constraint_sense::equal, 0.0};
    }

    friend Constraint operator==(const Variable &var, const LinearExpression &right)
    {
        return {right - var, constraint_sense::equal, 0.0};
    }

    friend Constraint operator==(const LinearExpression &left, const LinearExpression &right)
    {
        return {left - right, constraint_sense::equal, 0.0};
    }

    friend QuadExpression GetQuadExpression(const LinearExpression &expr)
    {
        return expr.Get()->GetQuadPart();
    }

    friend LinearExpression operator-(const LinearExpression &left, const QuadExpression &right)
    {
        LinearExpression ret = left.CreateFreeCopy();
        QuadExpression quad(GetQuadExpression(ret));
        quad -= right;
        return ret;
    }

    friend LinearExpression operator+(const LinearExpression &left, const QuadExpression &right)
    {
        LinearExpression ret = left.CreateFreeCopy();
        QuadExpression quad(GetQuadExpression(ret));
        quad += right;
        return ret;
    }

    friend Constraint operator==(const QuadExpression &left, const LinearExpression &right)
    {
        return {left - right, constraint_sense::equal, 0.0};
    }


    friend Constraint operator==(const QuadExpression &left, const Variable &var)
    {
        QuadExpression ret = left.CreateFreeCopy();
        LinearExpression linear(ret.GetLinearPart());
        linear -= var;
        return {std::move(ret), constraint_sense::equal, 0.0};
    }

    friend Constraint operator==(const QuadExpression &left, const QuadExpression &right)
    {
        return {left - right, constraint_sense::equal, 0.0};
    }

    friend Constraint operator<=(const QuadExpression& left, double c)
    {
        QuadExpression ret = left.CreateFreeCopy();
        LinearExpression linear(ret.GetLinearPart());
        linear -= c;
        return { std::move(ret), constraint_sense::less_equal, 0.0 };
    }

    friend Constraint operator==(const QuadExpression& left, double c)
    {
        QuadExpression ret = left.CreateFreeCopy();
        LinearExpression linear(ret.GetLinearPart());
        linear -= c;
        return { std::move(ret), constraint_sense::equal, 0.0 };
    }

    friend Constraint operator>=(const QuadExpression& left, const QuadExpression& right)
    {
        return { left - right, constraint_sense::greater_equal, 0.0 };
    }

    friend Constraint operator<=(const QuadExpression& left, const QuadExpression& right)
    {
        return { left - right, constraint_sense::greater_equal, 0.0 };
    }

    friend Constraint operator>=(const QuadExpression& left, const LinearExpression& right)
    {
        return { left - right, constraint_sense::greater_equal, 0.0 };
    }

    friend Constraint operator<=(const QuadExpression& left, const LinearExpression& right)
    {
        return { left - right, constraint_sense::less_equal, 0.0 };
    }

    friend Constraint operator>=(const QuadExpression& left, const Variable& var)
    {
        return { left - var, constraint_sense::greater_equal, 0.0 };
    }

    friend Constraint operator<=(const QuadExpression& left, const Variable& var)
    {
        return { left - var, constraint_sense::less_equal, 0.0 };
    }

    friend Constraint operator>=(const QuadExpression& left, double c)
    {
        return { left - c, constraint_sense::greater_equal, 0.0 };
    }

    friend GeneralConstraint operator==(const GeneralExpression &left, const Variable &right) {
        return GeneralConstraint(left, right);
    }

    friend GeneralExpression operator-(const QuadExpression &left, const GeneralExpression &right) {
        GeneralExpression expr(left.Get());
        return expr - right;
    }

    friend GeneralExpression operator+(const QuadExpression &left, const GeneralExpression &right) {
        GeneralExpression expr(left.Get());
        return expr + right;
    }
    
};

}
