Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight tying #584

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ defmodule Axon.Compiler do
end)
end

defp merge_type(_, %Axon.ModelState.SharedParameter{}, value), do: value

defp merge_type(_, _, %Axon.ModelState.SharedParameter{} = shared), do: shared

defp merge_type(key, template, value) do
if Nx.type(template) != Nx.type(value) do
Logger.warning(
Expand Down Expand Up @@ -1061,20 +1065,7 @@ defmodule Axon.Compiler do
# freezing and dtype policy
parameter_inputs =
Enum.map(layer_params, fn %{name: v, frozen: frz} ->
param = params[name][v]

cond do
param != nil ->
safe_policy_cast(maybe_freeze(param, frz), policy, :compute)

true ->
raise ArgumentError,
"parameter #{inspect(v)} for layer: #{inspect(name)} in" <>
" was not present in the given parameter map, this can" <>
" happen if you are using parameters intended for another" <>
" model or did not initialize portions of your model with" <>
" Axon.init/3"
end
resolve_parameter!(params, name, v, frz, policy)
end)

# Reorder the inputs according to the original input ordering
Expand Down Expand Up @@ -1291,6 +1282,44 @@ defmodule Axon.Compiler do
initializer.(shape, type, keys[layer_id][name])
end

defp resolve_parameter!(params, layer_name, param_name, freeze?, policy) do
layer_params =
case params[layer_name] do
nil ->
raise ArgumentError, "layer #{inspect(layer_name)} does not exist in the model state"

%Axon.ModelState.SharedParameter{path: path} ->
get_in(params, path)

map ->
map
end

parameter =
case layer_params[param_name] do
nil ->
raise ArgumentError,
"parameter #{inspect(param_name)} for layer: #{inspect(layer_name)}" <>
" was not present in the given parameter map, this can" <>
" happen if you are using parameters intended for another" <>
" model or did not initialize portions of your model with" <>
" Axon.init/3"

%Axon.ModelState.SharedParameter{path: path} ->
with nil <- get_in(params, path) do
raise ArgumentError,
"shared parameter for #{inspect(param_name)} in layer:" <>
" #{inspect(layer_name)}, references non-existent parameter" <>
" #{inspect(path)}"
end

parameter ->
parameter
end

safe_policy_cast(maybe_freeze(parameter, freeze?), policy, :compute)
end

defp maybe_freeze(param, true), do: Nx.Defn.Kernel.stop_grad(param)
defp maybe_freeze(param, false), do: param

Expand Down
34 changes: 34 additions & 0 deletions lib/axon/model_state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,33 @@ defmodule Axon.ModelState do
}
end

@doc """
Ties parameters in the model state together.

Tied parameters should be a map destination parameter to
source. For example, if you want the kernel of an embedding
layer to use the kernel of a dense layer as it's source, you
would do:

Axon.ModelState.tie(model_state, ["embedding", "kernel"], ["dense", "kernel"])

You can tie individual parameters or entire layers together:

Axon.ModelState.tie(model_state, ["embedding"], ["kernel"])
"""
def tie(model_state, destination, source) do
update_in(model_state, [Access.key!(:data)], fn data ->
shared = Axon.ModelState.SharedParameter.new(source)
[key | rest] = Enum.reverse(destination)

shared = Enum.reduce(rest, %{key => shared}, fn next, acc ->
%{next => acc}
end)

tree_merge(shared, data, fn _, lhs, _ -> lhs end)
end)
end

# Helpers

defp get_paths(map) do
Expand Down Expand Up @@ -269,6 +296,10 @@ defmodule Axon.ModelState do
nil ->
Map.put(acc, key, val_lhs)

%Axon.ModelState.SharedParameter{} = val_rhs ->
new_val = fun.(key, val_lhs, val_rhs)
Map.put(acc, key, new_val)

%Nx.Tensor{} = val_rhs ->
new_val = fun.(key, val_lhs, val_rhs)
Map.put(acc, key, new_val)
Expand Down Expand Up @@ -321,6 +352,9 @@ defmodule Axon.ModelState do
{_, %Nx.Tensor{} = tensor}, {count, size} ->
{count + Nx.size(tensor), size + Nx.byte_size(tensor)}

{_, %Axon.ModelState.SharedParameter{}}, {count, size} ->
{count, size}

{_, map}, {count, size} ->
{inner_count, inner_size} = get_param_info(map)
{count + inner_count, size + inner_size}
Expand Down
19 changes: 19 additions & 0 deletions lib/axon/model_state/shared_parameter.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
defmodule Axon.ModelState.SharedParameter do
@moduledoc false

# Represents a tied or shared parameter for layers who's
# weights are connected but don't necessarily perform the
# same operation. This implements the Nx.Container behavior
# and contains an access path to the parameter that holds the
# original weight

@derive {
Nx.Container,
keep: [:path], containers: []
}
defstruct [:path]

def new(path) do
%__MODULE__{path: path}
end
end
55 changes: 53 additions & 2 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5465,15 +5465,24 @@
end

describe "edge cases" do
test "raises clean error on missing parameter" do
test "raises clean error on missing layer" do
model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2)
input = Nx.tensor([[1.0]])

assert_raise ArgumentError, ~r/parameter "kernel" for layer:/, fn ->
assert_raise ArgumentError, ~r/layer \"dense_0\" does not exist/, fn ->
Axon.predict(model, ModelState.empty(), input)
end
end

test "raises clean error on missing parameter" do
model = Axon.input("input", shape: {nil, 1}) |> Axon.dense(2)
input = Nx.tensor([[1.0]])

assert_raise ArgumentError, ~r/parameter \"kernel\" for layer:/, fn ->
Axon.predict(model, ModelState.new(%{"dense_0" => %{}}), input)
end
end

test "initializes a non-linear model" do
x = Axon.input("input_0", shape: {nil, 1}) |> Axon.dense(2, name: "dense_0")
y = Axon.input("input_1", shape: {nil, 1}) |> Axon.dense(2, name: "dense_1")
Expand Down Expand Up @@ -5718,7 +5727,7 @@
end

describe "inspect values" do
test "prints intermediate layer values to the screen" do

Check failure on line 5730 in test/axon/compiler_test.exs

View workflow job for this annotation

GitHub Actions / main (25.3.2.6, 1.14.5, USE_EXLA=true)

test inspect values prints intermediate layer values to the screen (CompilerTest)
model =
Axon.input("x")
|> Axon.dense(10, name: "foo")
Expand All @@ -5739,4 +5748,46 @@
assert out =~ "bar:"
end
end

describe "weight tying" do
test "initializes with shared parameters" do
model =
Axon.input("x")
|> Axon.embedding(32, 32, name: "embed")
|> Axon.dense(32, name: "dense")

init_state =
ModelState.empty()
|> ModelState.tie(["embed", "kernel"], ["dense", "kernel"])

{init_fn, _} = Axon.build(model)
input = Nx.template({1, 4}, :u32)
assert %Axon.ModelState{data: %{"embed" => %{"kernel" => %Axon.ModelState.SharedParameter{}}}} = init_fn.(input, init_state)
end

test "performs inference with weights tied after initialization" do

Check failure on line 5768 in test/axon/compiler_test.exs

View workflow job for this annotation

GitHub Actions / main (25.3.2.6, 1.14.5)

test weight tying performs inference with weights tied after initialization (CompilerTest)

Check failure on line 5768 in test/axon/compiler_test.exs

View workflow job for this annotation

GitHub Actions / main (25.3.2.6, 1.14.5, USE_EXLA=true)

test weight tying performs inference with weights tied after initialization (CompilerTest)

Check failure on line 5768 in test/axon/compiler_test.exs

View workflow job for this annotation

GitHub Actions / main (25.3.2.6, 1.14.5, USE_TORCHX=true)

test weight tying performs inference with weights tied after initialization (CompilerTest)
model =
Axon.input("x")
|> Axon.embedding(32, 32, name: "embed")
|> Axon.dense(32, name: "dense")

{init_fn, predict_fn} = Axon.build(model)

%Axon.ModelState{data: %{"dense" => %{"kernel" => k, "bias" => b}}} =
model_state = init_fn.(Nx.template({1, 4}, :u32), ModelState.empty())

model_state =
Axon.ModelState.tie(model_state, ["embed", "kernel"], ["dense", "kernel"])

input = Nx.tensor([[0, 1, 2, 3]])

actual_predict_fn = fn input, kernel, bias ->
input
|> Axon.Layers.embedding(kernel)
|> Axon.Layers.dense(kernel, bias)
end

assert_equal(actual_predict_fn.(input, k, b), predict_fn.(model_state, input))
end
end
end
Loading