-
-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
thanks (it's 10x faster than JAX)! #25
Comments
Awesome! |
I think the model runs on CPU by default. I tried to move all models and tensors to the |
@pcuenca wait, you got it running on-GPU? and it was faster? that's massively different from the result I got. here's how I made it run on MPS: what I found was that it ran way slower. I left it overnight and it didn't finish generating even 1 image (got to the 145th token of 255, something like that). did I do something wrong? I just slapped |
it's worth knowing that the MPS backend does have some silent errors where it will produce incorrect output (or at least transfer the wrong result to CPU). here's the really wacky phenomenon that I found: |
@Birch-san These are my changes so far: main...pcuenca:min-dalle:mps-device I tried to use workarounds for unsupported ops, except for I may have introduced a problem somewhere, but if you disable the MPS device by returning |
That's very interesting. I'll try to debug generation tomorrow. Thanks! |
I was also looking into getting this model on the phone. Apple says that for transformers in pytorch, the dimensions aren't in optimal order for the neural engine: https://machinelearning.apple.com/research/neural-engine-transformers They also convert all the linear layers to convs and use a different einsum pattern |
that's just the neural engine. PyTorch's MPS backend targets the GPU, and JAX's IREE/Vulkan backend does too. Dunno what Tensorflow targets. but I'll definitely take "targeting 48 GPU cores" as a step up from "targeting 10 CPU cores". it sounds like the Neural Engine is not suitable for training anyway, only inferencing: |
The neural engine is much faster than the GPU, so it makes sense to apply those optimizations. Not all operations are supported, however, and it's hard to know whether the system decided to run your model in the neural engine or the GPU. I wasn't trying to do that yet, though. I just wanted to test inference in the MPS backend (GPU) of my M1 mac to see how it compares with the CPU and with nVidia GPUs. If we did a conversion to Core ML, we would then be able to test neural engine inference speed vs PyTorch+MPS performance. |
If it is indeed the problem of transferring from MPS to CPU, then we should try @qqaatw's idea for transferring as contiguous memory. |
@pcuenca if I slap |
I also tried using still black. |
Even faster these days: you get a 4x4 grid instead of a 3x3 grid on Replicate, after the same duration. However, this is based on Dall-E MEGA instead of Dall-E Mini, so results might differ. Not sure if better or worse. |
I've been trying to get dalle-playground running performantly on M1, but there's a lot of work remaining to make the JAX model work via IREE/Vulkan.
so, I tried out your pytorch model,
with a recent nightly of pytorch:
…and it's 10x faster at dalle-mega than dalle-playground was on JAX/XLA!
using dalle-mega full:
generating 1 image took 27 mins on dalle-playground (using 117% CPU), whereas this pytorch model runs in 2.7 mins (using 145% CPU)!
GPU looks less-than-half utilized. haven't checked whether pytorch is the process that's using the GPU.
these measurements are from M1 Max.
bonus

"crystal maiden and lina enjoying a pint together at a tavern"
The text was updated successfully, but these errors were encountered: