diff --git a/pymc/printing.py b/pymc/printing.py index b89d342b621..03179f4e985 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -58,7 +58,7 @@ def str_for_dist( if "latex" in formatting: if print_name is not None: - print_name = r"\text{" + _latex_escape(dist.name) + "}" + print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}" op_name = ( dist.owner.op._print_name[1] @@ -67,9 +67,11 @@ def str_for_dist( ) if include_params: if print_name: - return r"${} \sim {}({})$".format(print_name, op_name, ",~".join(dist_args)) + return r"${} \sim {}({})$".format( + print_name, op_name, ",~".join([d.strip("$") for d in dist_args]) + ) else: - return r"${}({})$".format(op_name, ",~".join(dist_args)) + return r"${}({})$".format(op_name, ",~".join([d.strip("$") for d in dist_args])) else: if print_name: @@ -138,7 +140,7 @@ def str_for_potential_or_deterministic( LaTeX or plain, optionally with distribution parameter values included.""" print_name = var.name if var.name is not None else "" if "latex" in formatting: - print_name = r"\text{" + _latex_escape(print_name) + "}" + print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}" if include_params: return rf"${print_name} \sim \operatorname{{{dist_name}}}({_str_for_expression(var, formatting=formatting)})$" else: @@ -182,7 +184,7 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str: else str_for_dist(var, formatting=formatting, include_params=True) ) if "latex" in formatting: - return r"\text{" + _latex_escape(_str) + "}" + return _latex_text_format(_latex_escape(_str.strip("$"))) else: return _str @@ -215,9 +217,20 @@ def _expand(x): names = [x.name for x in parents] if "latex" in formatting: - return r"f(" + ",~".join([r"\text{" + _latex_escape(n) + "}" for n in names]) + ")" + return ( + r"f(" + + ",~".join([_latex_text_format(_latex_escape(n.strip("$"))) for n in names]) + + ")" + ) + else: + return r"f(" + ", ".join([n.strip("$") for n in names]) + ")" + + +def _latex_text_format(text: str) -> str: + if r"\operatorname{" in text: + return text else: - return r"f(" + ", ".join(names) + ")" + return r"\text{" + text + "}" def _latex_escape(text: str) -> str: diff --git a/pymc/tests/test_printing.py b/pymc/tests/test_printing.py index 5966a33a132..bcd582a3227 100644 --- a/pymc/tests/test_printing.py +++ b/pymc/tests/test_printing.py @@ -1,6 +1,6 @@ import numpy as np -from pymc import Bernoulli, Censored, Mixture +from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT from pymc.aesaraf import floatX from pymc.distributions import ( Dirichlet, @@ -130,12 +130,12 @@ def setup_class(self): r"$\text{beta} \sim \operatorname{N}(0,~10)$", r"$\text{Z} \sim \operatorname{N}(f(),~f())$", r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$", - r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\text{\$\operatorname{DiracDelta}(0)\$},~\text{\$\operatorname{Pois}(5)\$})$", + r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))$", r"$\text{w} \sim \operatorname{Dir}(\text{})$", ( r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w}," - r"~\text{\$\operatorname{MarginalMixture}(f(),~\text{\\$\operatorname{DiracDelta}(0)\\$},~\text{\\$\operatorname{Pois}(5)\\$})\$}," - r"~\text{\$\operatorname{Censored}(\text{\\$\operatorname{Bern}(0.5)\\$},~-1,~1)\$})$" + r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))," + r"~\operatorname{Censored}(\operatorname{Bern}(0.5),~-1,~1))$" ), r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$", r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$", @@ -178,3 +178,43 @@ def test_str_repr(self): assert segment in model_text else: assert text in model_text + + +def test_model_latex_repr_three_levels_model(): + with Model() as censored_model: + mu = Normal("mu", 0.0, 5.0) + sigma = HalfCauchy("sigma", 2.5) + normal_dist = Normal.dist(mu=mu, sigma=sigma) + censored_normal = Censored( + "censored_normal", normal_dist, lower=-2.0, upper=2.0, observed=[1, 0, 0.5] + ) + + latex_repr = censored_model.str_repr(formatting="latex") + expected = [ + "$$", + "\\begin{array}{rcl}", + "\\text{mu} &\\sim & \\operatorname{N}(0,~5)\\\\\\text{sigma} &\\sim & " + "\\operatorname{C^{+}}(0,~2.5)\\\\\\text{censored_normal} &\\sim & " + "\\operatorname{Censored}(\\operatorname{N}(\\text{mu},~\\text{sigma}),~-2,~2)", + "\\end{array}", + "$$", + ] + assert [line.strip() for line in latex_repr.split("\n")] == expected + + +def test_model_latex_repr_mixture_model(): + with Model() as mix_model: + w = Dirichlet("w", [1, 1]) + mix = Mixture("mix", w=w, comp_dists=[Normal.dist(0.0, 5.0), StudentT.dist(7.0)]) + + latex_repr = mix_model.str_repr(formatting="latex") + expected = [ + "$$", + "\\begin{array}{rcl}", + "\\text{w} &\\sim & " + "\\operatorname{Dir}(\\text{})\\\\\\text{mix} &\\sim & " + "\\operatorname{MarginalMixture}(\\text{w},~\\operatorname{N}(0,~5),~\\operatorname{StudentT}(7,~0,~1))", + "\\end{array}", + "$$", + ] + assert [line.strip() for line in latex_repr.split("\n")] == expected