Created
August 9, 2019 18:14
-
-
Save tansey/406dec76295c15dbeca008fac8f50beb to your computer and use it in GitHub Desktop.
Heterogeneous (AKA multi-view) factor modeling in pytorch.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Heterogeneous factor modeling. | |
This model fits a heterogeneous factor model where columns may be: | |
1) Binary | |
2) Categorical | |
3) Gaussian | |
Everything is fit via alternating minimization and stochastic gradient descent. | |
The code relies on pytorch for SGD and a demo is included. | |
Author: Wesley Tansey | |
Date: 8/9/2019 | |
''' | |
import numpy as np | |
import torch | |
import torch.autograd as autograd | |
import torch.nn as nn | |
import torch.optim as optim | |
from scipy.stats import norm | |
from utils import batches | |
class HomogeneousFactorModel(nn.Module): | |
def __init__(self, X, k, row_embeddings): | |
super(HomogeneousFactorModel, self).__init__() | |
# Handle missing data | |
if np.ma.is_masked(X): | |
self.present = torch.BoolTensor((~X.mask)) | |
else: | |
self.present = torch.BoolTensor(np.ones(X.shape)) | |
def forward(self, tidx): | |
raise NotImplementedError | |
def row_mode(self): | |
raise NotImplementedError | |
def col_mode(self): | |
raise NotImplementedError | |
def probs(self, i, j, vals): | |
raise NotImplementedError | |
class GaussianFactorModel(HomogeneousFactorModel): | |
def __init__(self, X, k, row_embeddings): | |
super(GaussianFactorModel, self).__init__(X, k, row_embeddings) | |
self.row_embeddings = row_embeddings | |
self.mean_embeddings = nn.Embedding(X.shape[1], k) | |
self.std_embeddings = nn.Embedding(X.shape[1], k) | |
# self.std_embeddings = nn.Embedding.from_pretrained(torch.FloatTensor(np.random.normal(0,1/np.sqrt(k),size=(X.shape[1], k)))) | |
self.softplus = nn.Softplus() | |
self.means = torch.FloatTensor(X.mean(axis=0)) | |
self.stds = torch.FloatTensor(X.std(axis=0)) | |
self.labels = (torch.FloatTensor(X) - self.means[None]) / self.stds[None] | |
def forward(self, tidx): | |
'''Return the mean and standard deviation of the tidx entries.''' | |
return (((self.row_embeddings(tidx)[:,None] * self.mean_embeddings.weight[None]).sum(dim=2) + self.means, | |
(self.softplus(self.row_embeddings(tidx)[:,None]) * self.softplus(self.std_embeddings.weight[None])).sum(dim=2)),# + self.stds[None])), | |
self.present[tidx]) | |
def row_mode(self): | |
self.mean_embeddings.requires_grad = False | |
self.std_embeddings.requires_grad = False | |
def col_mode(self): | |
self.mean_embeddings.requires_grad = True | |
self.std_embeddings.requires_grad = True | |
def prob(self, i, j, vals): | |
# Standardize | |
vals = (vals - self.means.data.numpy()[j]) / self.stds.data.numpy()[j] | |
# Get the mean embedding | |
r = self.row_embeddings.weight.data.numpy()[i] | |
mu = r.dot(self.mean_embeddings.weight.data.numpy()[j]) | |
# std_c = self.std_embeddings.weight.data.numpy()[j] | |
# std_offset = self.stds.data.numpy()[j] | |
# std = np.log1p(np.exp(r.dot(std_c + std_offset))) | |
std = 1 # Assume standard normal after standardizing | |
return norm.pdf(vals, mu, scale=std) | |
class BinaryFactorModel(HomogeneousFactorModel): | |
def __init__(self, X, k, row_embeddings): | |
super(BinaryFactorModel, self).__init__(X, k, row_embeddings) | |
self.row_embeddings = row_embeddings | |
self.col_embeddings = nn.Embedding(X.shape[1], k) | |
self.labels = torch.FloatTensor(X) | |
def forward(self, tidx): | |
'''Return the logits for the tidx entries.''' | |
return ((self.row_embeddings(tidx)[:,None] * self.col_embeddings.weight[None]).sum(dim=2), | |
self.present[tidx]) | |
def row_mode(self): | |
self.col_embeddings.requires_grad = False | |
def col_mode(self): | |
self.col_embeddings.requires_grad = True | |
def prob(self, i, j, vals): | |
p = ilogit(self.row_embeddings.weight.data.numpy()[i].dot( | |
self.col_embeddings.weight.data.numpy()[j])) | |
return p*vals + (1-p) * (1-vals) | |
class CategoricalFactorModel(HomogeneousFactorModel): | |
def __init__(self, X, k, row_embeddings): | |
super(CategoricalFactorModel, self).__init__(X, k, row_embeddings) | |
self.row_embeddings = row_embeddings | |
self.k = k | |
self.d = max([len(np.ma.unique(X[~X.mask[:,i],i])) for i in range(X.shape[1])]) | |
self.col_embeddings = nn.Parameter(torch.FloatTensor(np.random.normal(size=(X.shape[1], self.d, self.k)))) | |
self.labels = torch.LongTensor(X) | |
def forward(self, tidx): | |
'''Return the softmax logits for the tidx entries.''' | |
# return ((self.row_embeddings(tidx)[:,None,None] * self.col_embeddings.weight.view(-1, self.d, self.k)[None]).sum(dim=3), | |
# self.present[tidx]) | |
return ((self.row_embeddings(tidx)[:,None,None] * self.col_embeddings[None]).sum(dim=3), | |
self.present[tidx]) | |
def row_mode(self): | |
self.col_embeddings.requires_grad = False | |
def col_mode(self): | |
self.col_embeddings.requires_grad = True | |
def prob(self, i, j, vals): | |
logits = self.col_embeddings.data.numpy()[j].dot(self.row_embeddings.weight.data.numpy()[i]) | |
p = np.exp(logits) / np.sum(np.exp(logits)) | |
return p[vals] | |
class HeterogeneousFactorModel(nn.Module): | |
def __init__(self, X, k, min_continuous=10): | |
super(HeterogeneousFactorModel, self).__init__() | |
X = np.ma.array(X) | |
self.row_embeddings = nn.Embedding(X.shape[0], k) | |
# Count the unique values in each column | |
self.vals = [np.ma.sort(np.ma.unique(X[~X.mask[:,i],i])) for i in range(X.shape[1])] | |
self.nvals = np.array([len(v) for v in self.vals]) | |
for j in range(X.shape[1]): | |
if np.any(np.isnan(self.vals[j])): | |
raise Exception() | |
# Find the binary columns (2 values) | |
self.bin_mask = self.nvals <= 2 | |
self.bin_cols = np.arange(X.shape[1])[self.bin_mask] | |
if len(self.bin_cols) > 0: | |
self.X_bin = np.ma.array([X[:,i] == X[:,i].max() for i in range(X.shape[1]) if self.bin_mask[i]], | |
mask=[X.mask[:,i] for i in range(X.shape[1]) if self.bin_mask[i]]).T | |
self.factor_bin = BinaryFactorModel(self.X_bin, k, self.row_embeddings) | |
else: | |
print('No binary columns found.') | |
# Find the categorical columns (2 < d <= min_continuous values) | |
self.cat_mask = (self.nvals > 2) & (self.nvals < min_continuous) | |
self.cat_cols = np.arange(X.shape[1])[self.cat_mask] | |
if len(self.cat_cols) > 0: | |
self.X_cat = np.ma.array([(X[:,i:i+1] > self.vals[i][None]).sum(axis=1) for i in range(X.shape[1]) if self.cat_mask[i]], | |
mask=[X.mask[:,i] for i in range(X.shape[1]) if self.cat_mask[i]]).T.astype(int) | |
self.factor_cat = CategoricalFactorModel(self.X_cat, k, self.row_embeddings) | |
else: | |
print('No categorical columns found.') | |
# Find the continuous (Gaussian) columns (>= min_continuous values) | |
self.con_mask = self.nvals >= min_continuous | |
self.con_cols = np.arange(X.shape[1])[self.con_mask] | |
if len(self.con_cols) > 0: | |
self.X_con = np.ma.array(X[:,self.con_mask], mask=X.mask[:,self.con_mask]) | |
self.factor_con = GaussianFactorModel(self.X_con, k, self.row_embeddings) | |
else: | |
print('No gaussian columns found.') | |
def row_mode(self): | |
self.row_embeddings.requires_grad = True | |
self.factor_bin.row_mode() | |
self.factor_cat.row_mode() | |
self.factor_con.row_mode() | |
def col_mode(self): | |
self.row_embeddings.requires_grad = False | |
self.factor_bin.col_mode() | |
self.factor_cat.col_mode() | |
self.factor_con.col_mode() | |
def forward(self, tidx): | |
bin_logits = self.factor_bin(tidx) if len(self.bin_cols) > 0 else None | |
cat_logits = self.factor_cat(tidx) if len(self.cat_cols) > 0 else None | |
con_logits = self.factor_con(tidx) if len(self.cat_cols) > 0 else None | |
return (bin_logits, cat_logits, con_logits) | |
def prob(self, i, j, vals): | |
if self.bin_mask[j]: | |
c = self.bin_mask[:j].sum() | |
v = np.argmax(self.vals[j][None] == vals[:,None], axis=1) | |
return self.factor_bin.prob(i, c, v) | |
if self.cat_mask[j]: | |
c = self.cat_mask[:j].sum() | |
v = np.argmax(self.vals[j][None] == vals[:,None], axis=1) | |
return self.factor_cat.prob(i, c, v) | |
if self.con_mask[j]: | |
c = self.con_mask[:j].sum() | |
return self.factor_con.prob(i, c, vals) | |
raise Exception('Why did this not qualify as any valid type??') | |
def fit_factor_model(X, k, mf_epochs=5000, lr=1e-1, weight_decay=0, | |
lr_decay=0.96, lr_step=50, batchsize=10, | |
verbose=True, | |
min_continuous=10, con_weight=0.01, | |
**kwargs): | |
import sys | |
# Create the model | |
model = HeterogeneousFactorModel(X, k, min_continuous=min_continuous) | |
# Setup the different losses | |
bin_loss_raw = nn.BCEWithLogitsLoss(reduction='none') | |
cat_loss_raw = nn.CrossEntropyLoss(reduction='none') | |
bin_loss = lambda predicted, target, present: bin_loss_raw(predicted[present], target[present]).mean() # (bin_loss_raw(predicted, target)*present).sum() / present.sum() | |
cat_loss = lambda predicted, target, present: cat_loss_raw(predicted[present], target[present]).mean() #(cat_loss_raw(predicted, target)*present).sum() / present.sum() | |
con_loss = lambda loc, scale, target, present:-(torch.distributions.Normal(loc[present], scale).log_prob(target[present])).mean() * con_weight # -(torch.distributions.Normal(loc, scale).log_prob(target) * present).sum() / present.sum() | |
# Sample stochastically over rows | |
train_indices = np.arange(X.shape[0]) | |
# Track progress | |
losses = np.zeros(mf_epochs) | |
# Train the model | |
for epoch in range(mf_epochs*2): | |
if verbose and (epoch % 2) == 1: | |
print('\t\tEpoch {}'.format(epoch//2+1)) | |
sys.stdout.flush() | |
# Train the row embeddings on even epochs and the columns on odd epochs | |
if epoch % 2 == 0: | |
model.row_mode() | |
else: | |
model.col_mode() | |
# Setup the SGD method | |
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay) | |
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) | |
train_loss = torch.Tensor([0]) | |
for batch_idx, batch in enumerate(batches(train_indices, batchsize, shuffle=True)): | |
if verbose and (batch_idx % 100 == 0): | |
print('\t\t\tBatch {}'.format(batch_idx)) | |
tidx = autograd.Variable(torch.LongTensor(batch), requires_grad=False) | |
# Set the model to training mode | |
model.train() | |
# Reset the gradient | |
model.zero_grad() | |
# Get the model predictions for the rows in this batch | |
(bin_logits, bin_present), (cat_logits, cat_present), (con_out, con_present) = model(tidx) | |
loss = bin_loss(bin_logits, model.factor_bin.labels[tidx], bin_present) | |
loss += cat_loss(cat_logits.view(-1, model.factor_cat.d), model.factor_cat.labels[tidx].view(-1), cat_present.view(-1)) | |
loss += con_loss(con_out[0], 1, model.factor_con.labels[tidx], con_present) | |
# Calculate gradients | |
loss.backward() | |
# Apply the update | |
optimizer.step() | |
# Track the loss | |
train_loss += loss.data | |
# Track the total loss | |
losses[epoch // 2] += train_loss.numpy() | |
scheduler.step() | |
if verbose and (epoch % 2) == 1: | |
print('Loss: {}'.format(train_loss)) | |
if (epoch % 2) == 1 and ((epoch // 2) % lr_step) == 0: | |
lr *= lr_decay | |
return model | |
if __name__ == '__main__': | |
import matplotlib.pyplot as plt | |
from utils import ilogit | |
# Generate some fake data from something similar to the model | |
nbin = 10 | |
ncat = 11 | |
ncon = 12 | |
N = 100 | |
M = 4 | |
P = nbin+ncat+ncon | |
K = 6 | |
# Create the embeddings | |
print('Creating embeddings') | |
bin_embeds = np.random.normal(0,1, size=(nbin, K)) | |
cat_embeds = np.random.normal(0,1/np.sqrt(M), size=(ncat, M, K)) | |
con_embeds = np.random.normal(0,0.5, size=(ncon, 2, K)) | |
row_embeds = np.random.normal(0,1, size=(N,K)) | |
########## Create the data ########## | |
print('Creating data') | |
X = np.zeros((N,P)) | |
# Binary samples | |
print('\tCreating binary samples') | |
logits = np.einsum('nk,mk->nm', row_embeds, bin_embeds) | |
bin_probs = ilogit(logits) | |
X[:,:nbin] = np.random.random(size=(N,nbin)) <= bin_probs | |
# Categorical samples | |
print('\tCreating categorical samples') | |
logits = np.einsum('nk,cmk->ncm', row_embeds, cat_embeds) | |
cat_probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True) | |
for i in range(N): | |
for j in range(nbin, nbin+ncat): | |
X[i,j] = np.random.choice(M, p=cat_probs[i,j-nbin]) | |
# Continuous samples | |
print('\tCreating Gaussian samples') | |
congits = np.einsum('nk,cmk->ncm', row_embeds, con_embeds) | |
con_probs = norm.pdf(np.linspace(-5,5,100)[None,None], congits[:,:,0:1], scale=np.log1p(np.exp(congits[:,:,1:2])) + 0.1) | |
for j in range(nbin+ncat, nbin+ncat+ncon): | |
X[:,j] = np.random.normal(congits[:,j-ncat-nbin,0], np.log1p(np.exp(congits[:,j-ncat-nbin,1])) + 0.1) | |
print('Masking some random bits of the data') | |
X_mask = np.random.choice(X.shape[0], size=3), np.random.choice(X.shape[1], size=3) | |
mask = np.zeros(X.shape, dtype='bool') | |
X[X_mask[0], X_mask[1]] = np.nan | |
X[0,0] = np.nan | |
X = np.ma.array(X, mask=np.isnan(X)) | |
########### Fit a factor model ########### | |
print('Fitting factor model') | |
factor_model = fit_factor_model(X, K, min_continuous=M+1, mf_epochs=5000) | |
########### Plot some example results ########### | |
print('Plotting results') | |
fig, axarr = plt.subplots(8,12,figsize=(60,40), sharex=False, sharey=False) | |
for i in range(axarr.shape[0]): | |
for j in range(axarr.shape[1]): | |
ax = axarr[i,j] | |
if j < 4: | |
# Binary | |
ax.bar(np.arange(2)+0.3, [(1-bin_probs[i,j]), bin_probs[i,j]], width=0.3, color='black') | |
ax.bar(np.arange(2)+0.65, factor_model.prob(i, j, np.arange(2)), width=0.3, color='orange') | |
ax.axvline(X[i,j]+0.5, color='red', ls='--') | |
ax.set_xlim([0,2]) | |
ax.set_ylim([0,1]) | |
elif j < 8: | |
# Categorical | |
ax.bar(np.arange(M)+1-0.7, cat_probs[i,j-4], width=0.3, color='black') | |
ax.bar(np.arange(M)+1-0.35, factor_model.prob(i, j-4+nbin, np.arange(M)), width=0.3, color='orange') | |
ax.axvline(X[i,j+nbin-4]+0.5, color='red', ls='--') | |
ax.set_xlim([0, M]) | |
ax.set_ylim([0,1]) | |
else: | |
# Continuous | |
x_min, x_max = X[:,nbin+ncat:nbin+ncat+4].min()*1.1, X[:,nbin+ncat:nbin+ncat+4].max()*1.1 | |
ax.plot(np.linspace(x_min,x_max,100), con_probs[i,j-8], color='black') | |
ax.plot(np.linspace(x_min,x_max,100), factor_model.prob(i,j-8+nbin+ncat, np.linspace(x_min,x_max,100)), color='orange') | |
ax.axvline(X[i,j+nbin+ncat-8], color='red', ls='--') | |
ax.set_xlim([x_min, x_max]) | |
plt.savefig('plots/factor-demo.pdf', bbox_inches='tight') | |
plt.close() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment