Circuit tracing for chain-of-thought
Background and motivation
Significant progress has been made in tracing the “thoughts” of an LLM1. Circuit-tracing techniques enable step-by-step identification of the computational steps involved when a model predicts its next token. Recently, a state-of-the-art library circuit-tracer for circuit tracing using sparse transcoders was open-sourced2.
While the interpretability of next-token-prediction is an important first step, in real applications models generate long responses which require a sequence of computations. In a long response it may not be immediately obvious which tokens to circuit-trace, as it may not be obvious which tokens represent the “core” of the computations that led to the response being useful. This is especially true for reasoning-style models, as one way to motivate reasoning models is to argue that complex reasoning needs to be distributed through many forward passes (and therefore generated tokens) rather than being done in a single forward pass. Empirically, reasoning models generate a lot of “filler” tokens, backtracking from logical dead-ends, etc.
The figure3 below shows an example of step-by-step reasoning. An important computation was performed in generating the 8 token highlighted, but this was not the only important computation. The generation of the proceeding 0. tokens are arguably equally important. In general, by definition any multi-step computation will have multiple important tokens.
Moreover, it is known that chain-of-thought reasoning is not always faithful 45. This means that circuit analysis of an entire chain-of-thought is necessary, because we will not know ahead of time which tokens were generated unfaithfully, and thus we must examine all of them6. We cannot claim to have solved circuit-tracing for LLMs until we can produce a causal graph describing the important circuits used in generating an entire chain-of-thought.
Proposed method
One simple step in this direction is to concatenate the attribution graphs from the generation of each token in the response, and link the causal graphs associated with the generation of each token by considering the output token sampled at step $t$ as being the input to the input token at step $t+1$. In other words,
\[A(o_t \to i_{t+1}) = 1 \neq 0\]where $A$ is the adjacency matrix of the causal graph and $A(s \to t)$ means the value of the adjacency matrix which is the weight of the connection from source node $s$ to target node $t$7.
This construction is conceptually natural and is most powerful when combined with a pruning algorithm1 to remove all but the most important nodes of the resulting CoT graph. If the token generated at time-step $t$ is not important for subsequent token generation, it will naturally be pruned as an input token, and thus all nodes which were responsible for generating the associated output token will be pruned as well. Thus if we consider a CoT which contains 100 tokens, but only contains 5 critical computational steps, we would hope to end up with a graph which contains only the circuits used in those 5 steps, without having to have manually identify those critical steps.
Test Prompts
While developing this method it is helpful to have test prompts which tease-out different behavior. A useful combination is (i) a math problem which requires working step-by-step to get the answer correctly, and (ii) a hard math problem which the model cannot answer correctly even working step-by-step, but where the model is allowed to “cheat” by using a hint4.
Prompt (i) is a case where CoT is required and is faithful:
- If asked to answer quickly, the model should reliably give the wrong answer.
- If asked to answer step-by-step, the model should reliably give the correct answer.
- If asked to answer step-by-step with an incorrect hint given, the model should ignore the hint and still answer correctly. Prompt (ii) is a case where the model appears to use CoT but it is unfaithful:
- If asked to answer quickly, the model should reliably give the wrong answer.
- If asked to answer step-by-step, the model should reliably get the wrong answer with an attribution graphs that looks like guessing, but the CoT should look “plausible”.
- If asked to answer step-by-step with a hint, it should claim the hint is the answer but the causal graph should indicate that it is using the hint.
- If asked to answer step-by-step with an incorrect hint, it should get the wrong answer but the causal graph should indicate that it is using the hint.
In this work, I used the following Prompt (i) (following 3) which satisfies the properties above:
prompt = """<bos><start_of_turn>user
What is floor(4*(sqrt(0.64)))?<end_of_turn>
<start_of_turn>model
"""
Completions must be generated by an instruct/chat version of the model, and ideally would be generated by a “thinking”/”reasoning” version. This introduces a source of error, as the transcoders I used were from the base version. Future work could examine the errors introduced here, but I found this decision to be conceptually well-motivated and followed in the literature89. I used the instruct version (gemma-2-2b-it) for all model generations and feature attributions in this note.
A sample completion is shown below:
response = """Here's how to solve this problem:
**1. Calculate the square root:**
√0.64 = 0.8
**2. Multiply by 4:**
4 * 0.8 = 3.2
**3. Take the floor:**
Floor(3.2) = 3
**Therefore, floor(4*(sqrt(0.64))) = 3**"""
I leave investigation of a Prompt (ii) as described above for future work.
Implementation
I developed the following algorithm to generate CoT attribution graphs:
- Iterate over the length of the response. At each rollout step
t, generate an attribution graph for the generation of a single token.- Form an input sequence given by
input = prompt + response[:t+1]. - Run the
attributefunction from thecircuit-tracerlibrary. - Remove error nodes to make the adjacency matrix smaller and de-clutter the resulting graph10.
- Run the pruning algorithm using the
prune_graphfunction fromcircuit-tracer. - Remove all output logits except for the one which was sampled during generation, i.e. corresponding to the final input token at step
t+1.
- Form an input sequence given by
- Concatenate the adjacency matrix for the attribution graphs for each rollout step into an adjacency matrix for the attribution graph from the entire CoT by separately concatenating the sub-matrices corresponding to features, inputs tokens, and output logits.
- Add a connection between the (single) remaining output logit at rollout step
tto the final input token at rollout stept+1to “connect” the individual graphs together - Consolidate repeated input tokens by aggregating (summing) the contributions of inputs across repeated inputs. 11
- Prune the CoT graph using the
prune_graphfunction fromcircuit-tracer, which removes all extraneous features not associated with critical computations.
The figure below schematically illustrates this process. The first row shows schematics and an example of individual-token adjacency matrices, while the bottom row shows an example combined CoT matrix (with aggressive pruning, leading to only about 60 features).
See Appendix A for a detailed list of hyperparameters which seemed to work well.
Results
The figure below shows the full CoT attribution graph for Prompt (i) and its completion shown above. Each generated token is represented as an output, but only contributions leading to the final 10 tokens are kept explicitly in the pruning algorithm. All other generated tokens are kept only if they end up being relevant for generating the final 10 tokens; otherwise they are pruned by the pruning algorithm. (This is why there is a large concentration of features for the last 10 tokens of input, and a relatively smaller set of features associated with earlier input tokens.)
I found it challenging to trace the circuits and interpret the features corresponding to specific mathematical steps. However, I believe this is due to limitations of the small Gemma model. I was unable to find features associated with a floor operation, for example, when inspecting the the attribution graph for the next-token prediction for the sequence ending with Floor(3.2) = I found plenty of “say 3” features, but none that seemed to be associated with a floor operation. I do not believe this is a fundamental problem with this methodology, as this methodology assumes that feature identification and circuit tracing works properly at each single-token generation step. I believe that when applied to a more powerful model, e.g. Haiku 3.5, this method would work (subject to the limitations below).
Challenges and opportunities for future work
Conceptual
This approach uses the transcoder-based circuit-tracing methodology of 1 and suffers from the same drawbacks. In particular, attributions are not made for why attention patterns are formed, and there are still errors from imperfect reconstruction by the transcoders.
Scaling
With current circuit-tracing methods implemented in the circuit-tracer library2, this method is quite slow. For example, a single token generation step for a 2B parameter model can take on the order of 10 minutes to trace and prune on a single A100. The smallest useful CoTs are around 100 tokens (e.g. Prompt (i) above), and more realistic examples will have hundreds or thousands of tokens. This poor scaling limits practical uses of this method.
However, it seems likely that any circuit-tracing method for a CoT will require a circuit-tracing step during each token generation, as each forward-pass could be non-faithful6. Given this, it seems a requirement that future effort be directed to improving the efficiency of sequential circuit-tracing operations analogously to how KV-caching significantly increases efficiency of generating rollouts. With this improvement this method could scale just as well as the scaling of rollout generation and become more practical.
Limitations of the experiments and analysis
There are two main limitations of the work here. First, the size of the gemma-2-2b-it models limit the interpretability of the extracted features and circuits and limit the complexity of the prompts I could work with. Second, the SAEs I used were for the base pre-trained gemma-2-2brather than the instruct version. As noted in 89 this may be acceptable but is not ideal.
Appendix A: Hyperparameters
Model: gemma-2-2b-it Generation of individual graphs for single-token prediction:
-
max_feature_nodes: 50k. -
desired_logit_prob: 0.95 (anything above thetop_p=0.9used in generating the response should be sufficient). - graph pruning:
node_threshold: 0.4 - 0.6 was a reasonable number, which led to about 50 - 2.5k features per graph after pruning (from 50k initially). - graph_pruning:
edge_threshold: 1, as individual graph pruning is purely to reduce the size of the adjacency matrix there was no reason to prune edges.
Notes / References
-
https://transformer-circuits.pub/2025/attribution-graphs/methods.html ↩ ↩2 ↩3
-
https://transformer-circuits.pub/2025/attribution-graphs/biology.html#dives-cot ↩ ↩2
-
https://arxiv.org/abs/2503.08679 ↩
-
In principle, a misaligned model could be using meaning hidden in seemingly-benign tokens (e.g. punctuation) to communicate with itself or other models. ↩ ↩2
-
The actual numerical value of $A(o_t \to i_{t+1})$ does not matter as long as it is nonzero because the pruning algorithm normalizes each row (output) to have unit norm, and this entry is the only input for this output. So I set the value to 1 for simplicity. ↩
-
https://www.lesswrong.com/posts/fmwk6qxrpW8d4jvbd/saes-usually-transfer-between-base-and-chat-models ↩ ↩2
-
Since the goal was to generate causal attribution graphs for chain of thought, rather than interpret the accuracy or faithfulness of the graphs, I opted to remove the error nodes as they would make the graph much more cluttered. Also, I was using transcoders trained on the base version of the model while using the instruct version of the model for generation and attribution so there were more error nodes than would normally be present. ↩
-
Without consolidation, each token of the prompt would be repeated
N = len(response)times, for example. This consolidation step ensures each input token only appears once in the graph but preserves its influence on downstream tokens. This consolidation step also has the effect of making $A_{io} = n$, where $n$ is the number of times input $i$ was used downstream in the CoT and $o$ is the output logit which led to input $i$ at the next time-step. ↩
Enjoy Reading This Article?
Here are some more articles you might like to read next: