-
Notifications
You must be signed in to change notification settings - Fork 548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow Index Building #226
Comments
@clayscode : I am attempting to do the same with an RTX3090, but I am getting errors like: "RuntimeError: CUDA error: device-side assert triggered". Did you have to do anything special to make it work with your GPU ? |
Two possibilities: 1 - The FSM takes a long time to build; We can address that more or less easily; It would be amazing if you could run |
Timer unit: 1e-06 s
Total time: 35.5477 s
File: test.py
Function: main at line 31
Line # Hits Time Per Hit % Time Line Contents
==============================================================
31 @profile
32 def main():
33 1 2222021.2 2222021.2 6.3 model = models.transformers("gpt2", device=torch.device('cuda'))
34 1 33325550.6 33325550.6 93.7 sequence = generate.json(model, Character)("Give me a character description")
35 1 22.4 22.4 0.0 print(sequence)
36 # {
37 # "name": "ranbelt",
38 # "age": 26,
39 # "armor": "chainmail",
40 # "weapon": "bow",
41 # "strength": 5
42 # }
43
44 1 52.3 52.3 0.0 parsed = Character.model_validate_json(sequence)
45 1 36.8 36.8 0.0 print(parsed) |
Thank you! Could you add |
I ran the profiling with the following code: from enum import Enum
from pydantic import BaseModel, constr
import outlines
import outlines.models as models
import outlines.text.generate as generate
import line_profiler
class Weapon(str, Enum):
sword = "sword"
axe = "axe"
mace = "mace"
spear = "spear"
bow = "bow"
crossbow = "crossbow"
class Armor(str, Enum):
leather = "leather"
chainmail = "chainmail"
plate = "plate"
class Character(BaseModel):
name: constr(max_length=10)
age: int
armor: Armor
weapon: Weapon
strength: int
def fn():
model = models.transformers("gpt2")
sequence = generate.json(model, Character)("Give me a character description")
print(sequence)
profile = line_profiler.LineProfiler()
profile.add_function(outlines.text.generate.sequence.Sequence.step)
profile.add_function(outlines.text.parsing.map_partial_states_to_vocab)
profile.add_function(outlines.text.parsing.find_partial_matches)
profile(fn)()
profile.print_stats() And here are the results: Profiling stats{ "name": "ranbelt", "age": 26, "armor": "chainmail", "weapon": "bow", "strength": 5 } Timer unit: 1e-09 sTotal time: 38.6415 s
Total time: 2.83256 s
Total time: 23.5358 s
Total time: 23.5642 s
Unsurprisingly the bottleneck is the index building. For reference, the corresponding regex: regex = outlines.text.json_schema.build_regex_from_schema(json.dumps(Character.model_json_schema()))
print(regex)
# \{
# "name": ".{,10}",
# "age": (0|[1-9][0-9]*),
# "armor": ("leather"|"chainmail"|"plate"),
# "weapon": ("sword"|"axe"|"mace"|"spear"|"bow"|"crossbow"),
# "strength": (0|[1-9][0-9]*)
# \} We can also look at the corresponding FSM, which contains ~160 states: fsm = interegular.parse_pattern(regex).to_fsm()
print(fsm) FSMname final? \n " , 0 1 2 3 4 5 6 7 8 9 : a b c d e g h i l m n o p r s t w x { } anything_else --------------------------------------------------------------------------------------------------------------------------------------------------------------------- * 0 False 1 1 False 2 2 False 3 3 False 4 4 False 5 5 False 6 6 False 7 7 False 8 8 False 9 9 False 10 10 False 11 11 False 12 12 False 13 13 False 14 15 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 14 False 16 17 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 15 False 16 17 18 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 16 False 19 20 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 17 False 19 20 21 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 18 False 22 19 20 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 19 False 23 24 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 20 False 23 24 25 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 21 False 22 23 24 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 23 22 False 26 23 False 27 28 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 24 False 27 28 29 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 25 False 22 27 28 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 27 26 False 30 27 False 31 32 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 28 False 31 32 33 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 29 False 22 31 32 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 31 30 False 34 31 False 35 36 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 32 False 35 36 37 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 33 False 22 35 36 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 34 False 38 35 False 39 40 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 36 False 39 40 41 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 37 False 22 39 40 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 39 38 False 42 39 False 43 44 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 40 False 43 44 45 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 41 False 22 43 44 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 43 42 False 46 43 False 47 48 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 44 False 47 48 49 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 45 False 22 47 48 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 47 46 False 50 47 False 51 48 False 51 52 49 False 22 51 50 False 53 51 False 52 52 False 22 53 False 54 54 False 56 55 55 55 55 55 55 55 55 55 55 False 58 57 57 57 57 57 57 57 57 57 57 56 False 58 57 False 58 57 57 57 57 57 57 57 57 57 57 58 False 59 59 False 60 60 False 61 61 False 62 62 False 63 63 False 64 64 False 65 65 False 66 66 False 67 67 False 68 68 False 69 69 False 70 70 False 71 71 False 74 73 72 72 False 75 73 False 76 74 False 77 75 False 78 76 False 79 77 False 80 78 False 81 79 False 82 80 False 83 81 False 84 82 False 85 83 False 86 84 False 87 85 False 88 86 False 89 87 False 90 88 False 91 89 False 92 90 False 93 91 False 94 92 False 95 93 False 96 94 False 90 95 False 97 96 False 98 97 False 99 98 False 100 99 False 90 100 False 101 101 False 102 102 False 103 103 False 104 104 False 105 105 False 106 106 False 107 107 False 108 108 False 109 109 False 110 110 False 113 114 115 112 111 111 False 116 117 112 False 118 113 False 119 114 False 120 115 False 121 116 False 122 117 False 123 118 False 124 119 False 125 120 False 126 121 False 127 122 False 128 123 False 129 124 False 130 125 False 131 126 False 132 127 False 133 128 False 134 129 False 135 130 False 136 131 False 137 132 False 137 133 False 138 134 False 139 135 False 140 136 False 137 137 False 141 138 False 142 139 False 137 140 False 137 141 False 143 142 False 144 143 False 145 144 False 146 145 False 147 146 False 148 147 False 149 148 False 137 149 False 150 150 False 151 151 False 152 152 False 153 153 False 154 154 False 155 155 False 156 156 False 157 157 False 158 158 False 159 159 False 161 160 160 160 160 160 160 160 160 160 160 False 163 162 162 162 162 162 162 162 162 162 162 161 False 163 162 False 163 162 162 162 162 162 162 162 162 162 162 163 False 164 164 TrueThere is a very simple optimization for this, which is to hard-code the tokens that correspond to the JSON structure and field names, and memoize |
We should also clarify that the index construction only needs to occur once for a given regular expression and vocabulary. As @rlouf said, we haven't set up automatic caching for this, but we can. |
The above code snippet runs very slowly on my machine. I have a 7900XTX and before I added
device=torch.device('cuda')
it was defaulting to CPU inference only. It seems to running on my GPU now, but the inference is still very slow (takes ~30 seconds to run the above). This could just be a ROCM thing, I'm not entirely sure. I've installed both Torch and Tensorflow with ROCM support. Any idea what might be going on?The text was updated successfully, but these errors were encountered: