jax 和 flax 不能很好地配合

问题描述 投票:0回答:1

我想实现一个具有多个 LSTM 门依次堆叠的神经网络。我将隐藏状态设置为 0,如此处所建议。当我尝试运行代码时,我得到了

JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)

当我尝试用 flax.linen.scan 替换 jax.lax.scan 时,它给出了另一个错误。不太确定如何继续或这里实际出了什么问题。代码附在下面。谢谢!

import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence


class LSTMModel(nn.Module):
lstm_hidden_size: int
num_lstm_layers: int
linear_layer_sizes: Sequence[int]
mean_aggregation: bool

def initialize_carry(self, batch_size, feature_size=1):
    """Initialize carry states with zeros for all LSTM layers."""
    return [
        (
            # Hidden state (h)
            jnp.zeros((batch_size, self.lstm_hidden_size)),
            # Cell state (c)
            jnp.zeros((batch_size, self.lstm_hidden_size)),
        )
        for _ in range(self.num_lstm_layers)
    ]

@nn.compact
def __call__(self, x, carry=None):
    if carry is None:
        raise ValueError(
            "Carry must be initialized explicitly using `initialize_carry`."
        )

    # Expand 2D input to 3D (if necessary)
    if x.ndim == 2:
        # [batch_size, sequence_length] -> [batch_size, sequence_length, 1]
        x = jnp.expand_dims(x, axis=-1)

    # Process through LSTM layers
    for i in range(self.num_lstm_layers):
        lstm_cell = nn.LSTMCell(
            features=self.lstm_hidden_size, name=f'lstm_cell_{i}')

        def step_fn(carry, xt):
            new_carry, yt = lstm_cell(carry, xt)
            return new_carry, yt

        # Use lax.scan to process the sequence
        carry[i], outputs = jax.lax.scan(step_fn, carry[i], x)
        x = outputs  # Update x for the next layer

    # Aggregate outputs
    if self.mean_aggregation:
        x = jnp.mean(x, axis=1)  # Average over the sequence
    else:
        x = x[:, -1, :]  # Use the last output

    # Pass through linear layers
    for size in self.linear_layer_sizes:
        x = nn.Dense(features=size)(x)
        x = nn.elu(x)

    # Final output layer
    x = nn.Dense(features=1)(x)
    return x


# Model hyperparameters
lstm_hidden_size = 64
num_lstm_layers = 2
linear_layer_sizes = [32, 16]
mean_aggregation = False

# Initialize model
model = LSTMModel(
    lstm_hidden_size=lstm_hidden_size,
    num_lstm_layers=num_lstm_layers,
    linear_layer_sizes=linear_layer_sizes,
    mean_aggregation=mean_aggregation
)

# Dummy input: batch of sequences with 10 timesteps
key = jax.random.PRNGKey(0)
# [batch_size, sequence_length, feature_size]
dummy_input = jax.random.normal(key, (32, 10, 1))

# Initialize carry states
carry = model.initialize_carry(
    batch_size=dummy_input.shape[0], feature_size=dummy_input.shape[-1])

# Initialize parameters
params = model.init(key, dummy_input, carry)

# Apply the model
outputs = model.apply(params, dummy_input, carry)

# Should print: [batch_size, 1]
print("Model output shape:", outputs.shape)
lstm recurrent-neural-network jit jax flax
1个回答
0
投票

考虑使用 nn.RNN 来简化代码:

lstm = nn.RNN(
  nn.LSTMCell(features=self.lstm_hidden_size),
  name=f'lstm_cell_{i}'
)
outputs = lstm(x)

RNN 将为您处理进位。如果你真的想自己处理携带,你可以使用

return_carry
initial_carry
:

lstm = nn.RNN(
  nn.LSTMCell(features=self.lstm_hidden_size),
  return_carry=True, 
  name=f'lstm_cell_{i}'
)
carry[i], outputs = lstm(x, initial_carry=carry[i])
© www.soinside.com 2019 - 2024. All rights reserved.