Skip to content

Commit

Permalink
r
Browse files Browse the repository at this point in the history
  • Loading branch information
a-zakir committed Feb 20, 2025
1 parent fd94a20 commit cdabd9e
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,15 @@ class EvalVisitor: public NodeVisitor<EvaluationResult>
* @brief Constructs an evaluation visitor with the specified context.
*
* @param context The evaluation context.
* @param dataSeriesKeys
* @param fillContext
*/
explicit EvalVisitor(EvaluationContext context,
Optimisation::LinearProblemApi::DataSeriesKeys dataSeriesKeys);
Optimisation::LinearProblemApi::FillContext fillContext);
std::string name() const override;

private:
const EvaluationContext context_;
Optimisation::LinearProblemApi::DataSeriesKeys dataSeriesKeys_;
Optimisation::LinearProblemApi::FillContext fillContext_;
EvaluationResult visit(const Nodes::SumNode* node) override;
EvaluationResult visit(const Nodes::SubtractionNode* node) override;
EvaluationResult visit(const Nodes::MultiplicationNode* node) override;
Expand Down
15 changes: 6 additions & 9 deletions src/expressions/visitors/EvalVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
namespace Antares::Expressions::Visitors
{
EvalVisitor::EvalVisitor(EvaluationContext context,
Optimisation::LinearProblemApi::DataSeriesKeys dataSeriesKeys):
Optimisation::LinearProblemApi::FillContext fillContext):
context_(std::move(context)),
dataSeriesKeys_(std::move(dataSeriesKeys))
fillContext_(std::move(fillContext))
{
}

Expand Down Expand Up @@ -89,15 +89,12 @@ EvaluationResult EvalVisitor::visit(const Nodes::ParameterNode* node)
else
{
std::vector<double> params;
params.reserve(dataSeriesKeys_.fillContext.getNumberOfTimestep());
for (auto timeStep = dataSeriesKeys_.fillContext.getFirstTimeStep();
timeStep <= dataSeriesKeys_.fillContext.getLastTimeStep();
params.reserve(fillContext_.getNumberOfTimestep());
for (auto timeStep = fillContext_.getFirstTimeStep();
timeStep <= fillContext_.getLastTimeStep();
++timeStep)
{
params.emplace_back(context_.getParameterValue(node->value(),
dataSeriesKeys_.scenarioGroup,
dataSeriesKeys_.scenario,
timeStep));
params.emplace_back(context_.getParameterValue(node->value(), "", 0, timeStep));
}
return EvaluationResult{params};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ class FillContext
unsigned lastTimeStep = 0;
};

struct DataSeriesKeys
{
FillContext fillContext;
std::string scenarioGroup;
unsigned scenario;
};

class ILinearProblemData
{
Expand Down
15 changes: 3 additions & 12 deletions src/solver/optim-model-filler/ComponentFiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,7 @@ void ComponentFiller::addVariables(Optimisation::LinearProblemApi::ILinearProble
{},
data);

Expressions::Visitors::EvalVisitor evaluator(evaluationContext,
{.fillContext = ctx,
.scenarioGroup = component_.getScenarioGroupId(),
.scenario = 0});
Expressions::Visitors::EvalVisitor evaluator(evaluationContext, ctx);
for (const auto& variable: component_.getModel()->Variables() | std::views::values)
{
const auto& lb = evaluator.dispatch(variable.LowerBound().RootNode());
Expand Down Expand Up @@ -256,10 +253,7 @@ void ComponentFiller::addConstraints(Optimisation::LinearProblemApi::ILinearProb
Expressions::Visitors::EvaluationContext evaluationContext(component_.getParameterValues(),
{},
data);
ReadLinearConstraintVisitor visitor(evaluationContext,
{.fillContext = ctx,
.scenarioGroup = component_.getScenarioGroupId(),
.scenario = 0});
ReadLinearConstraintVisitor visitor(evaluationContext, ctx);
for (const auto& constraint: component_.getModel()->getConstraints() | std::views::values)
{
auto* root_node = constraint.expression().RootNode();
Expand Down Expand Up @@ -292,10 +286,7 @@ void ComponentFiller::addObjective(Optimisation::LinearProblemApi::ILinearProble
{},
data);

ReadLinearExpressionVisitor visitor(evaluationContext,
{.fillContext = ctx,
.scenarioGroup = component_.getScenarioGroupId(),
.scenario = 0});
ReadLinearExpressionVisitor visitor(evaluationContext, ctx);

auto linear_expressions = visitor.dispatch(model->Objective().RootNode())
.GetLinearExpressions();
Expand Down
4 changes: 2 additions & 2 deletions src/solver/optim-model-filler/ReadLinearConstraintVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ namespace Antares::Optimization

ReadLinearConstraintVisitor::ReadLinearConstraintVisitor(
Expressions::Visitors::EvaluationContext context,
const Optimisation::LinearProblemApi::DataSeriesKeys& dataSeriesKeys):
linear_expression_visitor_(std::move(context), dataSeriesKeys)
const Optimisation::LinearProblemApi::FillContext& fillContext):
linear_expression_visitor_(std::move(context), fillContext)
{
}

Expand Down
22 changes: 9 additions & 13 deletions src/solver/optim-model-filler/ReadLinearExpressionVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ namespace Antares::Optimization

ReadLinearExpressionVisitor::ReadLinearExpressionVisitor(
Expressions::Visitors::EvaluationContext context,
Optimisation::LinearProblemApi::DataSeriesKeys dataSeriesKeys):
Optimisation::LinearProblemApi::FillContext fillContext):
context_(std::move(context)),
dataSeriesKeys_(std::move(dataSeriesKeys))
fillContext_(std::move(fillContext))
{
}

Expand All @@ -49,7 +49,7 @@ TimeDependentLinearExpression ReadLinearExpressionVisitor::visit(const SumNode*
auto operands = node->getOperands();
return std::accumulate(std::begin(operands),
std::end(operands),
TimeDependentLinearExpression(dataSeriesKeys_.fillContext),
TimeDependentLinearExpression(fillContext_),
[this](TimeDependentLinearExpression sum, Node* operand)
{ return sum + dispatch(operand); });
}
Expand Down Expand Up @@ -91,8 +91,7 @@ TimeDependentLinearExpression ReadLinearExpressionVisitor::visit(const NegationN

TimeDependentLinearExpression ReadLinearExpressionVisitor::visit(const VariableNode* node)
{
return TimeDependentLinearExpression(dataSeriesKeys_.fillContext,
LinearExpression(0, {{node->value(), 1}}));
return TimeDependentLinearExpression(fillContext_, LinearExpression(0, {{node->value(), 1}}));
}

TimeDependentLinearExpression ReadLinearExpressionVisitor::visit(const ParameterNode* node)
Expand All @@ -108,22 +107,19 @@ TimeDependentLinearExpression ReadLinearExpressionVisitor::visit(const Parameter
else if (systemParameter.type == Expressions::Visitors::ParameterType::CONSTANT)
{
return TimeDependentLinearExpression(
dataSeriesKeys_.fillContext,
fillContext_,
LinearExpression(context_.getSystemParameterValueAsDouble(node->value()), {}));
}
else // only timedepent
{
std::map<unsigned int, LinearExpression> linearExpressions;

for (auto timeStep = dataSeriesKeys_.fillContext.getFirstTimeStep();
timeStep <= dataSeriesKeys_.fillContext.getLastTimeStep();
for (auto timeStep = fillContext_.getFirstTimeStep();
timeStep <= fillContext_.getLastTimeStep();
++timeStep)
{
linearExpressions[timeStep] = LinearExpression(
context_.getParameterValue(node->value(),
dataSeriesKeys_.scenarioGroup,
dataSeriesKeys_.scenario,
timeStep),
context_.getParameterValue(node->value(), "", 0, timeStep),
{});
}
return TimeDependentLinearExpression(linearExpressions);
Expand All @@ -132,7 +128,7 @@ TimeDependentLinearExpression ReadLinearExpressionVisitor::visit(const Parameter

TimeDependentLinearExpression ReadLinearExpressionVisitor::visit(const LiteralNode* node)
{
return TimeDependentLinearExpression(dataSeriesKeys_.fillContext,
return TimeDependentLinearExpression(fillContext_,
LinearExpression(node->value(), {}));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ReadLinearConstraintVisitor
ReadLinearConstraintVisitor() = default;
explicit ReadLinearConstraintVisitor(
Expressions::Visitors::EvaluationContext context,
const Optimisation::LinearProblemApi::DataSeriesKeys& dataSeriesKeys);
const Optimisation::LinearProblemApi::FillContext& fillContext);
std::string name() const override;

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ReadLinearExpressionVisitor
public:
explicit ReadLinearExpressionVisitor(
Expressions::Visitors::EvaluationContext context,
Optimisation::LinearProblemApi::DataSeriesKeys dataSeriesKeys);
Optimisation::LinearProblemApi::FillContext fillContext);

ReadLinearExpressionVisitor() = default;
std::string name() const override;
Expand Down Expand Up @@ -69,6 +69,6 @@ class ReadLinearExpressionVisitor
TimeDependentLinearExpression visit(
const Expressions::Nodes::ComponentParameterNode* node) override;

Optimisation::LinearProblemApi::DataSeriesKeys dataSeriesKeys_;
Optimisation::LinearProblemApi::FillContext fillContext_;
};
} // namespace Antares::Optimization
3 changes: 1 addition & 2 deletions src/tests/src/expressions/test_DeepWideTrees.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ struct MyDummyFixture: Registry<Node>
{
Antares::Optimisation::LinearProblemDataImpl::LinearProblemData data;
EvaluationContext evaluationContext{{}, {}, data};
EvalVisitor evalVisitor{evaluationContext,
{.fillContext = {0, 0}, .scenarioGroup = "", .scenario = 0}};
EvalVisitor evalVisitor{evaluationContext, {0, 0}};
};

BOOST_FIXTURE_TEST_CASE(deep_tree_even, MyDummyFixture)
Expand Down
25 changes: 7 additions & 18 deletions src/tests/src/expressions/test_PrintAndEvalNodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,8 @@ struct MyDummyFixture: Registry<Node>
{
Antares::Optimisation::LinearProblemDataImpl::LinearProblemData data;
EvaluationContext evaluationContext{{}, {}, data};
Antares::Optimisation::LinearProblemApi::DataSeriesKeys keys{.fillContext = {0, 0},
.scenarioGroup = "",
.scenario = 0};
EvalVisitor evalVisitor{evaluationContext, keys};
Antares::Optimisation::LinearProblemApi::FillContext fillContext{0, 0};
EvalVisitor evalVisitor{evaluationContext, fillContext};
};

BOOST_AUTO_TEST_CASE(print_single_literal)
Expand Down Expand Up @@ -412,7 +410,7 @@ BOOST_FIXTURE_TEST_CASE(evaluate_param, MyDummyFixture)
const std::string value = "221.3";
EvaluationContext context({build_context_parameter_with("my-param", value)}, {}, data);

EvalVisitor evalVisitor(context, keys);
EvalVisitor evalVisitor(context, fillContext);
const double eval = evalVisitor.dispatch(&root).valueAsDouble();

BOOST_CHECK_EQUAL(std::stod(value), eval);
Expand Down Expand Up @@ -441,10 +439,7 @@ BOOST_FIXTURE_TEST_CASE(evaluate_time_dependent_param, MyDummyFixture)

unsigned hour_0 = 0;
unsigned hour_1 = 1;
EvalVisitor evalVisitor(context,
{.fillContext = {hour_0, hour_1 /*two hours*/},
.scenarioGroup = "",
.scenario = 0});
EvalVisitor evalVisitor(context, {hour_0, hour_1 /*two hours*/});
const auto eval = evalVisitor.dispatch(&root).valuesAsVector();

BOOST_CHECK_EQUAL(eval[0], hour_0);
Expand All @@ -465,10 +460,7 @@ BOOST_FIXTURE_TEST_CASE(evaluate_time_dependent_multiplication, MyDummyFixture)

unsigned hour_0 = 0;
unsigned hour_1 = 1;
EvalVisitor evalVisitor(context,
{.fillContext = {hour_0, hour_1 /*two hours*/},
.scenarioGroup = "",
.scenario = 0});
EvalVisitor evalVisitor(context, {hour_0, hour_1 /*two hours*/});
const auto eval = evalVisitor.dispatch(&root).valuesAsVector();

BOOST_CHECK_EQUAL(eval[0], hour_0 * literal.value());
Expand Down Expand Up @@ -519,10 +511,7 @@ void evaluate_time_dependent_operation()
dummy_data);
unsigned hour_0 = 0;
unsigned hour_1 = 1;
EvalVisitor evalVisitor(context,
{.fillContext = {hour_0, hour_1 /*two hours*/},
.scenarioGroup = "",
.scenario = 0});
EvalVisitor evalVisitor(context, {hour_0, hour_1 /*two hours*/});
const auto eval = evalVisitor.dispatch(&root).valuesAsVector();

BOOST_CHECK_EQUAL(eval[0], evalExpected<BinaryNode>(hour_0, literal.value()));
Expand All @@ -545,7 +534,7 @@ BOOST_FIXTURE_TEST_CASE(evaluate_variable, MyDummyFixture)
const double value = 221.3;
EvaluationContext context({}, {{"my-variable", value}}, data);

EvalVisitor evalVisitor(context, keys);
EvalVisitor evalVisitor(context, fillContext);
const double eval = evalVisitor.dispatch(&root).valueAsDouble();

BOOST_CHECK_EQUAL(value, eval);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct MyDummyFixture: Registry<Node>
{
Antares::Optimisation::LinearProblemDataImpl::LinearProblemData data;
EvaluationContext evaluationContext{{}, {}, data};
ReadLinearConstraintVisitor visitor{evaluationContext, {.fillContext = {0, 0}}};
ReadLinearConstraintVisitor visitor{evaluationContext, {0, 0}};
};

BOOST_FIXTURE_TEST_CASE(test_name, MyDummyFixture)
Expand All @@ -68,7 +68,7 @@ BOOST_FIXTURE_TEST_CASE(test_visit_equal_node, MyDummyFixture)
create<NegationNode>(create<ParameterNode>("param1")));
Node* node = create<EqualNode>(lhs, rhs);
EvaluationContext context({build_context_parameter_with("param1", "9.")}, {}, data);
ReadLinearConstraintVisitor visitor(context, {.fillContext = {0, 0}});
ReadLinearConstraintVisitor visitor(context, {0, 0});
auto constraint = visitor.dispatch(node)[0];
BOOST_CHECK_EQUAL(constraint.lb, -14.);
BOOST_CHECK_EQUAL(constraint.ub, -14.);
Expand All @@ -87,7 +87,7 @@ BOOST_FIXTURE_TEST_CASE(test_visit_less_than_or_equal_node, MyDummyFixture)
create<NegationNode>(create<ParameterNode>("param1")));
Node* node = create<LessThanOrEqualNode>(lhs, rhs);
EvaluationContext context({build_context_parameter_with("param1", "10.")}, {}, data);
ReadLinearConstraintVisitor visitor(context, {.fillContext = {0, 0}});
ReadLinearConstraintVisitor visitor(context, {0, 0});
auto constraint = visitor.dispatch(node)[0];
BOOST_CHECK_EQUAL(constraint.lb, -std::numeric_limits<double>::infinity());
BOOST_CHECK_EQUAL(constraint.ub, -1.);
Expand All @@ -107,7 +107,7 @@ BOOST_FIXTURE_TEST_CASE(test_visit_greater_than_or_equal_node, MyDummyFixture)
create<NegationNode>(create<ParameterNode>("param1")));
Node* node = create<GreaterThanOrEqualNode>(lhs, rhs);
EvaluationContext context({build_context_parameter_with("param1", "9.")}, {}, data);
ReadLinearConstraintVisitor visitor(context, {.fillContext = {0, 0}});
ReadLinearConstraintVisitor visitor(context, {0, 0});
auto constraint = visitor.dispatch(node)[0];
BOOST_CHECK_EQUAL(constraint.lb, -14);
BOOST_CHECK_EQUAL(constraint.ub, std::numeric_limits<double>::infinity());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct MyDummyFixture: Registry<Node>
{
Antares::Optimisation::LinearProblemDataImpl::LinearProblemData data;
EvaluationContext evaluationContext{{}, {}, data};
ReadLinearExpressionVisitor visitor{evaluationContext, {.fillContext = {0, 0}}};
ReadLinearExpressionVisitor visitor{evaluationContext, {0, 0}};
};

BOOST_FIXTURE_TEST_CASE(name, MyDummyFixture)
Expand Down Expand Up @@ -71,7 +71,7 @@ BOOST_FIXTURE_TEST_CASE(visit_literal_plus_param, MyDummyFixture)
// 5 + param(3) = 8
Node* sum = create<SumNode>(create<LiteralNode>(5.), create<ParameterNode>("param"));
EvaluationContext evaluation_context({{build_context_parameter_with("param", "3.")}}, {}, data);
ReadLinearExpressionVisitor visitor(evaluation_context, {.fillContext = {0, 0}});
ReadLinearExpressionVisitor visitor(evaluation_context, {0, 0});
auto linear_expression = visitor.dispatch(sum).GetLinearExpressions().at(0);
BOOST_CHECK_EQUAL(linear_expression.offset(), 8.);
BOOST_CHECK(linear_expression.coefPerVar().empty());
Expand All @@ -86,7 +86,7 @@ BOOST_FIXTURE_TEST_CASE(visit_literal_plus_param_plus_var, MyDummyFixture)
EvaluationContext evaluation_context({{build_context_parameter_with("param", "-5.")}},
{},
data);
ReadLinearExpressionVisitor visitor(evaluation_context, {.fillContext = {0, 0}});
ReadLinearExpressionVisitor visitor(evaluation_context, {0, 0});
auto linear_expression = visitor.dispatch(sum).GetLinearExpressions().at(0);
BOOST_CHECK_EQUAL(linear_expression.offset(), 55.);
BOOST_CHECK_EQUAL(linear_expression.coefPerVar().size(), 1);
Expand Down Expand Up @@ -119,7 +119,7 @@ BOOST_FIXTURE_TEST_CASE(visit_literal_plus_time_dependent_param_plus_var, MyDumm

unsigned hour_0 = 0;
unsigned hour_1 = 1;
ReadLinearExpressionVisitor visitor(evaluation_context, {.fillContext = {hour_0, hour_1}});
ReadLinearExpressionVisitor visitor(evaluation_context, {hour_0, hour_1});
auto linear_expressions = visitor.dispatch(sum).GetLinearExpressions();
BOOST_CHECK_EQUAL(linear_expressions.at(0).offset(), 60.);
BOOST_CHECK_EQUAL(linear_expressions.at(1).offset(), 61.);
Expand All @@ -137,7 +137,7 @@ BOOST_FIXTURE_TEST_CASE(visit_param_declared_const_in_library_but_time_dep_in_sy
{},
data);

ReadLinearExpressionVisitor visitor(evaluation_context, {.fillContext = {0, 1}});
ReadLinearExpressionVisitor visitor(evaluation_context, {0, 1});
BOOST_CHECK_THROW(visitor.dispatch(&p), std::invalid_argument);
}

Expand Down Expand Up @@ -189,7 +189,7 @@ BOOST_FIXTURE_TEST_CASE(visit_complex_expression, MyDummyFixture)
build_context_parameter_with("param2", "8.")},
{},
data);
ReadLinearExpressionVisitor visitor(evaluation_context, {.fillContext = {0, 0}});
ReadLinearExpressionVisitor visitor(evaluation_context, {0, 0});
auto linear_expression = visitor.dispatch(big_sum).GetLinearExpressions().at(0);
BOOST_CHECK_EQUAL(linear_expression.offset(), 10.);
BOOST_CHECK_EQUAL(linear_expression.coefPerVar().size(), 2);
Expand Down

0 comments on commit cdabd9e

Please sign in to comment.