Skip to content

Commit

Permalink
Merge pull request Pyomo#3206 from eslickj/AMPLFUNC_dup_fix
Browse files Browse the repository at this point in the history
AMPL solver duplicate funcadd fix
  • Loading branch information
mrmundt authored Apr 9, 2024
2 parents 5716566 + 942df95 commit 43ff36c
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 19 deletions.
12 changes: 3 additions & 9 deletions pyomo/contrib/pynumero/interfaces/pyomo_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pyomo.core.base as pyo
from pyomo.common.collections import ComponentMap
from pyomo.common.env import CtypesEnviron
from pyomo.solvers.amplfunc_merge import amplfunc_merge
from ..sparse.block_matrix import BlockMatrix
from pyomo.contrib.pynumero.interfaces.ampl_nlp import AslNLP
from pyomo.contrib.pynumero.interfaces.nlp import NLP
Expand Down Expand Up @@ -92,15 +93,8 @@ def __init__(self, pyomo_model, nl_file_options=None):
# The NL writer advertises the external function libraries
# through the PYOMO_AMPLFUNC environment variable; merge it
# with any preexisting AMPLFUNC definitions
amplfunc = "\n".join(
filter(
None,
(
os.environ.get('AMPLFUNC', None),
os.environ.get('PYOMO_AMPLFUNC', None),
),
)
)
amplfunc = amplfunc_merge(os.environ)

with CtypesEnviron(AMPLFUNC=amplfunc):
super(PyomoNLP, self).__init__(nl_file)

Expand Down
32 changes: 32 additions & 0 deletions pyomo/solvers/amplfunc_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# ___________________________________________________________________________
#
# Pyomo: Python Optimization Modeling Objects
# Copyright (c) 2008-2024
# National Technology and Engineering Solutions of Sandia, LLC
# Under the terms of Contract DE-NA0003525 with National Technology and
# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
# rights in this software.
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________


def amplfunc_string_merge(amplfunc, pyomo_amplfunc):
"""Merge two AMPLFUNC variable strings eliminating duplicate lines"""
# Assume that the strings amplfunc and pyomo_amplfunc don't contain duplicates
# Assume that the path separator is correct for the OS so we don't need to
# worry about comparing Unix and Windows paths.
amplfunc_lines = amplfunc.split("\n")
existing = set(amplfunc_lines)
for line in pyomo_amplfunc.split("\n"):
# Skip lines we already have
if line not in existing:
amplfunc_lines.append(line)
# Remove empty lines which could happen if one or both of the strings is
# empty or there are two new lines in a row for whatever reason.
amplfunc_lines = [s for s in amplfunc_lines if s != ""]
return "\n".join(amplfunc_lines)


def amplfunc_merge(env):
"""Merge AMPLFUNC and PYOMO_AMPLFUNC in an environment var dict"""
return amplfunc_string_merge(env.get("AMPLFUNC", ""), env.get("PYOMO_AMPLFUNC", ""))
9 changes: 4 additions & 5 deletions pyomo/solvers/plugins/solvers/ASL.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pyomo.opt.solver import SystemCallSolver
from pyomo.core.kernel.block import IBlock
from pyomo.solvers.mockmip import MockMIP
from pyomo.solvers.amplfunc_merge import amplfunc_merge
from pyomo.core import TransformationFactory

import logging
Expand Down Expand Up @@ -158,11 +159,9 @@ def create_command_line(self, executable, problem_files):
# Pyomo/Pyomo) with any user-specified external function
# libraries
#
if 'PYOMO_AMPLFUNC' in env:
if 'AMPLFUNC' in env:
env['AMPLFUNC'] += "\n" + env['PYOMO_AMPLFUNC']
else:
env['AMPLFUNC'] = env['PYOMO_AMPLFUNC']
amplfunc = amplfunc_merge(env)
if amplfunc:
env['AMPLFUNC'] = amplfunc

cmd = [executable, problem_files[0], '-AMPL']
if self._timer:
Expand Down
10 changes: 5 additions & 5 deletions pyomo/solvers/plugins/solvers/IPOPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from pyomo.opt.results import SolverStatus, SolverResults, TerminationCondition
from pyomo.opt.solver import SystemCallSolver

from pyomo.solvers.amplfunc_merge import amplfunc_merge

import logging

logger = logging.getLogger('pyomo.solvers')
Expand Down Expand Up @@ -119,11 +121,9 @@ def create_command_line(self, executable, problem_files):
# Pyomo/Pyomo) with any user-specified external function
# libraries
#
if 'PYOMO_AMPLFUNC' in env:
if 'AMPLFUNC' in env:
env['AMPLFUNC'] += "\n" + env['PYOMO_AMPLFUNC']
else:
env['AMPLFUNC'] = env['PYOMO_AMPLFUNC']
amplfunc = amplfunc_merge(env)
if amplfunc:
env['AMPLFUNC'] = amplfunc

cmd = [executable, problem_files[0], '-AMPL']
if self._timer:
Expand Down
162 changes: 162 additions & 0 deletions pyomo/solvers/tests/checks/test_amplfunc_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# ___________________________________________________________________________
#
# Pyomo: Python Optimization Modeling Objects
# Copyright (c) 2008-2024
# National Technology and Engineering Solutions of Sandia, LLC
# Under the terms of Contract DE-NA0003525 with National Technology and
# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
# rights in this software.
# This software is distributed under the 3-clause BSD License.
# ___________________________________________________________________________

import pyomo.common.unittest as unittest
from pyomo.solvers.amplfunc_merge import amplfunc_string_merge, amplfunc_merge


class TestAMPLFUNCStringMerge(unittest.TestCase):
def test_merge_no_dup(self):
s1 = "my/place/l1.so\nanother/place/l1.so"
s2 = "my/place/l2.so"
sm = amplfunc_string_merge(s1, s2)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 3)
# The order of lines should be maintained with the second string
# following the first
self.assertEqual(sm_list[0], "my/place/l1.so")
self.assertEqual(sm_list[1], "another/place/l1.so")
self.assertEqual(sm_list[2], "my/place/l2.so")

def test_merge_empty1(self):
s1 = ""
s2 = "my/place/l2.so"
sm = amplfunc_string_merge(s1, s2)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 1)
self.assertEqual(sm_list[0], "my/place/l2.so")

def test_merge_empty2(self):
s1 = "my/place/l2.so"
s2 = ""
sm = amplfunc_string_merge(s1, s2)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 1)
self.assertEqual(sm_list[0], "my/place/l2.so")

def test_merge_empty_both(self):
s1 = ""
s2 = ""
sm = amplfunc_string_merge(s1, s2)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 1)
self.assertEqual(sm_list[0], "")

def test_merge_bad_type(self):
self.assertRaises(AttributeError, amplfunc_string_merge, "", 3)
self.assertRaises(AttributeError, amplfunc_string_merge, 3, "")
self.assertRaises(AttributeError, amplfunc_string_merge, 3, 3)
self.assertRaises(AttributeError, amplfunc_string_merge, None, "")
self.assertRaises(AttributeError, amplfunc_string_merge, "", None)
self.assertRaises(AttributeError, amplfunc_string_merge, 2.3, "")
self.assertRaises(AttributeError, amplfunc_string_merge, "", 2.3)

def test_merge_duplicate1(self):
s1 = "my/place/l1.so\nanother/place/l1.so"
s2 = "my/place/l1.so\nanother/place/l1.so"
sm = amplfunc_string_merge(s1, s2)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 2)
# The order of lines should be maintained with the second string
# following the first
self.assertEqual(sm_list[0], "my/place/l1.so")
self.assertEqual(sm_list[1], "another/place/l1.so")

def test_merge_duplicate2(self):
s1 = "my/place/l1.so\nanother/place/l1.so"
s2 = "my/place/l1.so"
sm = amplfunc_string_merge(s1, s2)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 2)
# The order of lines should be maintained with the second string
# following the first
self.assertEqual(sm_list[0], "my/place/l1.so")
self.assertEqual(sm_list[1], "another/place/l1.so")

def test_merge_extra_linebreaks(self):
s1 = "\nmy/place/l1.so\nanother/place/l1.so\n"
s2 = "\nmy/place/l1.so\n\n"
sm = amplfunc_string_merge(s1, s2)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 2)
# The order of lines should be maintained with the second string
# following the first
self.assertEqual(sm_list[0], "my/place/l1.so")
self.assertEqual(sm_list[1], "another/place/l1.so")


class TestAMPLFUNCMerge(unittest.TestCase):
def test_merge_no_dup(self):
env = {
"AMPLFUNC": "my/place/l1.so\nanother/place/l1.so",
"PYOMO_AMPLFUNC": "my/place/l2.so",
}
sm = amplfunc_merge(env)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 3)
self.assertEqual(sm_list[0], "my/place/l1.so")
self.assertEqual(sm_list[1], "another/place/l1.so")
self.assertEqual(sm_list[2], "my/place/l2.so")

def test_merge_empty1(self):
env = {"AMPLFUNC": "", "PYOMO_AMPLFUNC": "my/place/l2.so"}
sm = amplfunc_merge(env)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 1)
self.assertEqual(sm_list[0], "my/place/l2.so")

def test_merge_empty2(self):
env = {"AMPLFUNC": "my/place/l2.so", "PYOMO_AMPLFUNC": ""}
sm = amplfunc_merge(env)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 1)
self.assertEqual(sm_list[0], "my/place/l2.so")

def test_merge_empty_both(self):
env = {"AMPLFUNC": "", "PYOMO_AMPLFUNC": ""}
sm = amplfunc_merge(env)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 1)
self.assertEqual(sm_list[0], "")

def test_merge_duplicate1(self):
env = {
"AMPLFUNC": "my/place/l1.so\nanother/place/l1.so",
"PYOMO_AMPLFUNC": "my/place/l1.so\nanother/place/l1.so",
}
sm = amplfunc_merge(env)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 2)
self.assertEqual(sm_list[0], "my/place/l1.so")
self.assertEqual(sm_list[1], "another/place/l1.so")

def test_merge_no_pyomo(self):
env = {"AMPLFUNC": "my/place/l1.so\nanother/place/l1.so"}
sm = amplfunc_merge(env)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 2)
self.assertEqual(sm_list[0], "my/place/l1.so")
self.assertEqual(sm_list[1], "another/place/l1.so")

def test_merge_no_user(self):
env = {"PYOMO_AMPLFUNC": "my/place/l1.so\nanother/place/l1.so"}
sm = amplfunc_merge(env)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 2)
self.assertEqual(sm_list[0], "my/place/l1.so")
self.assertEqual(sm_list[1], "another/place/l1.so")

def test_merge_nothing(self):
env = {}
sm = amplfunc_merge(env)
sm_list = sm.split("\n")
self.assertEqual(len(sm_list), 1)
self.assertEqual(sm_list[0], "")

0 comments on commit 43ff36c

Please sign in to comment.