Browse Source

Adds latexipy to make nicer looking plot for use w/ latex

Caleb Fangmeier 7 years ago
parent
commit
5a896f9374
2 changed files with 42 additions and 25 deletions
  1. 41 25
      filval/plotter.py
  2. 1 0
      requirements.txt

+ 41 - 25
filval/plotter.py

@@ -1,24 +1,45 @@
 #!/usr/bin/env python3
 from collections import namedtuple
-import matplotlib as mpl
+
 import numpy as np
+from markdown import Markdown
+import latexipy as lp
+
 from filval.histogram_utils import (hist, hist2d, hist_bin_centers, hist_fit,
                                     hist_normalize)
-# mpl.rc('text', usetex=True)
-# mpl.rc('figure', dpi=200)
-# mpl.rc('savefig', dpi=200)
+__all__ = ['make_plot',
+           'plot_registry',
+           'hist_plot',
+           'hist2d_plot']
 
 plot_registry = {}
 Plot = namedtuple('Plot', ['name', 'filename', 'title', 'desc', 'args'])
 
-
-def make_plot(filename=None, title='', scale=1):
+MD = Markdown(extensions=['mdx_math'],
+              extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
+
+lp.latexify(params={'pgf.texsystem': 'xelatex',
+                    'text.usetex': True,
+                    'font.family': 'serif',
+                    'pgf.preamble': [r'\usepackage[utf8x]{inputenc}',
+                                     r'\usepackage[T1]{fontenc}'],
+                    'font.size': 8,
+                    'axes.labelsize': 8,
+                    'axes.titlesize': 8,
+                    'legend.fontsize': 8,
+                    'xtick.labelsize': 8,
+                    'ytick.labelsize': 8,
+                    'figure.dpi': 150,
+                    'savefig.transparent': True,
+                    },
+            new_backend='TkAgg')
+
+
+def make_plot(plot_name=None, title='', scale=1, exts=['png', 'pgf']):
     import matplotlib.pyplot as plt
     from functools import wraps
     from os.path import join
-    from os import makedirs
     from inspect import signature, getdoc
-    from markdown import Markdown
 
     def fn_call_to_dict(fn, *args, **kwargs):
         pnames = list(signature(fn).parameters)
@@ -28,31 +49,26 @@ def make_plot(filename=None, title='', scale=1):
     def process_docs(fn):
         raw = getdoc(fn)
         if raw:
-            md = Markdown(extensions=['mdx_math'],
-                          extension_configs={'mdx_math': {'enable_dollar_delimiter': True}})
-            return md.convert(raw)
+            return MD.convert(raw)
         else:
             return None
 
     def wrap(fn):
         @wraps(fn)
         def f(*args, **kwargs):
-            nonlocal filename
-            plt.clf()
-            plt.gcf().set_size_inches(scale*10, scale*10)
-            fn(*args, **kwargs)
+            nonlocal plot_name
             pdict = fn_call_to_dict(fn, *args, **kwargs)
-            if filename is None:
+            if plot_name is None:
                 pstr = ','.join('{}:{}'.format(pname, pval)
                                 for pname, pval in pdict.items())
-                filename = fn.__name__ + '::' + pstr
-                filename = filename.replace('/', '_').replace('.', '_')+".png"
-            plt.tight_layout()
-            try:
-                makedirs('output/figures')
-            except FileExistsError:
-                pass
-            plt.savefig(join('output/figures', filename))
+                plot_name = fn.__name__ + '::' + pstr
+                plot_name = plot_name.replace('/', '_').replace('.', '_')
+            with lp.figure(plot_name, directory='output/figures',
+                           exts=exts,
+                           size=(scale*10, scale*10)):
+                fn(*args, **kwargs)
+                plt.tight_layout()
+            filename = plot_name+".png"
             plot_registry[fn.__name__] = Plot(fn.__name__, join('figures', filename),
                                               title, process_docs(fn), pdict)
         return f
@@ -128,7 +144,7 @@ def hist_plot(h, *args, axes=None, norm=None, include_errors=False,
         label_txt = "\n".join('{:7s}={: 0.2G}'.format(label, value)
                               for label, value in zip(arglabels, popt))
         axes.text(0.60, 0.95, label_txt, va='top', transform=axes.transAxes,
-                  fontsize='x-small', family='monospace', usetex=False)
+                  fontsize='medium', family='monospace', usetex=False)
     axes.grid()
 
 

+ 1 - 0
requirements.txt

@@ -1,5 +1,6 @@
 numpy
 matplotlib
+latexipy
 ipython
 scipy
 jupyter