Skip to content

A sample implementation of Mamba architecture for NLP.

Notifications You must be signed in to change notification settings

risv1/mamba-nlp-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mamba NLP in PyTorch

This repository contains a basic PyTorch implementation of the Mamba model for NLP, inspired by transformer architectures with a focus on integrating a Selective State Space Model (SSM) and MambaBlock design for improved sequence modeling.

Components

  1. ssm.py: Defines the SelectiveSSM module, which models sequences using state-space transformations with attention mechanisms, offering an alternative to standard transformer attention for certain tasks.
  2. block.py: Implements the MambaBlock, which combines the SelectiveSSM module with feedforward layers and residual connections, providing robust processing of sequences.
  3. encode.py: Contains the PositionalEncoding class for adding position-based information to token embeddings, critical for sequence understanding.
  4. nlp.py: Defines the overall MambaForNLP model, which incorporates the SelectiveSSM blocks, positional encodings, and the final output layer for generating predictions.

Usage

To train the Mamba model, simply run the following command:

python train.py

This will train the model on a text corpus, with the option to visualize training loss and analyze prediction probabilities using a sample_text defined in the file, try updating with a larger corpus for better results.

Citation

@article{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  journal={arXiv preprint arXiv:2312.00752},
  year={2023}
}

About

A sample implementation of Mamba architecture for NLP.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages