我想实现一个具有多个 LSTM 门依次堆叠的神经网络。我将隐藏状态设置为 0,如此处所建议。当我尝试运行代码时,我得到了
JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
当我尝试用 flax.linen.scan 替换 jax.lax.scan 时,它给出了另一个错误。不太确定如何继续或这里实际出了什么问题。代码附在下面。谢谢!
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence
class LSTMModel(nn.Module):
lstm_hidden_size: int
num_lstm_layers: int
linear_layer_sizes: Sequence[int]
mean_aggregation: bool
def initialize_carry(self, batch_size, feature_size=1):
"""Initialize carry states with zeros for all LSTM layers."""
return [
# Hidden state (h)
jnp.zeros((batch_size, self.lstm_hidden_size)),
# Cell state (c)
jnp.zeros((batch_size, self.lstm_hidden_size)),
for _ in range(self.num_lstm_layers)
def __call__(self, x, carry=None):
if carry is None:
raise ValueError(
"Carry must be initialized explicitly using `initialize_carry`."
# Expand 2D input to 3D (if necessary)
if x.ndim == 2:
# [batch_size, sequence_length] -> [batch_size, sequence_length, 1]
x = jnp.expand_dims(x, axis=-1)
# Process through LSTM layers
for i in range(self.num_lstm_layers):
lstm_cell = nn.LSTMCell(
features=self.lstm_hidden_size, name=f'lstm_cell_{i}')
def step_fn(carry, xt):
new_carry, yt = lstm_cell(carry, xt)
return new_carry, yt
# Use lax.scan to process the sequence
carry[i], outputs = jax.lax.scan(step_fn, carry[i], x)
x = outputs # Update x for the next layer
# Aggregate outputs
if self.mean_aggregation:
x = jnp.mean(x, axis=1) # Average over the sequence
x = x[:, -1, :] # Use the last output
# Pass through linear layers
for size in self.linear_layer_sizes:
x = nn.Dense(features=size)(x)
x = nn.elu(x)
# Final output layer
x = nn.Dense(features=1)(x)
return x
# Model hyperparameters
lstm_hidden_size = 64
num_lstm_layers = 2
linear_layer_sizes = [32, 16]
mean_aggregation = False
# Initialize model
model = LSTMModel(
# Dummy input: batch of sequences with 10 timesteps
key = jax.random.PRNGKey(0)
# [batch_size, sequence_length, feature_size]
dummy_input = jax.random.normal(key, (32, 10, 1))
# Initialize carry states
carry = model.initialize_carry(
batch_size=dummy_input.shape[0], feature_size=dummy_input.shape[-1])
# Initialize parameters
params = model.init(key, dummy_input, carry)
# Apply the model
outputs = model.apply(params, dummy_input, carry)
# Should print: [batch_size, 1]
print("Model output shape:", outputs.shape)
考虑使用 nn.RNN 来简化代码:
lstm = nn.RNN(
outputs = lstm(x)
RNN 将为您处理进位。如果你真的想自己处理携带,你可以使用
和 initial_carry
lstm = nn.RNN(
carry[i], outputs = lstm(x, initial_carry=carry[i])