Skip to content

Commit

Permalink
Fixed minor trace plotting bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Christopher Fonnesbeck committed Feb 19, 2013
1 parent c052006 commit f99ee9c
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions pymc/Matplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def wrapper(pymc_obj, *args, **kwargs):


@plotwrapper
def plot(data, name, format='png', suffix='', path='./', common_scale=True, datarange=(None, None),
def plot(data, name, format='png', suffix='', path='./', common_scale=True, datarange=(None, None),
new=True, last=True, rows=1, num=1, fontmap=None, verbose=1):
"""
Generates summary plots for nodes of a given PyMC object.
Expand Down Expand Up @@ -438,7 +438,7 @@ def plot(data, name, format='png', suffix='', path='./', common_scale=True, data
_sqrt_choice = lambda n: sqrt(n)

@plotwrapper
def histogram(data, name, bins='sturges', datarange=(None, None), format='png', suffix='', path='./', rows=1,
def histogram(data, name, bins='sturges', datarange=(None, None), format='png', suffix='', path='./', rows=1,
columns=1, num=1, last=True, fontmap = None, verbose=1):
"""
Generates histogram from an array of data.
Expand All @@ -449,11 +449,11 @@ def histogram(data, name, bins='sturges', datarange=(None, None), format='png',
name: string
The name of the histogram.
bins: int or string
The number of bins, or a preferred binning method. Available methods include
'doanes', 'sturges' and 'sqrt' (defaults to 'doanes').
datarange: tuple or list
Preferred range of histogram (defaults to (None,None)).
Expand All @@ -465,11 +465,11 @@ def histogram(data, name, bins='sturges', datarange=(None, None), format='png',
path (optional): string
Specifies location for saving plots (defaults to local directory).
fontmap (optional): dict
Font map for plot.
"""


# Internal histogram specification for handling nested arrays
try:
Expand Down Expand Up @@ -499,8 +499,8 @@ def histogram(data, name, bins='sturges', datarange=(None, None), format='png',
bins = bins
else:
raise ValueError('Invalid bins argument in histogram')
if isnan(bins):

if isnan(bins):
bins = uniquevals*(uniquevals<=25) or int(4 + 1.5*log(len(data)))
print_('Bins could not be calculated using selected method. Setting bins to %i.' % bins)

Expand Down Expand Up @@ -540,7 +540,7 @@ def histogram(data, name, bins='sturges', datarange=(None, None), format='png',


@plotwrapper
def trace(data, name, format='png', datarange=(None, None), suffix='', path='./', rows=1, columns=1,
def trace(data, name, format='png', datarange=(None, None), suffix='', path='./', rows=1, columns=1,
num=1, last=True, fontmap = None, verbose=1):
"""
Generates trace plot from an array of data.
Expand All @@ -551,7 +551,7 @@ def trace(data, name, format='png', datarange=(None, None), suffix='', path='./'
name: string
The name of the trace.
datarange: tuple or list
Preferred y-range of trace (defaults to (None,None)).
Expand All @@ -563,7 +563,7 @@ def trace(data, name, format='png', datarange=(None, None), suffix='', path='./'
path (optional): string
Specifies location for saving plots (defaults to local directory).
fontmap (optional): dict
Font map for plot.
Expand All @@ -588,10 +588,10 @@ def trace(data, name, format='png', datarange=(None, None), suffix='', path='./'

# Smaller tick labels
tlabels = gca().get_xticklabels()
setp(tlabels, 'fontsize', fontmap[rows/2])
setp(tlabels, 'fontsize', fontmap[max(rows/2,1)])

tlabels = gca().get_yticklabels()
setp(tlabels, 'fontsize', fontmap[rows/2])
setp(tlabels, 'fontsize', fontmap[max(rows/2,1)])

if standalone:
if not os.path.exists(path):
Expand All @@ -603,7 +603,7 @@ def trace(data, name, format='png', datarange=(None, None), suffix='', path='./'
#close()

@plotwrapper
def geweke_plot(data, name, format='png', suffix='-diagnostic', path='./', fontmap = None,
def geweke_plot(data, name, format='png', suffix='-diagnostic', path='./', fontmap = None,
verbose=1):
# Generate Geweke (1992) diagnostic plots

Expand Down Expand Up @@ -635,7 +635,7 @@ def geweke_plot(data, name, format='png', suffix='-diagnostic', path='./', fontm
#close()

@plotwrapper
def discrepancy_plot(data, name='discrepancy', report_p=True, format='png', suffix='-gof', path='./',
def discrepancy_plot(data, name='discrepancy', report_p=True, format='png', suffix='-gof', path='./',
fontmap = None, verbose=1):
# Generate goodness-of-fit deviate scatter plot

Expand Down Expand Up @@ -679,18 +679,18 @@ def discrepancy_plot(data, name='discrepancy', report_p=True, format='png', suff
savefig("%s%s%s.%s" % (path, name, suffix, format))
#close()

def gof_plot(simdata, trueval, name=None, bins=None, format='png', suffix='-gof', path='./',
def gof_plot(simdata, trueval, name=None, bins=None, format='png', suffix='-gof', path='./',
fontmap = None, verbose=1):
"""
Plots histogram of replicated data, indicating the location of the observed data
:Arguments:
simdata: array or PyMC object
Trace of simulated data or the PyMC stochastic object containing trace.
trueval: numeric
True (observed) value of the data
bins: int or string
The number of bins, or a preferred binning method. Available methods include
'doanes', 'sturges' and 'sqrt' (defaults to 'doanes').
Expand All @@ -703,12 +703,12 @@ def gof_plot(simdata, trueval, name=None, bins=None, format='png', suffix='-gof'
path (optional): string
Specifies location for saving plots (defaults to local directory).
fontmap (optional): dict
Font map for plot.
"""


if fontmap is None: fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}
try:
Expand Down Expand Up @@ -741,7 +741,7 @@ def gof_plot(simdata, trueval, name=None, bins=None, format='png', suffix='-gof'
bins = bins
else:
raise ValueError('Invalid bins argument in gof_plot')


# Generate histogram
hist(simdata, bins)
Expand Down Expand Up @@ -769,7 +769,7 @@ def gof_plot(simdata, trueval, name=None, bins=None, format='png', suffix='-gof'
#close()

@plotwrapper
def autocorrelation(data, name, maxlags=100, format='png', suffix='-acf', path='./',
def autocorrelation(data, name, maxlags=100, format='png', suffix='-acf', path='./',
fontmap = None, new=True, last=True, rows=1, columns=1, num=1, verbose=1):
"""
Generate bar plot of the autocorrelation function for a series (usually an MCMC trace).
Expand Down Expand Up @@ -860,7 +860,7 @@ def zplot(pvalue_dict, name='', format='png', path='./', fontmap = None, verbose
print_('\nGenerating model validation plot')

if fontmap is None: fontmap = {1:10, 2:8, 3:6, 4:5, 5:4}

x,y,labels = [],[],[]

for i,var in enumerate(pvalue_dict):
Expand Down Expand Up @@ -921,13 +921,13 @@ def var_str(name, shape):
return names


def summary_plot(pymc_obj, name='model', format='png', suffix='-summary', path='./',
alpha=0.05, quartiles=True, hpd=True, rhat=True, main=None, xlab=None, x_range=None,
def summary_plot(pymc_obj, name='model', format='png', suffix='-summary', path='./',
alpha=0.05, quartiles=True, hpd=True, rhat=True, main=None, xlab=None, x_range=None,
custom_labels=None, chain_spacing=0.05, vline_pos=0):
"""
Model summary plot
Generates a "forest plot" of 100*(1-alpha)% credible intervals for either the
Generates a "forest plot" of 100*(1-alpha)% credible intervals for either the
set of nodes in a given model, or a specified set of nodes.
:Arguments:
Expand Down Expand Up @@ -964,7 +964,7 @@ def summary_plot(pymc_obj, name='model', format='png', suffix='-summary', path=
main (optional): string
Title for main plot. Passing False results in titles being
suppressed; passing False (default) results in default titles.
xlab (optional): string
Label for x-axis. Defaults to no label
Expand Down Expand Up @@ -1167,7 +1167,7 @@ def summary_plot(pymc_obj, name='model', format='png', suffix='-summary', path=
if main is not False:
plot_title = main or str(int((1-alpha)*100)) + "% Credible Intervals"
title(plot_title)

# Add x-axis label
if xlab is not None:
xlabel(xlab)
Expand Down

0 comments on commit f99ee9c

Please sign in to comment.