diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GridSearchEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GridSearchEditor.java index 5717059b97..82de5aa7ba 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GridSearchEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/GridSearchEditor.java @@ -37,6 +37,8 @@ import javax.swing.table.TableRowSorter; import javax.swing.text.BadLocationException; import java.awt.*; +import java.awt.event.ActionEvent; +import java.awt.event.ActionListener; import java.io.*; import java.lang.reflect.InvocationTargetException; import java.nio.charset.StandardCharsets; @@ -184,17 +186,11 @@ public GridSearchEditor(GridSearchModel model) { * @param parameters the Parameters object containing the parameter values * @return a map of parameter names to corresponding Box components */ - public static Map createParameterComponents(Set params, Parameters parameters, - boolean listOptionAllowed, boolean bothOptionAllowed) { + public static Map createParameterComponents(Set params, Parameters parameters, boolean listOptionAllowed, boolean bothOptionAllowed) { ParamDescriptions paramDescriptions = ParamDescriptions.getInstance(); - return params.stream() - .collect(Collectors.toMap( - Function.identity(), - e -> createParameterComponent(e, parameters, paramDescriptions.get(e), listOptionAllowed, bothOptionAllowed), - (u, v) -> { - throw new IllegalStateException(String.format("Duplicate key %s.", u)); - }, - TreeMap::new)); + return params.stream().collect(Collectors.toMap(Function.identity(), e -> createParameterComponent(e, parameters, paramDescriptions.get(e), listOptionAllowed, bothOptionAllowed), (u, v) -> { + throw new IllegalStateException(String.format("Duplicate key %s.", u)); + }, TreeMap::new)); } /** @@ -216,8 +212,7 @@ public static Box[] toArray(Map parameterComponents) { } }); - return Stream.concat(otherComps.stream(), boolComps.stream()) - .toArray(Box[]::new); + return Stream.concat(otherComps.stream(), boolComps.stream()).toArray(Box[]::new); } /** @@ -228,8 +223,7 @@ public static Box[] toArray(Map parameterComponents) { * @param paramDesc the ParamDescription object containing information about the parameter * @return a Box component representing the parameter */ - private static Box createParameterComponent(String parameter, Parameters parameters, ParamDescription paramDesc, - boolean listOptionAllowed, boolean bothOptionAllowed) { + private static Box createParameterComponent(String parameter, Parameters parameters, ParamDescription paramDesc, boolean listOptionAllowed, boolean bothOptionAllowed) { JComponent component; Object defaultValue = paramDesc.getDefaultValue(); @@ -288,7 +282,7 @@ private static Box createParameterComponent(String parameter, Parameters paramet } else if (defaultValue instanceof Boolean) { component = getBooleanSelectionBox(parameter, parameters, bothOptionAllowed); } else if (defaultValue instanceof String) { - component = getStringField(parameter, parameters, (String) defaultValue); + component = createStringField(parameter, parameters, (String) defaultValue); } else { throw new IllegalArgumentException("Unexpected type: " + defaultValue.getClass()); } @@ -317,10 +311,8 @@ private static Box createParameterComponent(String parameter, Parameters paramet * @param upperBound the upperbound limit for valid input values in the DoubleTextField * @return a DoubleTextField with the specified parameters */ - public static DoubleTextField getDoubleTextField(String parameter, Parameters parameters, - double defaultValue, double lowerBound, double upperBound) { - DoubleTextField field = new DoubleTextField(defaultValue, - 8, new DecimalFormat("0.####"), new DecimalFormat("0.0#E0"), 0.001); + public static DoubleTextField getDoubleTextField(String parameter, Parameters parameters, double defaultValue, double lowerBound, double upperBound) { + DoubleTextField field = new DoubleTextField(defaultValue, 8, new DecimalFormat("0.####"), new DecimalFormat("0.0#E0"), 0.001); field.setFilter((value, oldValues) -> { if (Double.isNaN(value)) { @@ -357,10 +349,8 @@ public static DoubleTextField getDoubleTextField(String parameter, Parameters pa * @param upperBound the upper bound for the values * @return a ListDoubleTextField component with the specified parameters */ - public static ListDoubleTextField getListDoubleTextField(String parameter, Parameters parameters, - Double[] defaultValues, double lowerBound, double upperBound) { - ListDoubleTextField field = new ListDoubleTextField(defaultValues, - 8, new DecimalFormat("0.####"), new DecimalFormat("0.0#E0"), 0.001); + public static ListDoubleTextField getListDoubleTextField(String parameter, Parameters parameters, Double[] defaultValues, double lowerBound, double upperBound) { + ListDoubleTextField field = new ListDoubleTextField(defaultValues, 8, new DecimalFormat("0.####"), new DecimalFormat("0.0#E0"), 0.001); field.setFilter((values, oldValues) -> { if (values.length == 0) { @@ -413,8 +403,7 @@ public static ListDoubleTextField getListDoubleTextField(String parameter, Param * @param upperBound the upper bound for valid values * @return an IntTextField with the specified parameters */ - public static IntTextField getIntTextField(String parameter, Parameters parameters, - int defaultValue, double lowerBound, double upperBound) { + public static IntTextField getIntTextField(String parameter, Parameters parameters, int defaultValue, double lowerBound, double upperBound) { IntTextField field = new IntTextField(defaultValue, 8); field.setFilter((value, oldValue) -> { @@ -448,8 +437,7 @@ public static IntTextField getIntTextField(String parameter, Parameters paramete * @param upperBound the upper bound for the values * @return a ListIntTextField component with the specified parameters */ - public static ListIntTextField getListIntTextField(String parameter, Parameters parameters, - Integer[] defaultValues, double lowerBound, double upperBound) { + public static ListIntTextField getListIntTextField(String parameter, Parameters parameters, Integer[] defaultValues, double lowerBound, double upperBound) { ListIntTextField field = new ListIntTextField(defaultValues, 8); field.setFilter((values, oldValues) -> { @@ -499,8 +487,7 @@ public static ListIntTextField getListIntTextField(String parameter, Parameters * @param upperBound The upper bound for the LongTextField value. * @return A LongTextField object with the specified parameters. */ - public static LongTextField getLongTextField(String parameter, Parameters parameters, - long defaultValue, long lowerBound, long upperBound) { + public static LongTextField getLongTextField(String parameter, Parameters parameters, long defaultValue, long lowerBound, long upperBound) { LongTextField field = new LongTextField(defaultValue, 8); field.setFilter((value, oldValue) -> { @@ -524,8 +511,7 @@ public static LongTextField getLongTextField(String parameter, Parameters parame return field; } - public static ListLongTextField getListLongTextField(String parameter, Parameters parameters, - Long[] defaultValues, long lowerBound, long upperBound) { + public static ListLongTextField getListLongTextField(String parameter, Parameters parameters, Long[] defaultValues, long lowerBound, long upperBound) { ListLongTextField field = new ListLongTextField(defaultValues, 8); field.setFilter((values, oldValues) -> { @@ -771,14 +757,10 @@ private static String getParameterText(Set paramNamesSet, Parameters par sb.append(", "); } - if (values[i] instanceof Double) - sb.append(nf.format((double) values[i])); - else if (values[i] instanceof Integer) - sb.append((int) values[i]); - else if (values[i] instanceof Long) - sb.append((long) values[i]); - else - sb.append(values[i]); + if (values[i] instanceof Double) sb.append(nf.format((double) values[i])); + else if (values[i] instanceof Integer) sb.append((int) values[i]); + else if (values[i] instanceof Long) sb.append((long) values[i]); + else sb.append(values[i]); } paramText.append("\n\n- ").append(name).append(" = ").append(sb); @@ -810,8 +792,7 @@ public static void scrollToWord(JTextArea textArea, JScrollPane scrollPane, Stri * @param graphIndexComboBox The combo box to update with the graph indices. * @param resultsDir The directory where the graph results are stored. */ - private void updateAlgorithmBoxIndices(JComboBox simulationComboBox, JComboBox algorithmComboBox, - JComboBox graphIndexComboBox, File resultsDir) { + private void updateAlgorithmBoxIndices(JComboBox simulationComboBox, JComboBox algorithmComboBox, JComboBox graphIndexComboBox, File resultsDir) { int savedAlgorithm = model.getSelectedAlgorithm(); Object selectedSimulation = simulationComboBox.getSelectedItem(); @@ -877,8 +858,7 @@ private void updateAlgorithmBoxIndices(JComboBox simulationComboBox, JC * @param graphIndexComboBox The combo box to update with the graph indices. * @param resultsDir The directory where the graph results are stored. */ - private void updateGraphBoxIndices(JComboBox simulationComboBox, JComboBox algorithmComboBox, - JComboBox graphIndexComboBox, File resultsDir) { + private void updateGraphBoxIndices(JComboBox simulationComboBox, JComboBox algorithmComboBox, JComboBox graphIndexComboBox, File resultsDir) { int savedGraphIndex = model.getSelectedGraphIndex(); Object selectedSimulation = simulationComboBox.getSelectedItem(); @@ -935,9 +915,7 @@ private void updateGraphBoxIndices(JComboBox simulationComboBox, JCombo } } - private void updateSelectedGraph(JComboBox simulationComboBox, JComboBox algorithmComboBox, - JComboBox graphIndexComboBox, File resultsDir, - GraphWorkbench workbench) { + private void updateSelectedGraph(JComboBox simulationComboBox, JComboBox algorithmComboBox, JComboBox graphIndexComboBox, File resultsDir, GraphWorkbench workbench) { Object selectedSimulation = simulationComboBox.getSelectedItem(); Object selectedAlgorithm = algorithmComboBox.getSelectedItem(); Object selectedGraphIndex = graphIndexComboBox.getSelectedItem(); @@ -983,10 +961,7 @@ private void refreshGraphSelectionContent(JTabbedPane tabbedPane) { * @throws IllegalAccessException If the graph or simulation constructor or class is inaccessible. */ @NotNull - private edu.cmu.tetrad.algcomparison.simulation.Simulation getSimulation( - Class graphClazz, - Class simulationClazz) - throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { + private edu.cmu.tetrad.algcomparison.simulation.Simulation getSimulation(Class graphClazz, Class simulationClazz) throws NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { RandomGraph randomGraph; if (graphClazz == SingleGraph.class) { @@ -1253,8 +1228,7 @@ private void addAlgorithmTab(JTabbedPane tabbedPane) { Set allBootstrapParameters = GridSearchModel.getAllBootstrapParameters(algorithms); Set allScoreParameters = GridSearchModel.getAllScoreParameters(algorithms); - if (allAlgorithmParameters.isEmpty() && allTestParameters.isEmpty() && allBootstrapParameters.isEmpty() - && allScoreParameters.isEmpty()) { + if (allAlgorithmParameters.isEmpty() && allTestParameters.isEmpty() && allBootstrapParameters.isEmpty() && allScoreParameters.isEmpty()) { JLabel noParamLbl = NO_PARAM_LBL; noParamLbl.setBorder(new EmptyBorder(10, 10, 10, 10)); tabbedPane1.addTab("No Parameters", new PaddingPanel(noParamLbl)); @@ -1315,8 +1289,7 @@ private Box getParameterBox(Set params, boolean listOptionAllowed, boole parameterBox.add(noParamLbl, BorderLayout.NORTH); } else { Box parameters = Box.createVerticalBox(); - Box[] paramBoxes = ParameterComponents.toArray( - createParameterComponents(params, model.getParameters(), listOptionAllowed, bothOptionAllowed)); + Box[] paramBoxes = ParameterComponents.toArray(createParameterComponents(params, model.getParameters(), listOptionAllowed, bothOptionAllowed)); int lastIndex = paramBoxes.length - 1; for (int i = 0; i < lastIndex; i++) { parameters.add(paramBoxes[i]); @@ -1374,8 +1347,7 @@ private void addTableColumnsTab(JTabbedPane tabbedPane) { JPanel tableColumnsChoice = new JPanel(); tableColumnsChoice.setLayout(new BorderLayout()); - tableColumnsChoice.add(new JScrollPane(tableColumnsChoiceTextArea, JScrollPane.VERTICAL_SCROLLBAR_AS_NEEDED, - JScrollPane.HORIZONTAL_SCROLLBAR_AS_NEEDED), BorderLayout.CENTER); + tableColumnsChoice.add(new JScrollPane(tableColumnsChoiceTextArea, JScrollPane.VERTICAL_SCROLLBAR_AS_NEEDED, JScrollPane.HORIZONTAL_SCROLLBAR_AS_NEEDED), BorderLayout.CENTER); tableColumnsChoice.add(tableColumnsSelectionBox, BorderLayout.SOUTH); tabbedPane.addTab("Table Columns", tableColumnsChoice); @@ -1392,11 +1364,7 @@ private void addTableColumnsTab(JTabbedPane tabbedPane) { Parameters.serializableInstance().remove("algcomparison." + column.getColumnName()); - ParamDescriptions.getInstance().put("algcomparison." + column.getColumnName(), - new ParamDescription("algcomparison." + column.getColumnName(), - "Utility for " + column.getColumnName() + " in [0, 1]", - "Utility for " + column.getColumnName(), - weight, 0.0, 1.0)); + ParamDescriptions.getInstance().put("algcomparison." + column.getColumnName(), new ParamDescription("algcomparison." + column.getColumnName(), "Utility for " + column.getColumnName() + " in [0, 1]", "Utility for " + column.getColumnName(), weight, 0.0, 1.0)); } Box parameterBox = getParameterBox(params, false, false); @@ -1491,9 +1459,7 @@ private void addComparisonTab(JTabbedPane tabbedPane) { Box horiz5 = Box.createHorizontalBox(); horiz5.add(new JLabel("Parallelism:")); horiz5.add(Box.createHorizontalGlue()); - horiz5.add(getIntTextField("algcomparisonParallelism", model.getParameters(), - model.getParameters().getInt("algcomparisonParallelism", Runtime.getRuntime().availableProcessors()), - 1, 1000)); + horiz5.add(getIntTextField("algcomparisonParallelism", model.getParameters(), model.getParameters().getInt("algcomparisonParallelism", Runtime.getRuntime().availableProcessors()), 1, 1000)); Box horiz6 = Box.createHorizontalBox(); horiz6.add(new JLabel("Comparison Graph Type:")); @@ -1501,11 +1467,49 @@ private void addComparisonTab(JTabbedPane tabbedPane) { JComboBox comparisonGraphTypeComboBox = new JComboBox<>(); Box horiz7 = Box.createHorizontalBox(); - horiz7.add(new JLabel("Markov Checker Alpha:")); + horiz7.add(new JLabel("Markov Checker:")); horiz7.add(Box.createHorizontalGlue()); - horiz7.add(getDoubleTextField(Params.MC_ALPHA, model.getParameters(), - model.getParameters().getDouble(Params.MC_ALPHA), - 0.0, 1.0)); + JButton chooseTest = new JButton("Choose Test"); + + chooseTest.addActionListener(e2 -> { + JComboBox comboBox = new JComboBox<>(); + populateTestTypes(comboBox); + + JOptionPane dialog = new JOptionPane(comboBox, JOptionPane.PLAIN_MESSAGE); + dialog.createDialog("Choose Markov Checker Test)").setVisible(true); + }); + + horiz7.add(chooseTest); + + Box horiz8 = Box.createHorizontalBox(); + horiz8.add(new JLabel("Markov Checker:")); + horiz8.add(Box.createHorizontalGlue()); + JButton configureMarkovChecker = new JButton("Params"); + + configureMarkovChecker.addActionListener(new ActionListener() { + @Override + public void actionPerformed(ActionEvent e) { + JPanel independenceWrapperParamsPanel = createIndependenceWrapperParamsPanel(model.getParameters()); + JOptionPane dialog = new JOptionPane(independenceWrapperParamsPanel, JOptionPane.PLAIN_MESSAGE); + dialog.createDialog("Set Parameters").setVisible(true); + +// setTest(); + } + + private JPanel createParamsPanel(GridSearchModel model, Parameters parameters) { + JPanel panel = new JPanel(); + panel.add(createParamsPanel(model, parameters)); + return panel; + } + }); + + + horiz7.add(configureMarkovChecker); + + +// horiz8.add(getDoubleTextField(Params.MC_ALPHA, model.getParameters(), +// model.getParameters().getDouble(Params.MC_ALPHA), +// 0.0, 1.0)); for (GridSearchModel.ComparisonGraphType comparisonGraphType : GridSearchModel.ComparisonGraphType.values()) { comparisonGraphTypeComboBox.addItem(comparisonGraphType.toString()); @@ -2281,8 +2285,7 @@ public void changedUpdate(DocumentEvent e) { for (int i = 0; i < table.getRowCount(); i++) { GridSearchModel.MyTableColumn myTableColumn = columnSelectionTableModel.getMyTableColumn(i); - if (myTableColumn.getType() == GridSearchModel.MyTableColumn.ColumnType.PARAMETER - && myTableColumn.isSetByUser()) { + if (myTableColumn.getType() == GridSearchModel.MyTableColumn.ColumnType.PARAMETER && myTableColumn.isSetByUser()) { columnSelectionTableModel.selectRow(i); } } @@ -2293,8 +2296,7 @@ public void changedUpdate(DocumentEvent e) { GridSearchModel.MyTableColumn myTableColumn = columnSelectionTableModel.getMyTableColumn(i); List lastStatisticsUsed = model.getLastStatisticsUsed(); - if (myTableColumn.getType() == GridSearchModel.MyTableColumn.ColumnType.STATISTIC - && lastStatisticsUsed.contains(myTableColumn.getColumnName())) { + if (myTableColumn.getType() == GridSearchModel.MyTableColumn.ColumnType.STATISTIC && lastStatisticsUsed.contains(myTableColumn.getColumnName())) { columnSelectionTableModel.selectRow(i); } } @@ -2333,8 +2335,7 @@ private JPanel getButtonPanel(TableColumnSelectionModel columnSelectionTableMode columnSelectionTableModel.setTableRef(null); SwingUtilities.invokeLater(dialog::dispose); - List selectedTableColumns = new ArrayList<>( - columnSelectionTableModel.getSelectedTableColumns()); + List selectedTableColumns = new ArrayList<>(columnSelectionTableModel.getSelectedTableColumns()); for (GridSearchModel.MyTableColumn column : selectedTableColumns) { model.addTableColumn(column); @@ -2591,17 +2592,14 @@ private void setTableColumnsText() { * empty. Otherwise, it sets a message indicating that a comparison has not been run for the selection. */ private void setComparisonText() { - if (model.getSelectedSimulations().getSimulations().isEmpty() || model.getSelectedAlgorithms().isEmpty() - || model.getSelectedTableColumns().isEmpty()) { - comparisonTextArea.setText( - """ - ** You have made an empty selection; look back at the Simulation, Algorithm, and Table Columns tabs ** - """); + if (model.getSelectedSimulations().getSimulations().isEmpty() || model.getSelectedAlgorithms().isEmpty() || model.getSelectedTableColumns().isEmpty()) { + comparisonTextArea.setText(""" + ** You have made an empty selection; look back at the Simulation, Algorithm, and Table Columns tabs ** + """); } else if (comparisonTextArea.getText().isBlank()) { - comparisonTextArea.setText - (""" - ** Your selection is non-empty, but you have not yet run a comparison for it ** - """); + comparisonTextArea.setText(""" + ** Your selection is non-empty, but you have not yet run a comparison for it ** + """); } } @@ -2925,4 +2923,304 @@ public void write(int b) throws IOException { } } + /** + * Creates a parameters panel for the given set of parameters and Parameters object. + * + * @param params The set of parameter names. + * @param parameters The Parameters object containing the parameter values. + * @return The JPanel containing the parameters panel. + */ + public static JPanel createParamsPanel(Set params, Parameters parameters) { + JPanel panel = new JPanel(new BorderLayout()); + panel.setBorder(BorderFactory.createTitledBorder("Parameters")); + + Box paramsBox = Box.createVerticalBox(); + + Box[] boxes = toArray(createParameterComponents(params, parameters)); + int lastIndex = boxes.length - 1; + for (int i = 0; i < lastIndex; i++) { + paramsBox.add(boxes[i]); + paramsBox.add(Box.createVerticalStrut(10)); + } + paramsBox.add(boxes[lastIndex]); + + panel.add(new PaddingPanel(paramsBox), BorderLayout.CENTER); + + return panel; + } + + /** + * Creates a map of parameter components for the given set of parameters and Parameters object. + * + * @param params The set of parameter names. + * @param parameters The Parameters object containing the parameter values. + * @return A map of parameter names to Box components. + */ + private static Map createParameterComponents(Set params, Parameters parameters) { + ParamDescriptions paramDescriptions = ParamDescriptions.getInstance(); + return params.stream() + .collect(Collectors.toMap( + Function.identity(), + e -> createParameterComponent(e, parameters, paramDescriptions.get(e)), + (u, v) -> { + throw new IllegalStateException(String.format("Duplicate key %s.", u)); + }, + TreeMap::new)); + } + + /** + * Creates a parameter component based on the given parameter, Parameters, and ParamDescription. + * + * @param parameter The name of the parameter. + * @param parameters The Parameters object containing the parameter values. + * @param paramDesc The ParamDescription object with information about the parameter. + * @return A Box component representing the parameter component. + * @throws IllegalArgumentException If the default value type is unexpected. + */ + private static Box createParameterComponent(String parameter, Parameters parameters, ParamDescription paramDesc) { + JComponent component; + Object defaultValue = paramDesc.getDefaultValue(); + if (defaultValue instanceof Double) { + double lowerBoundDouble = paramDesc.getLowerBoundDouble(); + double upperBoundDouble = paramDesc.getUpperBoundDouble(); + component = getDoubleField(parameter, parameters, (Double) defaultValue, lowerBoundDouble, upperBoundDouble); + } else if (defaultValue instanceof Integer) { + int lowerBoundInt = paramDesc.getLowerBoundInt(); + int upperBoundInt = paramDesc.getUpperBoundInt(); + component = getIntTextField(parameter, parameters, (Integer) defaultValue, lowerBoundInt, upperBoundInt); + } else if (defaultValue instanceof Long) { + long lowerBoundLong = paramDesc.getLowerBoundLong(); + long upperBoundLong = paramDesc.getUpperBoundLong(); + component = createLongTextField(parameter, parameters, (Long) defaultValue, lowerBoundLong, upperBoundLong); + } else if (defaultValue instanceof Boolean) { + component = createBooleanSelectionBox(parameter, parameters, (Boolean) defaultValue); + } else if (defaultValue instanceof String) { + component = getStringField(parameter, parameters, (String) defaultValue); + } else { + throw new IllegalArgumentException("Unexpected type: " + defaultValue.getClass()); + } + + Box paramRow = Box.createHorizontalBox(); + + JLabel paramLabel = new JLabel(paramDesc.getShortDescription()); + String longDescription = paramDesc.getLongDescription(); + if (longDescription != null) { + paramLabel.setToolTipText(longDescription); + } + paramRow.add(paramLabel); + paramRow.add(Box.createHorizontalGlue()); + paramRow.add(component); + + return paramRow; + } + + /** + * Returns a DoubleTextField with specified parameters. + * + * @param parameter The name of the parameter. + * @param parameters The Parameters object containing the parameter values. + * @param defaultValue The default value for the DoubleTextField. + * @param lowerBound The lower bound for valid values. + * @param upperBound The upper bound for valid values. + * @return A DoubleTextField with the specified parameters. + */ + private static DoubleTextField getDoubleField(String parameter, Parameters parameters, + double defaultValue, double lowerBound, double upperBound) { + return ParameterComponents.getDoubleField(parameter, parameters, defaultValue, lowerBound, upperBound); + } + + /** + * Returns an IntTextField with the specified parameters. + * + * @param parameter The name of the parameter. + * @param parameters The Parameters object containing the parameter values. + * @param defaultValue The default value for the IntTextField. + * @param lowerBound The lower bound for valid values. + * @param upperBound The upper bound for valid values. + * @return An IntTextField with the specified parameters. + */ + private static IntTextField getIntTextField(String parameter, Parameters parameters, + int defaultValue, int lowerBound, int upperBound) { + return ParameterComponents.getIntTextField(parameter, parameters, defaultValue, lowerBound, upperBound); + } + + /** + * Returns a LongTextField object with the specified parameters. + * + * @param parameter The name of the parameter. + * @param parameters The Parameters object containing the parameter values. + * @param defaultValue The default value for the LongTextField. + * @param lowerBound The lower bound for valid values. + * @param upperBound The upper bound for valid values. + * @return A LongTextField object with the specified parameters. + */ + private static LongTextField createLongTextField(String parameter, Parameters parameters, + long defaultValue, long lowerBound, long upperBound) { + LongTextField field = new LongTextField(parameters.getLong(parameter, defaultValue), 8); + + field.setFilter((value, oldValue) -> { + if (value == field.getValue()) { + return oldValue; + } + + if (value < lowerBound) { + return oldValue; + } + + if (value > upperBound) { + return oldValue; + } + + try { + parameters.set(parameter, value); + } catch (Exception e) { + // Ignore. + } + + return value; + }); + + return field; + } + + /** + * Creates a boolean selection box with Yes and No radio buttons. + * + * @param parameter The name of the parameter. + * @param parameters The Parameters object containing the parameter values. + * @param defaultValue The default value for the boolean parameter + */ + private static Box createBooleanSelectionBox(String parameter, Parameters parameters, boolean defaultValue) { + Box selectionBox = Box.createHorizontalBox(); + + JRadioButton yesButton = new JRadioButton("Yes"); + JRadioButton noButton = new JRadioButton("No"); + + // Button group to ensure only one option can be selected + ButtonGroup selectionBtnGrp = new ButtonGroup(); + selectionBtnGrp.add(yesButton); + selectionBtnGrp.add(noButton); + + boolean aBoolean = parameters.getBoolean(parameter, defaultValue); + + // Set default selection + if (aBoolean) { + yesButton.setSelected(true); + } else { + noButton.setSelected(true); + } + + // Add to containing box + selectionBox.add(yesButton); + selectionBox.add(noButton); + + // Event listener + yesButton.addActionListener((e) -> { + JRadioButton button = (JRadioButton) e.getSource(); + if (button.isSelected()) { + parameters.set(parameter, true); + } + }); + + // Event listener + noButton.addActionListener((e) -> { + JRadioButton button = (JRadioButton) e.getSource(); + if (button.isSelected()) { + parameters.set(parameter, false); + } + }); + + return selectionBox; + } + + /** + * Returns a StringTextField object with the specified parameters. + * + * @param parameter The name of the parameter. + * @param parameters The Parameters object containing the parameter values. + * @param defaultValue The default value for the StringTextField. + * @return A StringTextField object with the specified parameters. + */ + private static StringTextField createStringField(String parameter, Parameters parameters, String defaultValue) { + return PathsAction.getStringField(parameter, parameters, defaultValue); + } + + /** + * Creates a parameters panel for the given independence wrapper and parameters. + * + * @param params The parameters for the independence test. + * @return The JPanel containing the parameters panel. + */ + private JPanel createIndependenceWrapperParamsPanel(Parameters params) { + Set testParameters = new HashSet<>(model.getMarkovCheckerIndependenceWrapper().getParameters()); + return createParamsPanel(testParameters, params); + } + + /** + * Refreshes the test list in the GUI. Retrieves the data type of the data set. Removes all items from the test + * combo box. Retrieves the independence test models for the given data type. Adds the independence test models to + * the test combo box. Disables the test combo box if there are no items. Selects the default model for the data + * type. + */ + private void populateTestTypes(JComboBox indTestJComboBox) { + indTestJComboBox.removeAllItems(); + + List models = new ArrayList<>(IndependenceTestModels.getInstance().getModels(DataType.Continuous)); + models.addAll(IndependenceTestModels.getInstance().getModels(DataType.Discrete)); + models.addAll(IndependenceTestModels.getInstance().getModels(DataType.Mixed)); + + for (IndependenceTestModel model : models) { + indTestJComboBox.addItem(model); + } + + IndependenceTestModel selectedIndependenceTestModel = this.model.getSelectedIndependenceTestModel(); + for (IndependenceTestModel model : models) { + if (model.equals(selectedIndependenceTestModel)) { + indTestJComboBox.setSelectedItem(model); + } + } + + if (selectedIndependenceTestModel == null) { + for (IndependenceTestModel model : models) { + if (model.getName().equals("Fisher Z Test")) { + this.model.setSelectedIndependenceTestModel(model); + break; + } + } + } + + indTestJComboBox.addItemListener(e -> { + IndependenceTestModel item = (IndependenceTestModel) e.getItem(); + this.model.setSelectedIndependenceTestModel(item); + Class clazz = (item == null) ? null + : (Class) item.getIndependenceTest().clazz(); + + if (clazz != null) { + try { + IndependenceWrapper independenceWrapper = clazz.getDeclaredConstructor(new Class[0]).newInstance(); + model.setMarkovCheckerIndependenceWrapper(independenceWrapper); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException + | NoSuchMethodException e1) { + TetradLogger.getInstance().log("Error: " + e1.getMessage()); + throw new RuntimeException(e1); + } + } + }); + +// indTestJComboBox.setSelectedItem(selectedIndependenceTestModel); + + Class clazz = (selectedIndependenceTestModel == null) ? null + : (Class) selectedIndependenceTestModel.getIndependenceTest().clazz(); + + if (clazz != null) { + try { + IndependenceWrapper independenceWrapper = clazz.getDeclaredConstructor(new Class[0]).newInstance(); + model.setMarkovCheckerIndependenceWrapper(independenceWrapper); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException + | NoSuchMethodException e1) { + TetradLogger.getInstance().log("Error: " + e1.getMessage()); + throw new RuntimeException(e1); + } + } + } } diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java index 89cf4dd14a..2c75820172 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/editor/MarkovCheckEditor.java @@ -1488,7 +1488,7 @@ private void refreshTestList() { * @param params The parameters for the independence test. * @return The JPanel containing the parameters panel. */ - private JPanel createParamsPanel(IndependenceWrapper independenceWrapper, Parameters params) { + private static JPanel createParamsPanel(IndependenceWrapper independenceWrapper, Parameters params) { Set testParameters = new HashSet<>(independenceWrapper.getParameters()); return createParamsPanel(testParameters, params); } @@ -1500,7 +1500,7 @@ private JPanel createParamsPanel(IndependenceWrapper independenceWrapper, Parame * @param parameters The Parameters object containing the parameter values. * @return The JPanel containing the parameters panel. */ - private JPanel createParamsPanel(Set params, Parameters parameters) { + public static JPanel createParamsPanel(Set params, Parameters parameters) { JPanel panel = new JPanel(new BorderLayout()); panel.setBorder(BorderFactory.createTitledBorder("Parameters")); @@ -1526,7 +1526,7 @@ private JPanel createParamsPanel(Set params, Parameters parameters) { * @param parameters The Parameters object containing the parameter values. * @return A map of parameter names to Box components. */ - private Map createParameterComponents(Set params, Parameters parameters) { + private static Map createParameterComponents(Set params, Parameters parameters) { ParamDescriptions paramDescriptions = ParamDescriptions.getInstance(); return params.stream() .collect(Collectors.toMap( @@ -1547,7 +1547,7 @@ private Map createParameterComponents(Set params, Parameter * @return A Box component representing the parameter component. * @throws IllegalArgumentException If the default value type is unexpected. */ - private Box createParameterComponent(String parameter, Parameters parameters, ParamDescription paramDesc) { + private static Box createParameterComponent(String parameter, Parameters parameters, ParamDescription paramDesc) { JComponent component; Object defaultValue = paramDesc.getDefaultValue(); if (defaultValue instanceof Double) { @@ -1594,34 +1594,9 @@ private Box createParameterComponent(String parameter, Parameters parameters, Pa * @param upperBound The upper bound for valid values. * @return A DoubleTextField with the specified parameters. */ - private DoubleTextField getDoubleField(String parameter, Parameters parameters, + private static DoubleTextField getDoubleField(String parameter, Parameters parameters, double defaultValue, double lowerBound, double upperBound) { - DoubleTextField field = new DoubleTextField(parameters.getDouble(parameter, defaultValue), - 8, new DecimalFormat("0.####"), new DecimalFormat("0.0#E0"), 0.001); - - field.setFilter((value, oldValue) -> { - if (value == field.getValue()) { - return oldValue; - } - - if (value < lowerBound) { - return oldValue; - } - - if (value > upperBound) { - return oldValue; - } - - try { - parameters.set(parameter, value); - } catch (Exception e) { - // Ignore. - } - - return value; - }); - - return field; + return ParameterComponents.getDoubleField(parameter, parameters, defaultValue, lowerBound, upperBound); } /** @@ -1634,33 +1609,9 @@ private DoubleTextField getDoubleField(String parameter, Parameters parameters, * @param upperBound The upper bound for valid values. * @return An IntTextField with the specified parameters. */ - private IntTextField getIntTextField(String parameter, Parameters parameters, + private static IntTextField getIntTextField(String parameter, Parameters parameters, int defaultValue, double lowerBound, double upperBound) { - IntTextField field = new IntTextField(parameters.getInt(parameter, defaultValue), 8); - - field.setFilter((value, oldValue) -> { - if (value == field.getValue()) { - return oldValue; - } - - if (value < lowerBound) { - return oldValue; - } - - if (value > upperBound) { - return oldValue; - } - - try { - parameters.set(parameter, value); - } catch (Exception e) { - // Ignore. - } - - return value; - }); - - return field; + return ParameterComponents.getIntTextField(parameter, parameters, defaultValue, lowerBound, upperBound); } /** @@ -1673,7 +1624,7 @@ private IntTextField getIntTextField(String parameter, Parameters parameters, * @param upperBound The upper bound for valid values. * @return A LongTextField object with the specified parameters. */ - private LongTextField getLongTextField(String parameter, Parameters parameters, + private static LongTextField getLongTextField(String parameter, Parameters parameters, long defaultValue, long lowerBound, long upperBound) { LongTextField field = new LongTextField(parameters.getLong(parameter, defaultValue), 8); @@ -1709,7 +1660,7 @@ private LongTextField getLongTextField(String parameter, Parameters parameters, * @param parameters The Parameters object containing the parameter values. * @param defaultValue The default value for the boolean parameter */ - private Box getBooleanSelectionBox(String parameter, Parameters parameters, boolean defaultValue) { + private static Box getBooleanSelectionBox(String parameter, Parameters parameters, boolean defaultValue) { Box selectionBox = Box.createHorizontalBox(); JRadioButton yesButton = new JRadioButton("Yes"); @@ -1760,24 +1711,8 @@ private Box getBooleanSelectionBox(String parameter, Parameters parameters, bool * @param defaultValue The default value for the StringTextField. * @return A StringTextField object with the specified parameters. */ - private StringTextField getStringField(String parameter, Parameters parameters, String defaultValue) { - StringTextField field = new StringTextField(parameters.getString(parameter, defaultValue), 20); - - field.setFilter((value, oldValue) -> { - if (value.equals(field.getValue().trim())) { - return oldValue; - } - - try { - parameters.set(parameter, value); - } catch (Exception e) { - // Ignore. - } - - return value; - }); - - return field; + private static StringTextField getStringField(String parameter, Parameters parameters, String defaultValue) { + return PathsAction.getStringField(parameter, parameters, defaultValue); } /** diff --git a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GridSearchModel.java b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GridSearchModel.java index ed16fdf49f..594fa2a48e 100644 --- a/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GridSearchModel.java +++ b/tetrad-gui/src/main/java/edu/cmu/tetradapp/model/GridSearchModel.java @@ -169,19 +169,24 @@ public class GridSearchModel implements SessionModel, GraphSource { */ private transient PrintStream verboseOut; /** - * A wrapper for a statistical independence test used by the Markov checker. + * A variable that holds an instance of the IndependenceWrapper implementation used for independence testing. *

- * This instance by default utilizes the Fisher-Z test for determining statistical independence. + * In this case, the implementation is `FisherZ`, which is typically employed for assessing statistical independence + * based on Fisher's Z-transformation. It serves as the primary tool for conditional independence checks in related + * algorithms or workflows. */ private IndependenceWrapper markovCheckerIndependenceWrapper = new FisherZ(); /** - * Represents the type of conditioning set used in a Markov property verification process. This variable identifies - * the specific conditioning set type in the context of the Markov checker. - *

- * The value is set to {@code ConditioningSetType.LOCAL_MARKOV} by default, indicating that the local Markov - * property is being utilized. + * Represents the type of conditioning set used in the Markov checker. The variable defines how the conditioning set + * is categorized or scoped, influencing the analysis process in probabilistic or causal models. It is initialized + * to `ConditioningSetType.LOCAL_MARKOV`, indicating that the default scope pertains to local Markovity. */ private ConditioningSetType markovCheckerConditioningSetType = ConditioningSetType.LOCAL_MARKOV; + /** + * Stores the selected independendence test model for the GridSearchEditor. It needs to be stored here in case + * the user closes the editor and re-opens it. + */ + private IndependenceTestModel selectedIndependenceTestModel = null; /** * Constructs a new GridSearchModel with the specified parameters. @@ -822,7 +827,7 @@ private List getStatisticsNamesFromImplementations(List constructor = column.getStatistic().getConstructor(IndependenceWrapper.class, ConditioningSetType.class); - Statistic statistic = constructor.newInstance(markovCheckerIndependenceWrapper, markovCheckerConditioningSetType); + Statistic statistic = constructor.newInstance(getMarkovCheckerIndependenceWrapper(), getMarkovCheckerConditioningSetType()); selectedStatistics.add(statistic); lastStatisticsUsed.add(statistic); } catch (NoSuchMethodException | InstantiationException | IllegalAccessException | @@ -995,7 +1000,7 @@ public List getAllTableColumns() { allTableColumns.add(column); } else if (MarkovCheckerStatistic.class.isAssignableFrom(statisticClass)) { Statistic _statistic = statisticClass.getConstructor(IndependenceWrapper.class, ConditioningSetType.class) - .newInstance(markovCheckerIndependenceWrapper, markovCheckerConditioningSetType); + .newInstance(getMarkovCheckerIndependenceWrapper(), getMarkovCheckerConditioningSetType()); MyTableColumn column = new MyTableColumn(_statistic.getAbbreviation(), _statistic.getDescription(), statisticClass); allTableColumns.add(column); } @@ -1221,17 +1226,39 @@ public void setSelectedGraphIndex(int selectedGraphIndex) { this.selectedGraphIndex = selectedGraphIndex; } + /** + * A wrapper for a statistical independence test used by the Markov checker. + *

+ * This instance by default utilizes the Fisher-Z test for determining statistical independence. + */ + public IndependenceWrapper getMarkovCheckerIndependenceWrapper() { + return markovCheckerIndependenceWrapper; + } + /** * Sets the Markov Checker Independence Wrapper. * - * @param markovCheckerIndependenceWrapper an instance of IndependenceWrapper to be associated - * with the Markov Checker. + * @param markovCheckerIndependenceWrapper an instance of IndependenceWrapper to be associated with the Markov + * Checker. */ public void setMarkovCheckerIndependenceWrapper(IndependenceWrapper markovCheckerIndependenceWrapper) { - if (markovCheckerIndependenceWrapper != null) { + if (markovCheckerIndependenceWrapper == null) { throw new IllegalArgumentException("markovCheckerIndependenceWrapper cannot be null"); } this.markovCheckerIndependenceWrapper = markovCheckerIndependenceWrapper; + + System.out.println("Setting independence wrapper to " + markovCheckerIndependenceWrapper); + } + + /** + * Represents the type of conditioning set used in a Markov property verification process. This variable identifies + * the specific conditioning set type in the context of the Markov checker. + *

+ * The value is set to {@code ConditioningSetType.LOCAL_MARKOV} by default, indicating that the local Markov + * property is being utilized. + */ + public ConditioningSetType getMarkovCheckerConditioningSetType() { + return markovCheckerConditioningSetType; } /** @@ -1240,12 +1267,20 @@ public void setMarkovCheckerIndependenceWrapper(IndependenceWrapper markovChecke * @param markovCheckerConditioningSetType the conditioning set type to be set for the Markov checker */ public void setMarkovCheckerConditioningSetType(ConditioningSetType markovCheckerConditioningSetType) { - if (markovCheckerConditioningSetType != null) { + if (markovCheckerConditioningSetType == null) { throw new IllegalArgumentException("markovCheckerConditioningSetType cannot be null"); } this.markovCheckerConditioningSetType = markovCheckerConditioningSetType; } + public IndependenceTestModel getSelectedIndependenceTestModel() { + return selectedIndependenceTestModel; + } + + public void setSelectedIndependenceTestModel(IndependenceTestModel selectedIndependenceTestModel) { + this.selectedIndependenceTestModel = selectedIndependenceTestModel; + } + /** * This class represents the comparison graph type for graph-based comparison algorithms. ComparisonGraphType is an * enumeration type that represents different types of comparison graphs. The available types are DAG (Directed diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index 2439e7b9f1..fb4fd9a007 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -608,76 +608,4 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnMarkovBlanket2() { System.out.println("Accepts size: " + accepts.size()); System.out.println("Rejects size: " + rejects.size()); } - - @Test - public void checkMarkovMethod() { - - DataSet data2 = null; - Knowledge knowledge = null; - try { - data2 = SimpleDataLoader.loadContinuousData(new File("/Users/josephramsey/IdeaProjects/bftools/fun_scripts/psp_c_pivot_v8.csv"), - "#", '\"', "*", true, Delimiter.COMMA, false); - - knowledge = SimpleDataLoader.loadKnowledge(new File("/Users/josephramsey/IdeaProjects/bftools/fun_scripts/mike_knowledge.txt"), - DelimiterType.WHITESPACE, "#"); - } catch (IOException e) { - throw new RuntimeException(e); - } - - for (int i = 0; i < data2.getNumRows(); i++) { - for (int j = 0; j < data2.getNumColumns(); j++) { - data2.setDouble(i, j, log(1 + data2.getDouble(i, j))); - } - } - - Graph graph = null; - try { - SemBicScore score = new SemBicScore(data2, true); - score.setPenaltyDiscount(2); - PermutationSearch search = new PermutationSearch(new Boss(score)); -// search.setKnowledge(knowledge); - graph = search.search(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - - System.out.println("Test Graph: " + graph); - - ConditioningSetType condType = ConditioningSetType.ORDERED_LOCAL_MARKOV; - MarkovCheck markovCheck = new MarkovCheck(graph, new IndTestFisherZ(data2, 0.00001), condType); - markovCheck.setKnowledge(knowledge); - markovCheck.generateResults(true); - - int numTestsH0 = markovCheck.getNumTests(true); - int numTestsH1 = markovCheck.getNumTests(false); - - double fracDepH0 = markovCheck.getFractionDependent(true); - double fracDepH1 = markovCheck.getFractionDependent(false); - - System.out.println("numTestsH0: " + numTestsH0); - System.out.println("numTestsH1: " + numTestsH1); - - System.out.println("fracDepH0: " + fracDepH0); - System.out.println("fracDepH1: " + fracDepH1); - - McGetNumTestsH0 numTestsH0_2 = new McGetNumTestsH0(new FisherZ(), condType); - McGetNumTestsH1 numTestsH1_2 = new McGetNumTestsH1(new FisherZ(), condType); - - double numTestsH0_2a = numTestsH0_2.getValue(null, graph, data2, new Parameters()); - double numTestsH1_2a = numTestsH1_2.getValue(null, graph, data2, new Parameters()); - - System.out.println("numTestsH0_2a: " + numTestsH0_2a); - System.out.println("numTestsH1_2a: " + numTestsH1_2a); - - MarkovCheckFractionDependentH0 fracDepH0_2 = new MarkovCheckFractionDependentH0(new FisherZ(), condType); - MarkovCheckFractionDependentH1 fracDepH1_2 = new MarkovCheckFractionDependentH1(new FisherZ(), condType); - - double fracDepH0_2a = fracDepH0_2.getValue(null, graph, data2, new Parameters()); - double fracDepH1_2a = fracDepH1_2.getValue(null, graph, data2, new Parameters()); - - System.out.println("fracDepH0_2a: " + fracDepH0_2a); - System.out.println("fracDepH1_2a: " + fracDepH1_2a); - - - } }