diff --git a/src/codegen/codegen_neuron_cpp_visitor.cpp b/src/codegen/codegen_neuron_cpp_visitor.cpp index d1c0073f6..dc8c39266 100644 --- a/src/codegen/codegen_neuron_cpp_visitor.cpp +++ b/src/codegen/codegen_neuron_cpp_visitor.cpp @@ -290,12 +290,10 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_body( return; } const auto block_name = function_or_procedure_block->get_node_name(); - if (info.point_process) { - printer->fmt_push_block("static double _hoc_{}(void* _vptr)", block_name); - } else if (wrapper_type == InterpreterWrapper::HOC) { - printer->fmt_push_block("static void _hoc_{}(void)", block_name); + if (wrapper_type == InterpreterWrapper::HOC) { + printer->fmt_push_block("{}", hoc_function_signature(block_name)); } else { - printer->fmt_push_block("static double _npy_{}(Prop* _prop)", block_name); + printer->fmt_push_block("{}", py_function_signature(block_name)); } printer->add_multi_line(R"CODE( double _r{}; @@ -416,20 +414,24 @@ void CodegenNeuronCppVisitor::print_hoc_py_wrapper_function_definitions() { } // HOC - printer->fmt_push_block("static void {}()", hoc_function_name(name)); - printer->fmt_line("hoc_retpushx({}({}));", method_name(name), fmt::join(args, ", ")); + std::string return_statement = info.point_process ? "return _ret;" : "hoc_retpushx(_ret);"; + + printer->fmt_push_block("{}", hoc_function_signature(name)); + printer->fmt_line("double _ret = {}({});", method_name(name), fmt::join(args, ", ")); + printer->add_line(return_statement); printer->pop_block(); - printer->fmt_push_block("static void {}()", hoc_function_name(table_name)); - printer->fmt_line("hoc_retpushx({}());", method_name(table_name)); + printer->fmt_push_block("{}", hoc_function_signature(table_name)); + printer->fmt_line("double _ret = {}();", method_name(table_name)); + printer->add_line(return_statement); printer->pop_block(); // Python - printer->fmt_push_block("static double {}(Prop* _prop)", py_function_name(name)); + printer->fmt_push_block("{}", py_function_signature(name)); printer->fmt_line("return {}({});", method_name(name), fmt::join(args, ", ")); printer->pop_block(); - printer->fmt_push_block("static double {}(Prop* _prop)", py_function_name(table_name)); + printer->fmt_push_block("{}", py_function_signature(table_name)); printer->fmt_line("return {}();", method_name(table_name)); printer->pop_block(); } @@ -604,10 +606,10 @@ std::string CodegenNeuronCppVisitor::hoc_function_name( std::string CodegenNeuronCppVisitor::hoc_function_signature( const std::string& function_or_procedure_name) const { - return fmt::format("static {} {}(void{})", + return fmt::format("static {} {}({})", info.point_process ? "double" : "void", hoc_function_name(function_or_procedure_name), - info.point_process ? "*" : ""); + info.point_process ? "void * _vptr" : ""); } @@ -619,7 +621,8 @@ std::string CodegenNeuronCppVisitor::py_function_name( std::string CodegenNeuronCppVisitor::py_function_signature( const std::string& function_or_procedure_name) const { - return fmt::format("static double {}(Prop*)", py_function_name(function_or_procedure_name)); + return fmt::format("static double {}(Prop* _prop)", + py_function_name(function_or_procedure_name)); } @@ -1218,16 +1221,26 @@ void CodegenNeuronCppVisitor::print_global_variables_for_hoc() { printer->add_line("{nullptr, nullptr}"); printer->decrease_indent(); printer->add_line("};"); + + + auto print_py_callable_reg = [this](const auto& callables, auto get_name) { + for (const auto& callable: callables) { + const auto name = get_name(callable); + printer->fmt_line("{{\"{}\", {}}},", name, py_function_name(name)); + } + }; + if (!info.point_process) { printer->push_block("static NPyDirectMechFunc npy_direct_func_proc[] ="); - for (const auto& procedure: info.procedures) { - const auto proc_name = procedure->get_node_name(); - printer->fmt_line("{{\"{}\", {}}},", proc_name, py_function_name(proc_name)); - } - for (const auto& function: info.functions) { - const auto func_name = function->get_node_name(); - printer->fmt_line("{{\"{}\", {}}},", func_name, py_function_name(func_name)); - } + print_py_callable_reg(info.procedures, + [](const auto& callable) { return callable->get_node_name(); }); + print_py_callable_reg(info.functions, + [](const auto& callable) { return callable->get_node_name(); }); + print_py_callable_reg(info.function_tables, + [](const auto& callable) { return callable->get_node_name(); }); + print_py_callable_reg(info.function_tables, [](const auto& callable) { + return "table_" + callable->get_node_name(); + }); printer->add_line("{nullptr, nullptr}"); printer->pop_block(";"); } diff --git a/test/usecases/function_table/art_function_table.mod b/test/usecases/function_table/art_function_table.mod new file mode 100644 index 000000000..c330a4de7 --- /dev/null +++ b/test/usecases/function_table/art_function_table.mod @@ -0,0 +1,6 @@ +NEURON { + ARTIFICIAL_CELL art_function_table +} + +INCLUDE "function_table.inc" + diff --git a/test/usecases/function_table/function_table.inc b/test/usecases/function_table/function_table.inc new file mode 100644 index 000000000..6f842b259 --- /dev/null +++ b/test/usecases/function_table/function_table.inc @@ -0,0 +1,8 @@ +FUNCTION_TABLE cnst1(v) +FUNCTION_TABLE cnst2(v, x) +FUNCTION_TABLE tau1(v) +FUNCTION_TABLE tau2(v, x) + +FUNCTION use_tau2(v, x) { + use_tau2 = tau2(v, x) +} diff --git a/test/usecases/function_table/function_table.mod b/test/usecases/function_table/function_table.mod index 5000c147f..d5ace275b 100644 --- a/test/usecases/function_table/function_table.mod +++ b/test/usecases/function_table/function_table.mod @@ -2,11 +2,5 @@ NEURON { SUFFIX function_table } -FUNCTION_TABLE cnst1(v) -FUNCTION_TABLE cnst2(v, x) -FUNCTION_TABLE tau1(v) -FUNCTION_TABLE tau2(v, x) +INCLUDE "function_table.inc" -FUNCTION use_tau2(v, x) { - use_tau2 = tau2(v, x) -} diff --git a/test/usecases/function_table/point_function_table.mod b/test/usecases/function_table/point_function_table.mod new file mode 100644 index 000000000..5d9ffc73e --- /dev/null +++ b/test/usecases/function_table/point_function_table.mod @@ -0,0 +1,6 @@ +NEURON { + POINT_PROCESS point_function_table +} + +INCLUDE "function_table.inc" + diff --git a/test/usecases/function_table/test_function_table.py b/test/usecases/function_table/test_function_table.py index cd8dfb7ea..bb24b98ea 100644 --- a/test/usecases/function_table/test_function_table.py +++ b/test/usecases/function_table/test_function_table.py @@ -4,43 +4,80 @@ import scipy -def test_constant_1d(): +def make_callable(inst, name, mech_name): + if inst is None: + return getattr(h, f"{name}_{mech_name}") + else: + return getattr(inst, f"{name}") + + +def make_callbacks(inst, name, mech_name): + set_table = make_callable(inst, f"table_{name}", mech_name) + eval_table = make_callable(inst, name, mech_name) + + return set_table, eval_table + + +def check_constant_1d(make_inst, mech_name): s = h.Section() s.insert("function_table") + inst = make_inst(s) + set_table, eval_table = make_callbacks(inst, "cnst1", mech_name) + c = 42.0 - h.table_cnst1_function_table(c) + set_table(c) for vv in np.linspace(-10.0, 10.0, 14): - np.testing.assert_equal(h.cnst1_function_table(vv), c) + np.testing.assert_equal(eval_table(vv), c) -def test_constant_2d(): +def check_constant_2d(make_inst, mech_name): s = h.Section() s.insert("function_table") + inst = make_inst(s) + set_table, eval_table = make_callbacks(inst, "cnst2", mech_name) + c = 42.0 - h.table_cnst2_function_table(c) + set_table(c) for vv in np.linspace(-10.0, 10.0, 7): for xx in np.linspace(-20.0, 10.0, 9): - np.testing.assert_equal(h.cnst2_function_table(vv, xx), c) + np.testing.assert_equal(eval_table(vv, xx), c) + + +def check_1d(make_inst, mech_name): + s = h.Section() + s.insert("function_table") + inst = make_inst(s) + set_table, eval_table = make_callbacks(inst, "tau1", mech_name) -def test_1d(): v = np.array([0.0, 1.0]) tau1 = np.array([1.0, 2.0]) - h.table_tau1_function_table(h.Vector(tau1), h.Vector(v)) + set_table(h.Vector(tau1), h.Vector(v)) for vv in np.linspace(v[0], v[-1], 20): expected = np.interp(vv, v, tau1) - actual = h.tau1_function_table(vv) + actual = eval_table(vv) np.testing.assert_approx_equal(actual, expected, significant=11) -def test_2d(): +def check_2d(make_inst, mech_name): + s = h.Section() + s.insert("function_table") + + inst = make_inst(s) + set_table, eval_table = make_callbacks(inst, "tau2", mech_name) + eval_use_table = make_callable(inst, "use_tau2", mech_name) + + if inst is None: + setdata = getattr(h, f"setdata_{mech_name}") + setdata(s(0.5)) + v = np.array([0.0, 1.0]) x = np.array([1.0, 2.0, 3.0]) @@ -50,36 +87,33 @@ def test_2d(): hoc_tau2 = h.Matrix(*tau2.shape) hoc_tau2.from_vector(h.Vector(tau2.transpose().reshape(-1))) - h.table_tau2_function_table( - hoc_tau2._ref_x[0][0], v.size, v[0], v[-1], x.size, x[0], x[-1] - ) + set_table(hoc_tau2._ref_x[0][0], v.size, v[0], v[-1], x.size, x[0], x[-1]) for vv in np.linspace(v[0], v[-1], 20): for xx in np.linspace(x[0], x[-1], 20): expected = scipy.interpolate.interpn((v, x), tau2, (vv, xx)) - actual = h.tau2_function_table(vv, xx) + actual = eval_table(vv, xx) + actual_indirect = eval_use_table(vv, xx) np.testing.assert_approx_equal(actual, expected, significant=11) + np.testing.assert_approx_equal(actual_indirect, expected, significant=11) -def test_use_table(): - s = h.Section() - s.insert("function_table") - - h.setdata_function_table(s(0.5)) +def test_function_table(): + variations = [ + (lambda s: None, "function_table"), + (lambda s: s(0.5).function_table, "function_table"), + (lambda s: h.point_function_table(s(0.5)), "point_function_table"), + (lambda s: h.art_function_table(s(0.5)), "art_function_table"), + ] - vv, xx = 0.33, 2.24 + for make_instance, mech_name in variations: + check_constant_1d(make_instance, mech_name) + check_constant_2d(make_instance, mech_name) - expected = h.tau2_function_table(vv, xx) - actual = h.use_tau2_function_table(vv, xx) - np.testing.assert_approx_equal(actual, expected, significant=11) + check_1d(make_instance, mech_name) + check_2d(make_instance, mech_name) if __name__ == "__main__": - test_constant_1d() - test_constant_2d() - - test_1d() - test_2d() - - test_use_table() + test_function_table()