Skip to content

birajpandey/KernelODETransport

Repository files navigation

Kernel Ordinary Differential Equations (KODE)

This code generates the experiments in the paper.

Biraj Pandey*, Bamdad Hosseini, Pau Battle, and Houman Owhadi. "Diffeomorphic Measure Matching with Kernels for Generative Modeling". [arxiv]

Setup:

  1. Do a clean download of the repository.

    git clone https://github.com/birajpandey/KernelODETransport.git
    
  2. Go to the downloaded repo

    cd path/to/KernelODETransport
    
  3. Run the Makefile. It creates an anaconda environment called kode_env, downloads required packages, datasets and runs tests.

    make 
    
  4. Activate the conda environment.

    conda activate kode_env
    
  5. Install the kode package

    pip install -e .
    
  6. Run the files in scripts/ to reproduce our published results.

Remark: This project structure is based on the cookiecutter data science project template. I also took a great deal of help from the The Good Research Code Handbook written by Patrick J Mineault.

Example:

Here we fit the KODE model to sample from the pinwheel distribution.

import numpy as np
import jax.random as jrandom
import matplotlib.pyplot as plt

from kode.data import load_dataset
from kode.models import transporter, kernels, losses, utils
from kode.visualization import visualize

# generate data
Y = load_dataset.two_dimensional_data('pinwheel', batch_size=5000, rng=None)
X = np.random.normal(size=(5000, 2))

# find inducing points
num_inducing_points = 100
inducing_points, median_distance = utils.find_inducing_points(X, Y,  
                                                              num_inducing_points, 
                                                              random_state=20)

# define model params 
model_params = {'length_scale': [0.1 * median_distance]}
num_discrete_steps, num_solver_steps = 5, 10
model_kernel = kernels.get_kernel('rbf', model_params)

# define loss function and parameters
loss_params = {'length_scale': [0.15 * median_distance]}
loss_kernel = kernels.get_kernel('laplace', loss_params)
mmd_loss_fun = losses.MMDLoss(loss_kernel)

# initialize the optimizer
optimizer = utils.get_adam_with_exp_decay()

# initialize the model
key = jrandom.PRNGKey(20)
transport_model = transporter.Transporter(inducing_points, model_kernel,
                                        num_discrete_steps, num_solver_steps, key)
gradient_mask = transport_model.get_gradient_mask()

# train
num_epochs = 501
rkhs_strength, h1_strength = 1e-6, 1e-6
batch_size = 5000
transport_model.fit(X, Y,  num_epochs, mmd_loss_fun, rkhs_strength,
                h1_strength, batch_size, optimizer)

# transform using the model
Y_pred, Y_traj = transport_model.transform(X, num_steps=20,  trajectory=True)


# Plot results
fig = plt.figure(figsize=(12, 4))
ax1, ax2, ax3 = visualize.plot_2d_distributions(fig, X, Y, Y_pred)
plt.show()

png

Reproducing experiments

Different scripts are provided for different datasets. To see all options, use the -h flag.

Hyperparameters

To see the hyperparameters used in each experiment, use the load_hyperparameters.py function.

from kode.models import load_hyperparameters
parameters = load_hyperparameters.two_dimensional_data('pinwheel')

2d benchmarks

To reproduce our experiments for 2d benchmarks, run:

python scripts/train_2d_benchmarks.py --dataset pinwheel --save-name pinwheel_experiment

You can change the model parameters via the command line.

python scripts/train_2d_benchmarks.py --dataset pinwheel --num-inducing-points 500 --num-epochs 101 --save-name modified_pinwheel_experiment

To evaluate the trained model, run:

python scripts/evaluate_2d_benchmarks.py --dataset pinwheel --file-name pinwheel_experiment

High dimensional benchmarks:

To reproduce our results on high dimensional benchmarks, run:

python scripts/train_high_dimensional_benchmarks.py --dataset power --save-name power_experiment

Conditional 2d experiments:

To reproduce our results on 2d conditioning experiments, run:

python scripts/train_conditional_2d_benchmarks.py --dataset pinwheel --save-name conditional_pinwheel_experiment

Lotka-Volterra experiment:

To reproduce our results on lotka volterra experiment, run:

python scripts/train_lotka_volterra.py --save-name lv_experiment

Acknowledgements

This material is in part based upon work supported by the US National Science Foundation Grant DMS-208535, NSF Graduate Fellowship grant DGE-1762114, and US AOFSR Grant FA9550-20-1-0358. Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the funding agencies.

Releases

No releases published

Packages

No packages published