|
import numpy as np |
|
import seaborn as sns |
|
class PlotHelper: |
|
|
|
def splitVariables(self, x_vars=None, y_vars=None, max_per_row=10): |
|
slices = [] |
|
for y in y_vars: |
|
curr_y = y |
|
l = [] |
|
if max_per_row > len(x_vars): |
|
for x in x_vars: |
|
if y != x: |
|
l.append(x) |
|
slices.append([curr_y, l]) |
|
else: |
|
i = 0 |
|
for s in range(int(np.ceil(len(x_vars)/max_per_row))): |
|
l = [] |
|
for x in x_vars[i:max_per_row+i]: |
|
if y != x: |
|
l.append(x) |
|
slices.append([l,curr_y]) |
|
i += max_per_row |
|
return slices |
|
|
|
def pairplot(self, data=None, hue=None, hue_order=None, palette=None, |
|
vars=None, x_vars=None, y_vars=None, kind='scatter', |
|
diag_kind='hist', markers=None, size=2.5, aspect=1, |
|
dropna=True, plot_kws=None, diag_kws=None, |
|
grid_kws=None, wrap=True, max_per_row=None, reg_line_color=None): |
|
if kind == 'reg' and reg_line_color != None: |
|
plot_kws={'line_kws':{'color':reg_line_color}} |
|
if max_per_row == None: |
|
return sns.pairplot(data=data, hue=hue, hue_order=hue_order, palette=palette, |
|
vars=vars, x_vars=x_vars, y_vars=y_vars, kind=kind, |
|
diag_kind=diag_kind, markers=markers, size=size, aspect=aspect, |
|
dropna=dropna, plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws) |
|
else: |
|
slices = self.splitVariables(x_vars, y_vars, max_per_row) |
|
for i in range(len(slices)): |
|
sns.pairplot(data=data, hue=hue, hue_order=hue_order, palette=palette, |
|
vars=vars, x_vars=slices[i][0], y_vars=slices[i][1], kind=kind, |
|
diag_kind=diag_kind, markers=markers, size=size, aspect=aspect, |
|
dropna=dropna, plot_kws=plot_kws, diag_kws=diag_kws, grid_kws=grid_kws) |
|
|