Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create hello.py #1

Open
wants to merge 1 commit into
base: check_compile
Choose a base branch
from
Open

Conversation

ydshieh2
Copy link

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ydshieh2
Copy link
Author

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda"
ckpt = "google/gemma-2b"

model = AutoModelForCausalLM.from_pretrained(ckpt)
#config = model.config
#config.num_hidden_layers = 1
#config.vocab_size = 16
#config.intermediate_size = 16
#config.num_attention_heads = 2
#config.num_key_value_heads = 2
#config.head_dim = 16
#config.max_length = 16

#model = type(model)(config=config)
model = model.to(device)
model.eval()


tokenizer = AutoTokenizer.from_pretrained(ckpt)

sequence = "Hey what's the plan" * 1
inputs = tokenizer.encode(sequence, return_tensors='pt').to(device)
#inputs = torch.zeros_like(inputs, device=device)

#model_forward = model.forward
#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(10):
#    o = model_forward(inputs)


#breakpoint()

N_WORKERS = 20
N_ITER = 1

streams = [torch.cuda.Stream(device=device) for _  in range(N_WORKERS)]

x = torch.rand(size=(128*1, 128*1)).to(device)
w = torch.rand(size=(128*1, 128*1)).to(device)

#breakpoint()

#model.config.cache_implementation = "static"
#model.generation_config.cache_implementation = "static"


import os
os.environ["TOKENIZERS_PARALLELISM"] = "0"

#def model_forward(model, *args, **kwargs):
#    return model.forward(*args, **kwargs)


from transformers import StaticCache
cache = StaticCache(config=model.config, batch_size=1, max_cache_len=64, device=device)
past_key_values = cache
seq_length = inputs.size()[-1]
cache_position = torch.arange(seq_length, device=device)

#breakpoint()

#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(4):
#    o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

#breakpoint()

compiled = {idx: None for idx in range(N_WORKERS)}

def foo(idx):

    if compiled[idx] is None:
        def model_forward(model, *args, **kwargs):
            return model(*args, **kwargs)

        model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)

        for i in range(4):
            o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

        compiled[idx] = model_forward

    model_forward = compiled[idx]

    for j in range(1):

        st = datetime.datetime.now()

        for idx in [idx]:
            outputs = dict()
            s = streams[idx]
            with torch.cuda.stream(s):
                o = 0
                with torch.no_grad():
                    for i in range(N_ITER):
                        torch.cuda.nvtx.range_push('iter{}'.format(i))
                        #out = torch.matmul(x, w)
                        out = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)
                        #out = model_forward(model, inputs)
                        #out = model(inputs)
                        o = o + out.logits
                        outputs[idx] = o
                        #out = model(inputs)
                        #o = o + out.logits
                    #torch.cuda.nvtx.range_pop()

        [s.synchronize() for s in streams]

        d = (datetime.datetime.now() - st).total_seconds()
        print(f'idx: {idx} = {"%.9f" % d}')


import threading


import datetime

for i in range(20):
    s = datetime.datetime.now()

    for idx in range(N_WORKERS):
        t = threading.Thread(target=foo, args=(idx,))
        t.start()
        t.join()

    d = (datetime.datetime.now() - s).total_seconds()
    print(d)

@ydshieh2
Copy link
Author

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda"
ckpt = "google/gemma-2b"

model = AutoModelForCausalLM.from_pretrained(ckpt)
#config = model.config
#config.num_hidden_layers = 1
#config.vocab_size = 16
#config.intermediate_size = 16
#config.num_attention_heads = 2
#config.num_key_value_heads = 2
#config.head_dim = 16
#config.max_length = 16

#model = type(model)(config=config)
model = model.to(device)
model.eval()


tokenizer = AutoTokenizer.from_pretrained(ckpt)

sequence = "Hey what's the plan" * 1
inputs = tokenizer.encode(sequence, return_tensors='pt').to(device)
#inputs = torch.zeros_like(inputs, device=device)

#model_forward = model.forward
#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(10):
#    o = model_forward(inputs)


#breakpoint()

N_WORKERS = 1
N_ITER = 4

streams = [torch.cuda.Stream(device=device) for _  in range(N_ITER)]

x = torch.rand(size=(128*1, 128*1)).to(device)
w = torch.rand(size=(128*1, 128*1)).to(device)

#breakpoint()

#model.config.cache_implementation = "static"
#model.generation_config.cache_implementation = "static"


import os
os.environ["TOKENIZERS_PARALLELISM"] = "0"

##def model_forward(model, *args, **kwargs):
#    return model.forward(*args, **kwargs)


from transformers import StaticCache
cache = StaticCache(config=model.config, batch_size=1, max_cache_len=64, device=device)
past_key_values = cache
seq_length = inputs.size()[-1]
cache_position = torch.arange(seq_length, device=device)

#breakpoint()

#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(4):
#    o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

#breakpoint()

compiled = {idx: None for idx in range(N_WORKERS)}
#compiled = {idx: None for idx in range(N_ITER)}


def foo(idx):

    if compiled[idx] is None:
        def model_forward(model, *args, **kwargs):
            return model(*args, **kwargs)

        #model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)

        for i in range(4):
            o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

        compiled[idx] = model_forward

    model_forward = compiled[idx]

    st = datetime.datetime.now()

    for j in range(N_ITER):

        #st = datetime.datetime.now()

        for idx in [idx]:
            outputs = dict()
            s = streams[j]

            with torch.cuda.stream(s):
                o = 0
                with torch.no_grad():
                    for i in range(1):
                        #torch.cuda.nvtx.range_push('iter{}'.format(i))
                        #out = torch.matmul(x, w)
                        #out = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)
                        #out = model_forward(model, inputs)
                        out = model(inputs)
                        #o = o + out.logits
                        #outputs[idx] = o
                        #out = model(inputs)
                        #o = o + out.logits
                    #torch.cuda.nvtx.range_pop()

        #[s.synchronize() for s in [streams[j]]]

        d = (datetime.datetime.now() - st).total_seconds()
        print(f'idx: {idx} = {"%.9f" % d}')


import threading


import datetime

for i in range(20):
    s = datetime.datetime.now()

    for idx in range(N_WORKERS):
        t = threading.Thread(target=foo, args=(idx,))
        t.start()
        t.join()

    d = (datetime.datetime.now() - s).total_seconds()
    print(d)

@ydshieh2
Copy link
Author

import torch

from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda"
ckpt = "google/gemma-2b"

model = AutoModelForCausalLM.from_pretrained(ckpt)
#config = model.config
#config.num_hidden_layers = 1
#config.vocab_size = 16
#config.intermediate_size = 16
#config.num_attention_heads = 2
#config.num_key_value_heads = 2
#config.head_dim = 16
#config.max_length = 16

#model = type(model)(config=config)
model = model.to(device)
model.eval()


tokenizer = AutoTokenizer.from_pretrained(ckpt)

sequence = "Hey what's the plan" * 1
inputs = tokenizer.encode(sequence, return_tensors='pt').to(device)
#inputs = torch.zeros_like(inputs, device=device)

#model_forward = model.forward
#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(10):
#    o = model_forward(inputs)


#breakpoint()

N_WORKERS = 2
N_ITER = 1

streams = [torch.cuda.Stream(device=device) for _  in range(N_WORKERS)]

x = torch.rand(size=(128*1, 128*1)).to(device)
w = torch.rand(size=(128*1, 128*1)).to(device)

#breakpoint()

#model.config.cache_implementation = "static"
#model.generation_config.cache_implementation = "static"


import os
os.environ["TOKENIZERS_PARALLELISM"] = "0"

##def model_forward(model, *args, **kwargs):
#    return model.forward(*args, **kwargs)


from transformers import StaticCache
cache = StaticCache(config=model.config, batch_size=1, max_cache_len=64, device=device)
past_key_values = cache
seq_length = inputs.size()[-1]
cache_position = torch.arange(seq_length, device=device)

#breakpoint()

#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(4):
#    o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

#breakpoint()

compiled = {idx: None for idx in range(N_WORKERS)}
#compiled = {idx: None for idx in range(N_ITER)}


def foo(idx):

    if compiled[idx] is None:
        def model_forward(model, *args, **kwargs):
            return model(*args, **kwargs)

        #model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)

        for i in range(4):
            o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

        compiled[idx] = model_forward

    model_forward = compiled[idx]

    st = datetime.datetime.now()

    for idx in [idx]:
        outputs = dict()
        s = streams[idx]
        with torch.cuda.stream(s):
            o = 0
            with torch.no_grad():
                for i in range(N_ITER):
                    out = model(inputs)

    [s.synchronize() for s in [streams[idx]]]

    d = (datetime.datetime.now() - st).total_seconds()
    print(f'idx: {idx} = {"%.9f" % d}')


import threading


import datetime

for i in range(20):
    s = datetime.datetime.now()

    for idx in range(N_WORKERS):
        t = threading.Thread(target=foo, args=(idx,))
        t.start()
        t.join()

    d = (datetime.datetime.now() - s).total_seconds()
    print(d)

@ydshieh2
Copy link
Author

import threading
import datetime


N_WORKERS = 2
N_ITER = 1


def foo(idx):


    st = datetime.datetime.now()

    for idx in [idx]:
        for i in range(N_ITER):
            for L in range(10000000):
                L = L * L // 999999

    d = (datetime.datetime.now() - st).total_seconds()
    print(f'idx: {idx} = {"%.9f" % d}')


if __name__ == "__main__":

    import sys;
    print(sys._is_gil_enabled())

    for i in range(20):
        s = datetime.datetime.now()

        for idx in range(N_WORKERS):
            t = threading.Thread(target=foo, args=(idx,))
            t.start()
            t.join()

        d = (datetime.datetime.now() - s).total_seconds()
        print(d)

@ydshieh2
Copy link
Author

import threading
import datetime


N_WORKERS = 4
N_ITER = 1


def foo(idx):


    #st = datetime.datetime.now()

    for i in range(N_ITER):
        for L in range(10000000):
            L = L * L // 999999

    #d = (datetime.datetime.now() - st).total_seconds()
    #print(f'idx: {idx} = {"%.9f" % d}')


if __name__ == "__main__":

    import sys;
    print(sys._is_gil_enabled())

    for i in range(20):
        s = datetime.datetime.now()

        threads = [threading.Thread(target=foo, args=(idx,)) for idx in range(N_WORKERS)]
        [t.start() for t in threads]
        [t.join() for t in threads]

        d = (datetime.datetime.now() - s).total_seconds()
        print(d)

@ydshieh2
Copy link
Author

import torch

from transformers import AutoModelForCausalLM

device = "cuda"
ckpt = "google/gemma-2b"
ckpt = "ydshieh-gemma-2b"

model = AutoModelForCausalLM.from_pretrained(ckpt)

config = model.config
config.num_hidden_layers = 1
config.vocab_size = 16
config.intermediate_size = 16
config.num_attention_heads = 2
config.num_key_value_heads = 2
config.head_dim = 16
config.max_length = 16

model = type(model)(config=config)

model = model.to(device)
model.eval()

#model.save_pretrained("ydshieh-gemma-2b", safe_serialization=False)


inputs = torch.tensor([[0] * 6], dtype=torch.int32)
inputs = inputs.to(device)


N_WORKERS = 3
N_ITER = 1

if device == "cuda":
    streams = [torch.cuda.Stream(device=device) for _  in range(N_WORKERS)]



import os
os.environ["TOKENIZERS_PARALLELISM"] = "0"



from transformers import StaticCache
cache = StaticCache(config=model.config, batch_size=1, max_cache_len=64, device=device)
past_key_values = cache
seq_length = inputs.size()[-1]
cache_position = torch.arange(seq_length, device=device)

#breakpoint()

#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(4):
#    o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

#breakpoint()

compiled = {idx: None for idx in range(N_WORKERS)}
#compiled = {idx: None for idx in range(N_ITER)}

def foo(idx):


    st = datetime.datetime.now()

    for idx in [idx]:
        if device == "cuda":
            s = streams[idx]
            with torch.cuda.stream(s):
                with torch.no_grad():
                    for i in range(N_ITER):
                        #for L in range(10000000):
                        #    L = L * L // 999999
                        out = model(inputs)
        else:
            with torch.no_grad():
                for i in range(N_ITER):
                    #for L in range(10000000):
                    #    L = L * L // 999999
                    out = model(inputs)

    if device == "cuda":
        [s.synchronize() for s in [streams[idx]]]

    d = (datetime.datetime.now() - st).total_seconds()
    print(f'idx: {idx} = {"%.9f" % d}')


import threading


import datetime

for i in range(20):
    s = datetime.datetime.now()

    threads = [threading.Thread(target=foo, args=(idx,)) for idx in range(N_WORKERS)]
    [t.start() for t in threads]
    [t.join() for t in threads]

    d = (datetime.datetime.now() - s).total_seconds()
    print(d)

@ydshieh2
Copy link
Author

import torch

from transformers import AutoModelForCausalLM

device = "cuda"
ckpt = "google/gemma-2b"
ckpt = "ydshieh-gemma-2b"

model = AutoModelForCausalLM.from_pretrained(ckpt)

config = model.config
config.num_hidden_layers = 8
config.vocab_size = 16
config.intermediate_size = 16
config.num_attention_heads = 2
config.num_key_value_heads = 2
config.head_dim = 16
config.max_length = 16

model = type(model)(config=config)

model = model.to(device)
model.eval()

#model.save_pretrained("ydshieh-gemma-2b", safe_serialization=False)


inputs = torch.tensor([[0] * 6], dtype=torch.int32)
inputs = inputs.to(device)


N_WORKERS = 8
N_ITER = 1

if device == "cuda":
    streams = [torch.cuda.Stream(device=device) for _  in range(N_WORKERS)]



import os
os.environ["TOKENIZERS_PARALLELISM"] = "0"



from transformers import StaticCache
cache = StaticCache(config=model.config, batch_size=1, max_cache_len=64, device=device)
past_key_values = cache
seq_length = inputs.size()[-1]
cache_position = torch.arange(seq_length, device=device)

#breakpoint()

#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(4):
#    o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

#breakpoint()

compiled = {idx: None for idx in range(N_WORKERS)}
#compiled = {idx: None for idx in range(N_ITER)}


def foo(idx):

    if idx == None:
        return

    st = datetime.datetime.now()

    for idx in [idx]:
        if device == "cuda":
            s = streams[idx]
            with torch.cuda.stream(s):
                with torch.no_grad():
                    for i in range(N_ITER):
                        #for L in range(10000000):
                        #    L = L * L // 999999
                        out = model(inputs)
        else:
            with torch.no_grad():
                for i in range(N_ITER):
                    #for L in range(10000000):
                    #    L = L * L // 999999
                    out = model(inputs)

    if device == "cuda":
        [s.synchronize() for s in [streams[idx]]]

    d = (datetime.datetime.now() - st).total_seconds()
    print(f'idx: {idx} = {"%.9f" % d}')


import threading


import datetime

times = []
for i in range(20):
    s = datetime.datetime.now()

    threads = [threading.Thread(target=foo, args=(None,)) for idx in range(N_WORKERS)]
    [t.start() for t in threads]
    [t.join() for t in threads]

    d = (datetime.datetime.now() - s).total_seconds()
    print(d)
    times.append(d)

_avg_time = sum(times) / len(times)



times = []
for i in range(20):
    s = datetime.datetime.now()

    threads = [threading.Thread(target=foo, args=(idx,)) for idx in range(N_WORKERS)]
    [t.start() for t in threads]
    [t.join() for t in threads]

    d = (datetime.datetime.now() - s).total_seconds()
    print(d)
    times.append(d)


avg_time = sum(times) / len(times)

time_adjusted = [x - _avg_time for x in times]
avg_time_adjusted = sum(time_adjusted ) / len(time_adjusted )

print(f"avg_time: do nothing = {_avg_time}")
print(f"avg_time: do something = {avg_time}")
print(f"avg_time_adjusted: do something = {avg_time_adjusted}")
print(f"avg_time_perworker: do something = {avg_time_adjusted / N_WORKERS}")

print(f"N_WORKERS :{N_WORKERS}")





@ydshieh2
Copy link
Author

import torch

from transformers import AutoModelForCausalLM

device = "cuda"
ckpt = "google/gemma-2b"
ckpt = "ydshieh-gemma-2b"

model = AutoModelForCausalLM.from_pretrained(ckpt)

config = model.config
config.num_hidden_layers = 8
config.vocab_size = 16
config.intermediate_size = 16
config.num_attention_heads = 2
config.num_key_value_heads = 2
config.head_dim = 16
config.max_length = 16

model = type(model)(config=config)

model = model.to(device)
model.eval()

#model.save_pretrained("ydshieh-gemma-2b", safe_serialization=False)


inputs = torch.tensor([[0] * 6], dtype=torch.int32)
inputs = inputs.to(device)


N_WORKERS = 32
N_ITER = 1

if device == "cuda":
    streams = [torch.cuda.Stream(device=device) for _  in range(N_WORKERS)]



import os
os.environ["TOKENIZERS_PARALLELISM"] = "0"



from transformers import StaticCache
cache = StaticCache(config=model.config, batch_size=1, max_cache_len=64, device=device)
past_key_values = cache
seq_length = inputs.size()[-1]
cache_position = torch.arange(seq_length, device=device)

#breakpoint()

#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(4):
#    o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

#breakpoint()

compiled = {idx: None for idx in range(N_WORKERS)}
#compiled = {idx: None for idx in range(N_ITER)}

worker_time = {idx: [] for idx in range(N_WORKERS)}


def foo(idx):

    if idx == None:
        return

    st = datetime.datetime.now()

    for idx in [idx]:
        if device == "cuda":
            s = streams[idx]
            with torch.cuda.stream(s):
                with torch.no_grad():
                    for i in range(N_ITER):
                        #for L in range(10000000):
                        #    L = L * L // 999999
                        out = model(inputs)
        else:
            with torch.no_grad():
                for i in range(N_ITER):
                    #for L in range(10000000):
                    #    L = L * L // 999999
                    out = model(inputs)

    if device == "cuda":
        [s.synchronize() for s in [streams[idx]]]

    d = (datetime.datetime.now() - st).total_seconds()
    print(f'idx: {idx} = {"%.9f" % d}')

    worker_time[idx].append(d)

import threading


import datetime

times = []
for i in range(20):
    s = datetime.datetime.now()

    threads = [threading.Thread(target=foo, args=(None,)) for idx in range(N_WORKERS)]
    [t.start() for t in threads]
    [t.join() for t in threads]

    d = (datetime.datetime.now() - s).total_seconds()
    print(d)
    times.append(d)

_avg_time = sum(times) / len(times)


worker_time = {idx: [] for idx in range(N_WORKERS)}
times = []
for i in range(5 + 20):
    s = datetime.datetime.now()

    threads = [threading.Thread(target=foo, args=(idx,)) for idx in range(N_WORKERS)]
    [t.start() for t in threads]
    [t.join() for t in threads]

    d = (datetime.datetime.now() - s).total_seconds()
    print(d)
    if i > 4:
        times.append(d)
avg_time = sum(times) / len(times)

time_adjusted = [x - _avg_time for x in times]
avg_time_adjusted = sum(time_adjusted ) / len(time_adjusted )

all_worker_time = []
for idx in range(N_WORKERS):
    all_worker_time.extend(worker_time[idx][5:])
avg_worker_time = sum(all_worker_time) / len(all_worker_time)

print(f"avg_time: do nothing = {_avg_time}")
print(f"avg_time: do something = {avg_time}")
print(f"avg_time_adjusted: do something = {avg_time_adjusted}")
print(f"avg_worker_time: do something = {avg_worker_time}")
print(f"computed_avg_worker_time: do something = {avg_time_adjusted / N_WORKERS}")

print(f"N_WORKERS :{N_WORKERS}")

@ydshieh2
Copy link
Author

def train(model):
    # Construct data_loader, optimizer, etc.
    pass
    print("hello")
    import time
    while True:
        #import torch
        time.sleep(1)
        #inputs = torch.tensor([[0] * 6], dtype=torch.int32)
        #o = model(inputs.to("cuda"))
        #print(o)

if __name__ == '__main__':

    import torch

    from transformers import AutoModelForCausalLM, DistilBertForMaskedLM, GemmaForCausalLM

    device = "cuda"
    ckpt = "google/gemma-2b"
    #ckpt = "ydshieh-gemma-2b"

    #breakpoint()
    import torch.multiprocessing as mp
    mp.set_start_method("spawn")


    num_processes = 2
    #breakpoint()
    model = GemmaForCausalLM.from_pretrained(ckpt)

    config = model.config
    #config.num_hidden_layers = 1
    config.vocab_size = 16
    config.intermediate_size = 16
    config.num_attention_heads = 2
    config.num_key_value_heads = 2
    config.head_dim = 16
    config.max_length = 16

    model = type(model)(config=config)

    #model = model.to(device)
    model.eval()
    #breakpoint()
    model = model.to("cuda")

    #breakpoint()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()



    #breakpoint()
    processes = []

    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        #breakpoint()
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

    #exit(0)
    #breakpoint()

@ydshieh2
Copy link
Author

apt-get install wget nano
wget  https://repo.anaconda.com/miniconda/Miniconda3-py312_24.9.2-0-Linux-x86_64.sh
bash Miniconda3-py312_24.9.2-0-Linux-x86_64.sh -b -p $HOME/miniconda
source $HOME/miniconda/bin/activate
conda init
. ~/.bashrc
conda create -y -n py13 --override-channels -c conda-forge python-freethreading
conda activate py13
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121

@ydshieh2
Copy link
Author

import torch

from transformers import AutoModelForCausalLM

device = "cuda"
ckpt = "google/gemma-2b"
ckpt = "ydshieh-gemma-2b"

model = AutoModelForCausalLM.from_pretrained(ckpt)

config = model.config
config.num_hidden_layers = 8
config.vocab_size = 16
config.intermediate_size = 16
config.num_attention_heads = 2
config.num_key_value_heads = 2
config.head_dim = 16
config.max_length = 16

model = type(model)(config=config)

model = model.to(device)
model.eval()

#model.save_pretrained("ydshieh-gemma-2b", safe_serialization=False)


inputs = torch.tensor([[0] * 6], dtype=torch.int32)
inputs = inputs.to(device)



N_ITER = 1



import os
os.environ["TOKENIZERS_PARALLELISM"] = "0"



from transformers import StaticCache
cache = StaticCache(config=model.config, batch_size=1, max_cache_len=64, device=device)
past_key_values = cache
seq_length = inputs.size()[-1]
cache_position = torch.arange(seq_length, device=device)

#breakpoint()

#model_forward = torch.compile(model_forward, mode="reduce-overhead", fullgraph=True)
#for i in range(4):
#    o = model_forward(model, inputs, past_key_values=past_key_values, cache_position=cache_position)

#breakpoint()

# compiled = {idx: None for idx in range(N_WORKERS)}
#compiled = {idx: None for idx in range(N_ITER)}


def foo(idx):

    if idx == None:
        return

    st = datetime.datetime.now()

    for idx in [idx]:
        if device == "cuda":
            s = streams[idx]
            with torch.cuda.stream(s):
                with torch.no_grad():
                    for i in range(N_ITER):
                        #for L in range(10000000):
                        #    L = L * L // 999999
                        out = model(inputs)
        else:
            with torch.no_grad():
                for i in range(N_ITER):
                    #for L in range(10000000):
                    #    L = L * L // 999999
                    out = model(inputs)

    if device == "cuda":
        [s.synchronize() for s in [streams[idx]]]

    d = (datetime.datetime.now() - st).total_seconds()
    # print(f'idx: {idx} = {"%.9f" % d}')

    worker_time[idx].append(d)

import threading


import datetime

ALL_N_WORKERS = list(range(1, 1 + 64))

for N_WORKERS in ALL_N_WORKERS:

    if device == "cuda":
        streams = [torch.cuda.Stream(device=device) for _ in range(N_WORKERS)]

    worker_time = {idx: [] for idx in range(N_WORKERS)}

    times = []
    for i in range(20):
        s = datetime.datetime.now()

        threads = [threading.Thread(target=foo, args=(None,)) for idx in range(N_WORKERS)]
        [t.start() for t in threads]
        [t.join() for t in threads]

        d = (datetime.datetime.now() - s).total_seconds()
        #print(d)
        times.append(d)

    _avg_time = sum(times) / len(times)


    worker_time = {idx: [] for idx in range(N_WORKERS)}
    times = []
    for i in range(5 + 20):
        s = datetime.datetime.now()

        threads = [threading.Thread(target=foo, args=(idx,)) for idx in range(N_WORKERS)]
        [t.start() for t in threads]
        [t.join() for t in threads]

        d = (datetime.datetime.now() - s).total_seconds()
        #print(d)
        if i > 4:
            times.append(d)
    avg_time = sum(times) / len(times)

    time_adjusted = [x - _avg_time for x in times]
    avg_time_adjusted = sum(time_adjusted ) / len(time_adjusted )

    all_worker_time = []
    for idx in range(N_WORKERS):
        all_worker_time.extend(worker_time[idx][5:])
    avg_worker_time = sum(all_worker_time) / len(all_worker_time)

    print(f"avg_time: do nothing = {_avg_time}")
    print(f"avg_time: do something = {avg_time}")
    print(f"avg_time_adjusted: do something = {avg_time_adjusted}")
    print(f"avg_worker_time: do something = {avg_worker_time}")
    print(f"computed_avg_worker_time: do something = {avg_time_adjusted / N_WORKERS}")

    print(f"N_WORKERS :{N_WORKERS}")
    print("============================")

@ydshieh2
Copy link
Author

def train(rank):
    import torch
    import datetime

    device = "cuda"
    x = torch.rand(size=(128 * 32, 128 * 32)).to(device)
    w = torch.rand(size=(128 * 32, 128 * 32)).to(device)

    print("hello")
    import time
    while True:


        #time.sleep(1)

        s = datetime.datetime.now()
        for _ in range(100):
            out = torch.matmul(x, w)
        d = (datetime.datetime.now() - s).total_seconds()
        print(f"rank {rank}: {d}")


if __name__ == '__main__':

    import torch
    import datetime


    device = "cuda"

    #breakpoint()
    import torch.multiprocessing as mp
    mp.set_start_method("spawn")


    num_processes = 4



    #breakpoint()
    processes = []

    for rank in range(num_processes):
        p = mp.Process(target=train, args=(rank,))
        #breakpoint()
        p.start()
        processes.append(p)
    for p in processes:
        p.join()

    #exit(0)
    #breakpoint()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant