Skip to content

jdhorwood/mtlcm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Leveraging Task Structures for Improved Identifiability in Neural Network Representations (TMLR 2024)

Paper Paper

This package provides the implementations for the results in the paper: Leveraging Task Structures for Improved Identifiability in Neural Network Representations (published at TMLR 2024).

Installation

  1. Create a conda environment with python version 3.9.
    conda create -n mtlcm python=3.9
  2. Install poetry, which we use for dependency management.
  3. Activate the conda environment.
    conda activate mtlcm
  4. Within the project's directory, run:
    poetry install
  5. If using cuda, install the cuda version of dgl via pip manually by running (replace cuxxx with your cuda version)
    poetry remove dgl; pip install  dgl -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html
    

Running the experiments

The experiments are run via the file mtlcm/run.py and configured through yaml files. Below are sample commands for each experiment. Replace the configuration files/entries to try different hyper-parameters.

Linear synthetic data

python mtlcm/run.py linear_synthetic mtlcm/experiments/linear_identifiability/configs/config.yaml

Non-linear synthetic data

python mtlcm/run.py multitask_synthetic mtlcm/experiments/synthetic_multitask/configs/exp_config.yaml

QM9 data

python mtlcm/run.py qm9 mtlcm/experiments/qm9/configs/latent_7/config_0.yaml

Superconductivity data

For the superconductivity, you first need to download the files train.csv and unique_m.csv from this link into the mtlcm/data/superconduct/ directory.

Then run

python mtlcm/run.py superconduct mtlcm/experiments/superconduct/configs/exp_config.yaml

Results

Results for synthetic data experiments will be immediately available in the exp_outputs/ directory. For real-world data, the above command should be run multiple times for different seeds (e.g. latent_7/config_[1,2,3,4].yaml in the above example for qm9). The results can then be computed by comparing the representations across seeds by running

python mtlcm/results.py "exp_outputs/qm9/full_config_latent7"

Citation

If you find our paper and/or code useful for your research, please consider citing our paper:

@article{chen2024leveraging,
title={Leveraging Task Structures for Improved Identifiability in Neural Network Representations},
author={Wenlin Chen and Julien Horwood and Juyeon Heo and Jos{\'e} Miguel Hern{\'a}ndez-Lobato},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2024},
url={https://openreview.net/forum?id=WLcPrq6pu0}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages