diff --git a/src/builder.rs b/src/builder.rs index 47aca04..85bea25 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -73,88 +73,200 @@ impl CpModelBuilder { vars: vars.into_iter().map(|v| v.0).collect(), })) } + + /// Add a linear constraint + /// + /// # Exemple + /// + /// ``` + /// # use cp_sat::builder::CpModelBuilder; + /// # use cp_sat::proto::CpSolverStatus; + /// let mut model = CpModelBuilder::default(); + /// let x = model.new_int_var([(0, 100)]); + /// let y = model.new_int_var([(0, 100)]); + /// model.add_linear_constraint([(1, x), (3, y)], [(301, i64::MAX)]); + /// let response = model.solve(); + /// assert_eq!(response.status(), CpSolverStatus::Optimal); + /// assert!(x.solution_value(&response) + 3 * y.solution_value(&response) >= 301); + /// ``` pub fn add_linear_constraint( &mut self, - expr: &LinearExpr, - (begin, end): (i64, i64), + expr: impl Into, + domain: impl IntoIterator, ) -> Constraint { + let expr = expr.into(); + let constant = expr.constant; self.add_cst(CstEnum::Linear(proto::LinearConstraintProto { - vars: expr.vars.clone(), - coeffs: expr.coeffs.clone(), - domain: vec![begin - expr.constant, end - expr.constant], + vars: expr.vars, + coeffs: expr.coeffs, + domain: domain + .into_iter() + .flat_map(|(begin, end)| { + [ + if begin == i64::MIN { + i64::MIN + } else { + begin.saturating_sub(constant) + }, + if end == i64::MAX { + i64::MAX + } else { + end.saturating_sub(constant) + }, + ] + }) + .collect(), })) } + + /// Add an equality constraint between 2 linear expressions. + /// + /// # Exemple + /// + /// ``` + /// # use cp_sat::builder::{CpModelBuilder, LinearExpr}; + /// # use cp_sat::proto::CpSolverStatus; + /// let mut model = CpModelBuilder::default(); + /// let x = model.new_int_var([(0, 50)]); + /// let y = model.new_int_var([(53, 100)]); + /// model.add_eq(y, LinearExpr::from(x) + 3); + /// let response = model.solve(); + /// assert_eq!(response.status(), CpSolverStatus::Optimal); + /// assert_eq!(y.solution_value(&response), x.solution_value(&response) + 3); + /// assert_eq!(50, x.solution_value(&response)); + /// assert_eq!(53, y.solution_value(&response)); + /// ``` pub fn add_eq, U: Into>( &mut self, lhs: T, rhs: U, ) -> Constraint { - self.add_eq_by_ref(&lhs.into(), &rhs.into()) - } - pub fn add_eq_by_ref(&mut self, lhs: &LinearExpr, rhs: &LinearExpr) -> Constraint { - let (mut cst, val) = create_linear_cst_proto(lhs, rhs); - cst.domain.extend([val, val]); - self.add_cst(CstEnum::Linear(cst)) + self.add_linear_constraint(lhs.into() - rhs.into(), [(0, 0)]) } + + /// Add a greater or equal constraint between 2 linear expressions. + /// + /// # Exemple + /// + /// ``` + /// # use cp_sat::builder::{CpModelBuilder, LinearExpr}; + /// # use cp_sat::proto::CpSolverStatus; + /// let mut model = CpModelBuilder::default(); + /// let x = model.new_int_var([(0, 50)]); + /// let y = model.new_int_var([(50, 100)]); + /// model.add_ge(x, y); + /// let response = model.solve(); + /// assert_eq!(response.status(), CpSolverStatus::Optimal); + /// assert!(x.solution_value(&response) >= y.solution_value(&response)); + /// assert_eq!(50, x.solution_value(&response)); + /// assert_eq!(50, y.solution_value(&response)); + /// ``` pub fn add_ge, U: Into>( &mut self, lhs: T, rhs: U, ) -> Constraint { - self.add_ge_by_ref(&lhs.into(), &rhs.into()) - } - pub fn add_ge_by_ref(&mut self, lhs: &LinearExpr, rhs: &LinearExpr) -> Constraint { - let (mut cst, val) = create_linear_cst_proto(lhs, rhs); - cst.domain.extend([val, i64::MAX]); - self.add_cst(CstEnum::Linear(cst)) + self.add_linear_constraint(lhs.into() - rhs.into(), [(0, i64::MAX)]) } + + /// Add a lesser or equal constraint between 2 linear expressions. + /// + /// # Exemple + /// + /// ``` + /// # use cp_sat::builder::{CpModelBuilder, LinearExpr}; + /// # use cp_sat::proto::CpSolverStatus; + /// let mut model = CpModelBuilder::default(); + /// let x = model.new_int_var([(50, 100)]); + /// let y = model.new_int_var([(0, 50)]); + /// model.add_le(x, y); + /// let response = model.solve(); + /// assert_eq!(response.status(), CpSolverStatus::Optimal); + /// assert!(x.solution_value(&response) <= y.solution_value(&response)); + /// assert_eq!(50, x.solution_value(&response)); + /// assert_eq!(50, y.solution_value(&response)); + /// ``` pub fn add_le, U: Into>( &mut self, lhs: T, rhs: U, ) -> Constraint { - self.add_le_by_ref(&lhs.into(), &rhs.into()) - } - pub fn add_le_by_ref(&mut self, lhs: &LinearExpr, rhs: &LinearExpr) -> Constraint { - let (mut cst, val) = create_linear_cst_proto(lhs, rhs); - cst.domain.extend([i64::MIN, val]); - self.add_cst(CstEnum::Linear(cst)) + self.add_linear_constraint(lhs.into() - rhs.into(), [(i64::MIN, 0)]) } + + /// Add a stricly greater constraint between 2 linear expressions. + /// + /// # Exemple + /// + /// ``` + /// # use cp_sat::builder::{CpModelBuilder, LinearExpr}; + /// # use cp_sat::proto::CpSolverStatus; + /// let mut model = CpModelBuilder::default(); + /// let x = model.new_int_var([(0, 51)]); + /// let y = model.new_int_var([(50, 100)]); + /// model.add_gt(x, y); + /// let response = model.solve(); + /// assert_eq!(response.status(), CpSolverStatus::Optimal); + /// assert!(x.solution_value(&response) > y.solution_value(&response)); + /// assert_eq!(51, x.solution_value(&response)); + /// assert_eq!(50, y.solution_value(&response)); + /// ``` pub fn add_gt, U: Into>( &mut self, lhs: T, rhs: U, ) -> Constraint { - self.add_gt_by_ref(&lhs.into(), &rhs.into()) - } - pub fn add_gt_by_ref(&mut self, lhs: &LinearExpr, rhs: &LinearExpr) -> Constraint { - let (mut cst, val) = create_linear_cst_proto(lhs, rhs); - cst.domain.extend([val + 1, i64::MAX]); - self.add_cst(CstEnum::Linear(cst)) + self.add_linear_constraint(lhs.into() - rhs.into(), [(1, i64::MAX)]) } + + /// Add a strictly lesser constraint between 2 linear expressions. + /// + /// # Exemple + /// + /// ``` + /// # use cp_sat::builder::{CpModelBuilder, LinearExpr}; + /// # use cp_sat::proto::CpSolverStatus; + /// let mut model = CpModelBuilder::default(); + /// let x = model.new_int_var([(50, 100)]); + /// let y = model.new_int_var([(0, 51)]); + /// model.add_lt(x, y); + /// let response = model.solve(); + /// assert_eq!(response.status(), CpSolverStatus::Optimal); + /// assert!(x.solution_value(&response) < y.solution_value(&response)); + /// assert_eq!(50, x.solution_value(&response)); + /// assert_eq!(51, y.solution_value(&response)); + /// ``` pub fn add_lt, U: Into>( &mut self, lhs: T, rhs: U, ) -> Constraint { - self.add_lt_by_ref(&lhs.into(), &rhs.into()) - } - pub fn add_lt_by_ref(&mut self, lhs: &LinearExpr, rhs: &LinearExpr) -> Constraint { - let (mut cst, val) = create_linear_cst_proto(lhs, rhs); - cst.domain.extend([i64::MIN, val - 1]); - self.add_cst(CstEnum::Linear(cst)) + self.add_linear_constraint(lhs.into() - rhs.into(), [(i64::MIN, -1)]) } + + /// Add a not equal constraint between 2 linear expressions. + /// + /// # Exemple + /// + /// ``` + /// # use cp_sat::builder::{CpModelBuilder, LinearExpr}; + /// # use cp_sat::proto::CpSolverStatus; + /// let mut model = CpModelBuilder::default(); + /// let x = model.new_int_var([(50, 100)]); + /// let y = model.new_int_var([(0, 51)]); + /// model.add_lt(x, y); + /// let response = model.solve(); + /// assert_eq!(response.status(), CpSolverStatus::Optimal); + /// assert!(x.solution_value(&response) < y.solution_value(&response)); + /// assert_eq!(50, x.solution_value(&response)); + /// assert_eq!(51, y.solution_value(&response)); + /// ``` pub fn add_ne, U: Into>( &mut self, lhs: T, rhs: U, ) -> Constraint { - self.add_ne_by_ref(&lhs.into(), &rhs.into()) - } - pub fn add_ne_by_ref(&mut self, lhs: &LinearExpr, rhs: &LinearExpr) -> Constraint { - let (mut cst, val) = create_linear_cst_proto(lhs, rhs); - cst.domain.extend([i64::MIN, val - 1, val + 1, i64::MAX]); - self.add_cst(CstEnum::Linear(cst)) + self.add_linear_constraint(lhs.into() - rhs.into(), [(i64::MIN, -1), (1, i64::MAX)]) } pub fn add_min_eq( &mut self, @@ -192,7 +304,6 @@ impl CpModelBuilder { /// ``` /// # use cp_sat::builder::CpModelBuilder; /// # use cp_sat::proto::{CpSolverStatus, SatParameters}; - /// # use cp_sat::proto::sat_parameters::SearchBranching; /// let mut model = CpModelBuilder::default(); /// let x = model.new_int_var([(0, 100)]); /// let y = model.new_bool_var(); @@ -201,7 +312,6 @@ impl CpModelBuilder { /// model.add_ge([(1, x), (3, y.into())], 3); /// model.maximize(y); /// let response = model.solve(); - /// println!("{:#?}", response); /// assert_eq!(response.status(), CpSolverStatus::Optimal); /// ``` pub fn add_hint(&mut self, var: impl Into, value: i64) { @@ -311,6 +421,31 @@ pub struct LinearExpr { coeffs: Vec, constant: i64, } +impl std::ops::AddAssign for LinearExpr { + fn add_assign(&mut self, mut rhs: Self) { + if self.vars.len() < rhs.vars.len() { + std::mem::swap(self, &mut rhs); + } + self.vars.extend_from_slice(&rhs.vars); + self.coeffs.extend_from_slice(&rhs.coeffs); + self.constant += rhs.constant; + } +} +impl std::ops::Neg for LinearExpr { + type Output = LinearExpr; + fn neg(mut self) -> Self::Output { + for c in &mut self.coeffs { + *c = -*c; + } + self.constant = -self.constant; + self + } +} +impl> std::ops::SubAssign for LinearExpr { + fn sub_assign(&mut self, rhs: L) { + *self += -rhs.into(); + } +} impl std::ops::AddAssign for LinearExpr { fn add_assign(&mut self, rhs: i64) { self.constant += rhs; @@ -395,6 +530,16 @@ where self } } +impl std::ops::Sub for LinearExpr +where + LinearExpr: std::ops::SubAssign, +{ + type Output = LinearExpr; + fn sub(mut self, rhs: T) -> Self::Output { + self -= rhs; + self + } +} impl From for proto::LinearExpressionProto { fn from(expr: LinearExpr) -> Self { proto::LinearExpressionProto { @@ -404,27 +549,3 @@ impl From for proto::LinearExpressionProto { } } } - -fn create_linear_cst_proto( - left: &LinearExpr, - right: &LinearExpr, -) -> (proto::LinearConstraintProto, i64) { - ( - proto::LinearConstraintProto { - vars: left - .vars - .iter() - .copied() - .chain(right.vars.iter().copied()) - .collect(), - coeffs: left - .coeffs - .iter() - .copied() - .chain(right.coeffs.iter().map(|&c| -c)) - .collect(), - domain: vec![], - }, - right.constant - left.constant, - ) -}