nnsight#

Interpretable Neural Networks

NNsight (/ɛn.saɪt/) is a package for interpreting and manipulating the internals of models.

NNsight Logo

Wrap Any PyTorch Model

The NNsight class object wraps a given PyTorch model, enabling tracing capabilities.

Get Started →
from nnsight import NNsight, LanguageModel

net = torch.nn.Sequential(OrderedDict([
    ('layer1', torch.nn.Linear(input_size, hidden_dims)),
    ('layer2', torch.nn.Linear(hidden_dims, output_size)),
]))

model = NNsight(net)

# or

transformer = LanguageModel('openai-community/gpt2')

Access Hidden States

Easily expose Module inputs and outputs.

Walkthrough →
with model.trace('Who invented neural networks?'):

    hidden_state_output = model.layer1.output.save()
    hidden_state_input = model.layer2.input.save()

    output = model.output.save()

print(hidden_state_output)
print(hidden_state_input)
print(output)

Develop Complex Interventions

Edit Module outputs and measure effects.

Tutorials →
with model.trace() as tracer:

    with tracer.invoke('The Eiffel Tower is in the city of'):

        model.transformer.h[-1].mlp.output[0][:] = 0

        intervention = model.lm_head.output.argmax(dim=-1).save()

    with tracer.invoke('The Eiffel Tower is in the city of'):

        original = model.lm_head.output.argmax(dim=-1).save()

print(output)