Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: fix tool-call of DeepSeek R1 Qwen, return reasoning_content (Command 7RB & DeepSeek R1) unless --reasoning-format none #11607

Merged
merged 94 commits into from
Feb 13, 2025

Conversation

ochafik
Copy link
Collaborator

@ochafik ochafik commented Feb 3, 2025

(non-streaming api only for now).

Usage

  • Get and build this PR's branch
    git clone https://github.com/ggerganov/llama.cpp
    cd llama.cpp
    git remote add ochafik https://github.com/ochafik/llama.cpp
    git fetch ochafik
    git checkout ochafik/r1-toolcall
    cmake -B build -DLLAMA_CURL=1
    cmake --build build -t llama-server --parallel --config Release
    alias llama-server=./build/bin/llama-server
  • Run with (add --verbose to inspect prompt / grammars used):

    # For DeepSeek only, optional chat template override (please report results w/ and w/o)
    llama-server --jinja --think -hf bartowski/DeepSeek-R1-Distill-Qwen-32B-GGUF:Q6_K_L \
      --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
    llama-server --jinja --think -hf bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q6_K_L \
      --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja
    
    # Command R7B also has native think tokens:
    llama-server --jinja -fa --think -hf bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L \
      --chat-template-file <( python scripts/get_chat_template.py CohereForAI/c4ai-command-r7b-12-2024 tool_use )
    
    # Try adding  ‘--reasoning-format none’ To disable thoughts extraction
  • Call the API and profit

    curl http://localhost:8080/v1/chat/completions -d '{
      "messages": [
        {
          "role": "system",
          "content": "You are a tool calling agent."
        },
        {
          "role": "user",
          "content": "scrape https://ochafik.com and tell me the number of times it says Olivier"
        }
      ],
      "tools": [
        {
          "type": "function",
          "function": {
            "name": "python",
            "description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
            "parameters": {
              "type": "object",
              "properties": {
                "code": {
                  "type": "string",
                  "description": "The code to run in the ipython interpreter."
                }
              },
              "required": [
                "code"
              ]
            }
          }
        }
      ],
      "temperature": 0.0,
      "top_k": 1,
      "top_p": 1.0
    }' | jq
  • Show result w/ `DeepSeek-R1-Distill-Qwen-32B-GGUF:Q6_K_L`
    {
      "choices": [
        {
          "finish_reason": "tool_calls",
          "index": 0,
          "message": {
            "role": "assistant",
            "reasoning_content": "Okay, so the user wants me to scrape the website https://ochafik.com and find out how many times the name \"Olivier\" appears. Hmm, I need to figure out how to approach this. \n\nFirst, I remember that web scraping can be done using Python libraries like requests and BeautifulSoup. Requests can fetch the webpage content, and BeautifulSoup can parse the HTML to extract text. \n\nI should start by writing a Python script that sends a GET request to the URL. Then, I'll check if the request was successful. If it is, I'll use BeautifulSoup to parse the content. \n\nNext, I'll extract all the text from the parsed HTML. Once I have the text, I can split it into words and count how many times \"Olivier\" appears. But wait, I should make sure the case matches. Maybe the name is in different cases, but the user specifically asked for \"Olivier,\" so I'll keep it case-sensitive.\n\nI also need to handle any potential errors, like if the website doesn't respond or if there's an issue parsing the HTML. But since I'm using a tool that runs the code, I can let it handle those exceptions and just report the result.\n\nPutting it all together, the code will fetch the webpage, parse it, extract the text, split it into words, and count the occurrences of \"Olivier.\" Then, it'll print the result. I think that should do it.",
            "content": null,
            "tool_calls": [
              {
                "type": "function",
                "function": {
                  "name": "python",
                  "arguments": "{\"code\":\"import requests\\nfrom bs4 import BeautifulSoup\\n\\nurl = 'https://ochafik.com'\\nresponse = requests.get(url)\\n\\nif response.status_code == 200:\\n    soup = BeautifulSoup(response.text, 'html.parser')\\n    text = soup.get_text()\\n    words = text.split()\\n    count = words.count('Olivier')\\n    print(f'The name \\\"Olivier\\\" appears {count} times.')\\nelse:\\n    print('Failed to retrieve the webpage.')\"}"
                },
                "id": ""
              }
            ]
          }
        }
      ]
    }

    Which is this code:

    import requests
    from bs4 import BeautifulSoup
    
    url = 'https://ochafik.com'
    response = requests.get(url)
    
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        text = soup.get_text()
        words = text.split()
        count = words.count('Olivier')
        print(f'The name "Olivier" appears {count} times.')
    else:
        print('Failed to retrieve the webpage.')

    Not too bad, but it didn't do lower-case and word split is a bit poor.

    Trying again w/ the following extra args to make the sampling greedy:

    curl ... {
      ...
      "temperature": 0.0,
      "top_k": 1,
      "top_p": 1.0
    }
    

    We have a winner:

    import requests
    from bs4 import BeautifulSoup
    
    url = 'https://ochafik.com'
    response = requests.get(url)
    
    if response.status_code == 200:
        soup = BeautifulSoup(response.text, 'html.parser')
        text = soup.get_text()
        count = text.lower().count('olivier')
        print(f'The name "Olivier" appears {count} times.')
    else:
        print('Failed to retrieve the webpage.')

    And the thoughts:

    Okay, so the user wants me to scrape the website https://ochafik.com and find out how many times it mentions "Olivier." Hmm, I need to figure out how to do that. Since I can use the Python function tool, I can write a script that does this.

    First, I should think about how to scrape a website. I remember that using libraries like requests and BeautifulSoup can help. Requests can fetch the webpage content, and BeautifulSoup can parse the HTML to extract text.

    So, the steps would be: send a GET request to the URL, check if the request was successful, parse the HTML content, extract all the text, and then count the occurrences of the word "Olivier."

    Wait, but sometimes the word might be part of a larger word, like "Olivier's" or "Oliviering." The user probably wants exact matches, so I should make sure to count only whole words. Maybe using a regular expression with word boundaries would be better.

    Also, I should handle any potential errors, like if the website doesn't respond or if there's an issue with parsing. But since the user didn't specify handling errors, maybe I can just proceed with the basic script.

    Putting it all together, the Python code would import requests and BeautifulSoup, send the request, parse the content, extract the text, and then count the occurrences. I'll write the code accordingly and use the Python tool to execute it.

Implementation notes

  • Had to work around DeepSeek R1's official jinja template:
    • It doesn't describe the available tools, and the backfill done by Minja wasn't phrased well enough (for the 7B model), so I've added autogenerated tool call examples to minja's revamped "polyfill" behaviour (using a delta template eval).
      sync: minja #11641
    • It ignores message.tool_calls if message.content is not null, updated / testing the server output accordingly (better oai compliance)
    • After a tool result, it leaves the prompt hanging on a <|tool▁output▁end|> or <|tool▁call▁end|> (need to close the list of outputs / calls w/ plural <|tool▁outputs▁end|> / <|tool▁calls▁end|>, respectively, and then missing end of sentence + optional add_generation_prompt)
      • Hacked a workaround so the default template now works well with this branch
      • Added / documented better template (models/templates/llama-cpp-deepseek-r1.jinja)
  • DeepSeek R1 Distill Qwen 8B & 32B models seem to take liberties with their tool call start tag, so accepting variations of the syntax (which then triggers the lazy grammar / full compliance). I'd avoid the 8B (or its lower quants) if possible, it sometimes tries to skip the tools output tag.
  • NEW --reasoning-format flag, which controls output of reasoning_content in the API (see test_thoughts)
  • Updated tests:
    • Added the Q4_K_M quant to some (but not all) slow server tests (had to tell it not to overthink, bit... ironic).
    • Added slow tool result server test test_calc_result (checking models make some use of tool call results, which some struggle a bit with)
    • More corner cases for parsing, esp. DeepSeek R1 & Command R7B

TODOs:

Possible follow ups

  • Document differences between stream & non-stream modes (thought & tool_plan not sent in stream)

  • look at the Llama distill more closely (see Eval bug: trivial grammar crashes (DeepSeek R1 Distill Llama 8B) #11591)

  • Reintroduce forced thinking in generic handler under some --reasoning flag (+ explore @ngxson's idea to support a disabled value that biases thinking tags)

    ```typescript
    // ResponseSchema is json_schema if set, otherwise string
    
    type SchemaToolRequired     =  {thoughts: string} & ToolCallSchema
    type Schema                 = ({thoughts: string} & ToolCallSchema) | {thoughts: string, response: ResponseSchema}
    
    type ToolCallSchema         = SingleToolCallSchema | ParallelToolCallSchema
    type SingleToolCallSchema   = {tool_call: ToolCall}
    type ParallelToolCallSchema = {tool_calls: ToolCall[]} // If parallel_tool_calls is true
    
    // Note: id only defined if parallel_tool_calls is true
    type ToolCall =
          {name: 'tool_1', arguments: Parameters1Schema, id?: string} |
          {name: 'tool_2', arguments: Parameters2Schema, id?: string} |
          ...
    ```
    

@github-actions github-actions bot added testing Everything test related examples python python script changes server labels Feb 3, 2025
@ochafik ochafik mentioned this pull request Feb 9, 2025
@github-actions github-actions bot added the script Script related label Feb 9, 2025
@ochafik
Copy link
Collaborator Author

ochafik commented Feb 12, 2025

In streaming mode, the output data does not separate 'content' and 'reasoning content' like this:

data: {"choices":[{"finish_reason":null,"index":0,"delta":{"content":"<think>"}}],"created":1739000016,"id":"chatcmpl-QPiD7T4WVir86Qga3YHuhmJ0DO7hNQHK","model":"DeepSeek-R1-UD-IQ1_M","system_fingerprint":"b0-unknown","object":"chat.completion.chunk"}

data: {"choices":[{"finish_reason":null,"index":0,"delta":{"content":"\n"}}],"created":1739000017,"id":"chatcmpl-QPiD7T4WVir86Qga3YHuhmJ0DO7hNQHK","model":"DeepSeek-R1-UD-IQ1_M","system_fingerprint":"b0-unknown","object":"chat.completion.chunk"}

@WangxuP Based on my (limited) understanding of the delta format used by OpenAI (incl. for tool calls), the "correct" way to stream thoughts back would be to hold off on anything that might be an opening <think> tag, then send it as a reasoning_content delta. Hope we see how OpenAI stream their own thoughts in a near future (I have a few more things to crunch on before implementing streaming anyway).

| nvidia-Llama-3.1-Nemotron-70B-Instruct-HF.jinja | llama 3.x tool calls (w/ builtin tools) |
| openchat-openchat-3.5-0106.jinja | generic tool calls |
| teknium-OpenHermes-2.5-Mistral-7B.jinja | generic tool calls |
| Almawave-Velvet-14B.jinja | Hermes 2 Pro |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting here (no need to take any actions right now), but this README file is now too long and hard to follow for new users. I'm planning to break this into small files (like what we did with docs directory). Potentially we will end up with a main API docs, tool-calling docs and development docs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great!! happy to help with this (if only reviewing)

// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
// so we accept common variants (then it's all constrained)
builder.add_rule("root",
"( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" ) "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nits, if you're doing multiple string concatenations, it's better to use std::ostringstream to reduce the number of copy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point, for now I've been favouring readability but will keep this in mind when doing an optimization pass (depending on how much this all ends up costing, we might want to cache various bits and/or create a grammar DSL that would bypass the string stage altogether; JSON schema conversion has lots of room for optimization & I'd also like to take the llguidance stuff into account: exciting prospects!)

common/chat.cpp Outdated
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool["function"];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use .at() instead of operator[] when it's possible, as explained in https://github.com/nlohmann/json

In function from_json, use function at() to access the object values rather than operator[]. In case a key does not exist, at throws an exception that you can handle, whereas operator[] exhibits undefined behavior

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks!

@ngxson ngxson requested a review from ggerganov February 12, 2025 22:35
@ochafik ochafik merged commit c7f460a into ggml-org:master Feb 13, 2025
48 checks passed
@ngxson
Copy link
Collaborator

ngxson commented Feb 13, 2025

That’s next on my list, wanted to get non-streamed logic working well first, then will need to revamp the parsers to accept eg unclosed json list of “parallel” tool calls and stream them back one by one (bit of delta book keeping to do, tool call deltas give updates to the arguments for the current tool call, then move to the next, etc). Medium amount of work but probably gnarly haha.

Yes I would say this will be a hard approach. Specially because each model has their own format, so we can't really rely on regex much in stream mode.

Indeed, I assume that most of the complication will be about moving away from regex and use some kind of "state machine" to keep track of the generated text. From this perspective, I'm wondering, is it worth inventing our own implementation of regex? Regex is just a state machine under the hood, so by doing this we can fully manage the state on our own.

Written as (pseudo-) code, my idea looks like:

struct chat_regex tool_regex;
tool_regex.add_literal("<tool_name>")
tool_regex.add_string()
tool_regex.add_literal("</tool_name><tool_data>")
tool_regex.add_json()
tool_regex.add_literal("</tool_data>")
tool_regex.end()

Which will be compiled into:

flowchart LR

F@{ shape: dbl-circ, label: "end" }

A --&lt;tool_name&gt;--> B
B --string--> C
C --&lt;/tool_name&gt;&lt;tool_data&gt;--> D
D --json--> E
E --&lt;/tool_data&gt;--> F
Loading

We create a "state" object each time we want to use it (i.e. store it into the server slot):

slot.chat_parser_state = chat_parser_state(tool_regex); // initial state A
slot.chat_parser_state << slot.generated_text; // with "generated_text" is the "delta" generated content
slot.chat_parser_state.get_matched(); // not sure yet what it should return

@ochafik
Copy link
Collaborator Author

ochafik commented Feb 13, 2025

From this perspective, I'm wondering, is it worth inventing our own implementation of regex? Regex is just a state machine under the hood, so by doing this we can fully manage the state on our own.

@ngxson Yesss!! 🫡🫡🫡🫡🫡

So, my original dream was to write a recursive descent / backtracking parser based on the existing GBNF grammars, and use a crude naming convention to extract data out of rules:

  • *-tool-calls-N
  • then nested *-tool-call-N
  • then nested *-tool-call-name-N & *-tool-call-arguments-N

(the N is there to allow alternative tool call syntaxes to be extracted).

A bit adhoc and magic, but very limited modifications needed in the tool call code (just stick to a consistent naming, and delete all the regexp code) and a pretty simple parser to implement (can add some hooks to make it generic wrt/ the naming convention to extract any kind of data).

It would also make it trivial to support partial extractions / streaming by memorizing the parsing stack (& extracted variables) that consumed the longest text (when parsing fails).

(and +1 to keeping the state in the slot, although TBD whether that should be a parsing stack state - first stack that failed because of an EOF? - or just the JSON tree of the last response returned, doing a React-style full-DOM diff at every step; much slower but might be safer, to be investigated)

If we agree to explore this route, I might start by refactoring the grammar parsing code to output an easier intermediate grammar AST that can then be used directly by the recursive descent parser (and be trivially converted to the pointer-heavy sampling grammar structure).

@ochafik
Copy link
Collaborator Author

ochafik commented Feb 13, 2025

Written as (pseudo-) code, my idea looks like:

struct chat_regex tool_regex;
tool_regex.add_literal("<tool_name>")
tool_regex.add_string()
tool_regex.add_literal("</tool_name><tool_data>")
tool_regex.add_json()
tool_regex.add_literal("</tool_data>")
tool_regex.end()

@ngxson we could also explore this kind of syntax to build a DSL to create the dual-use GBNF grammar (possibly also llguidance grammar)

cc/ @mmoskal @HanClinto

@ngxson
Copy link
Collaborator

ngxson commented Feb 13, 2025

(and +1 to keeping the state in the slot, although TBD whether that should be a parsing stack state - first stack that failed because of an EOF? - or just the JSON tree of the last response returned, doing a React-style full-DOM diff at every step; much slower but might be safer, to be investigated)

I don't really understand the second part of your phrase about "first stack that failed because of an EOF", but IMO storing the parsing stack is fine. The React-style diff may sounds intuitive/safer, but I think nlohmann::json is not performant enough to do that efficiently. I'm even doubt that we may end up with a implementation slower than the javascript version used by react.

we could also explore this kind of syntax to build a DSL to create the dual-use GBNF grammar (possibly also llguidance grammar)

I have a quick look at all of the regex you're currently using in chat.cpp, but I think a DSL is not very needed at the moment because most of your regex can be expressed in a more intuitive way using my pseudo-code above. Furthermore, the maintenance cost may be high, given that we only gonna use it internally.

Most of your regex(es) use [\\s\\n\\r], [^something]+, ([\\s\\S\\r\\n]*?), which can be expressed as cpp functions like maybe_space(), match_until(...)

And to make it look even nicer, we can use cpp operator overloading, for example with operator->:

tool_regex -> literal("<tool_name>") -> string() -> literal("</tool_name>");

Another benefit of this approach is that some expressions can also be optimized during compile time.


Edit: on second thought, using -> could be a bad idea because it can be confused with pointer dereference. >> or << would be a better choice. Or maybe just chain call tool_regex.literal(...).string().literal(...) for simplicity

@mmoskal
Copy link
Collaborator

mmoskal commented Feb 14, 2025

@ochafik right! llguidance already does support streaming and emitting capture groups for subgrammars; it will even know to only emit "foo" when the tokens so far are "foo<", but then emit "<text" when "text" is sampled (and not "tool").

There is also some code in there to support general stop regexes using a (lazy) state machine.

Note that as people develop the tool calling more in models, they are likely to use special tokens for tool calling, JSON mode etc. Not sure gbnf handles that particularly well (that is the difference between "<|foo|>" and "<|" "foo" "|>").

@ggerganov
Copy link
Member

If we agree to explore this route, I might start by refactoring the grammar parsing code to output an easier intermediate grammar AST that can then be used directly by the recursive descent parser (and be trivially converted to the pointer-heavy sampling grammar structure).

Most of the grammar and tool functionalities are way over my head and I cannot provide a very meaningful feedback. But overall I think the llama-grammar module could use some deeper refactoring and maintenance. The main things I would lookout for is to keep the implementation simple, no extra dependencies and good performance. The general approach has been to prototype stuff in libcommon and when it becomes mature enough, to move it to libllama.

One more thought is that long-term we can also think about moving some core functionality about grammars to ggml. At some point I was considering it because I wanted to reuse grammar functionality in whisper.cpp. So it's something to think about, but very low-prio atm.

@mmoskal
Copy link
Collaborator

mmoskal commented Feb 15, 2025

@ochafik here's how the lazy matching is handled in llguidance, see also docs

import llguidance
import huggingface_hub
import json

lark_grm = """

start: "<tool_name>" name "<tool_data>" data "</tool_data>"
name[capture, suffix="</tool_name>"]: /.*/
data[capture]: %json {
    "properties": {
        "foo": { "type": "string" }
    },
    "required": ["foo"]
}

"""


def main():
    tok_name = huggingface_hub.hf_hub_download(
        "microsoft/Phi-3.5-mini-instruct", "tokenizer.json"
    )
    with open(tok_name, "r") as f:
        text = f.read()
    tok = llguidance.LLTokenizer(text)

    interp = llguidance.LLInterpreter(
        tok,
        json.dumps({"grammars": [{"lark_grammar": lark_grm}]}),
        enable_ff_tokens=False,
        enable_backtrack=False,
        log_level=1,
    )
    interp.start_without_prompt()

    toks = tok.tokenize_str("<tool_name>foo<bar></tool_name><tool_data>{\"foo\": \"bar\"}</tool_data>")
    for t in toks:
        mask, r = interp.compute_mask()
        obj = json.loads(r)
        for p in obj["progress"]:
            if p["object"] != "text":
                print("\n  ", end="")
                print(p)
        # feeding token now
        print(tok.dbg_tokens([t]), end=" ")
        interp.commit_token(t)
    print("\n")

if __name__ == "__main__":
    main()

When you run it, you get:

⟦<⟧ ⟦tool⟧ ⟦_⟧ ⟦name⟧ ⟦>⟧ ⟦foo⟧ ⟦<⟧ ⟦bar⟧ ⟦></⟧ ⟦tool⟧ ⟦_⟧ ⟦name⟧ ⟦><⟧ 
  {'object': 'capture', 'name': 'name', 'str': 'foo<bar>', 'hex': '666f6f3c6261723e', 'log_prob': 0.0}
⟦tool⟧ ⟦_⟧ ⟦data⟧ ⟦>{⟧ ⟦"⟧ ⟦foo⟧ ⟦":⟧ ⟦ "⟧ ⟦bar⟧ ⟦"}⟧ 
  {'object': 'capture', 'name': 'data', 'str': '{"foo": "bar"}', 'hex': '7b22666f6f223a2022626172227d', 'log_prob': 0.0}
⟦</⟧ ⟦tool⟧ ⟦_⟧ ⟦data⟧ ⟦>⟧ 

The captures are generated immedietly after getting enough tokens.

If the model use special tokens, you need to write the grammar slightly differently:

start: <|assistant|> name <|end|> /\s*/ data
name[capture]: /.*/
data[capture]: %json {
    "properties": {
        "foo": { "type": "string" }
    },
    "required": ["foo"]
}

Note lack of suffix= on name - it will extend greedily, until it hits the <|end|> special token. Special tokens are never allowed by regular expressions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples python python script changes script Script related server testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants