这是我的上一个问题的后续问题。我正在将参数化量子电路实现为量子神经网络,其中优化循环是抖动的。虽然没有错误,但一切正常,但我发现执行时间方面有一个非常不寻常的行为。
查看下面的代码:
import pennylane as qml
from pennylane import numpy as np
import jax
from jax import numpy as jnp
import optax
from itertools import combinations
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import log_loss
import matplotlib.pyplot as plt
import matplotlib.colors
import warnings
warnings.filterwarnings("ignore")
np.random.seed(42)
import time
# Load the digits dataset with features (X_digits) and labels (y_digits)
X_digits, y_digits = load_digits(return_X_y=True)
# Create a boolean mask to filter out only the samples where the label is 2 or 6
filter_mask = np.isin(y_digits, [2, 6])
# Apply the filter mask to the features and labels to keep only the selected digits
X_digits = X_digits[filter_mask]
y_digits = y_digits[filter_mask]
# Split the filtered dataset into training and testing sets with 10% of data reserved for testing
X_train, X_test, y_train, y_test = train_test_split(
X_digits, y_digits, test_size=0.1, random_state=42
)
# Normalize the pixel values in the training and testing data
# Convert each image from a 1D array to an 8x8 2D array, normalize pixel values, and scale them
X_train = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_train])
X_test = np.array([thing.reshape([8, 8]) / 16 * 2 * np.pi for thing in X_test])
# Adjust the labels to be centered around 0 and scaled to be in the range -1 to 1
# The original labels (2 and 6) are mapped to -1 and 1 respectively
y_train = (y_train - 4) / 2
y_test = (y_test - 4) / 2
def feature_map(features):
# Apply Hadamard gates to all qubits to create an equal superposition state
for i in range(len(features[0])):
qml.Hadamard(i)
# Apply angle embeddings based on the feature values
for i in range(len(features)):
# For odd-indexed features, use Z-rotation in the angle embedding
if i % 2:
qml.AngleEmbedding(features=features[i], wires=range(8), rotation="Z")
# For even-indexed features, use X-rotation in the angle embedding
else:
qml.AngleEmbedding(features=features[i], wires=range(8), rotation="X")
# Define the ansatz (quantum circuit ansatz) for parameterized quantum operations
def ansatz(params):
# Apply RY rotations with the first set of parameters
for i in range(8):
qml.RY(params[i], wires=i)
# Apply CNOT gates with adjacent qubits (cyclically connected) to create entanglement
for i in range(8):
qml.CNOT(wires=[(i - 1) % 8, (i) % 8])
# Apply RY rotations with the second set of parameters
for i in range(8):
qml.RY(params[i + 8], wires=i)
# Apply CNOT gates with qubits in reverse order (cyclically connected)
# to create additional entanglement
for i in range(8):
qml.CNOT(wires=[(8 - 2 - i) % 8, (8 - i - 1) % 8])
dev = qml.device("default.qubit", wires=8)
@qml.qnode(dev)
def circuit(params, features):
feature_map(features)
ansatz(params)
return qml.expval(qml.PauliZ(0))
def variational_classifier(weights, bias, x):
return circuit(weights, x) + bias
def square_loss(labels, predictions):
return np.mean((labels - qml.math.stack(predictions)) ** 2)
def accuracy(labels, predictions):
acc = sum([np.sign(l) == np.sign(p) for l, p in zip(labels, predictions)])
acc = acc / len(labels)
return acc
def cost(params, X, Y):
predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X]
return square_loss(Y, predictions)
def acc(params, X, Y):
predictions = [variational_classifier(params["weights"], params["bias"], x) for x in X]
return accuracy(Y, predictions)
np.random.seed(0)
weights = 0.01 * np.random.randn(16)
bias = jnp.array(0.0)
params = {"weights": weights, "bias": bias}
opt = optax.adam(0.05)
batch_size = 7
num_batch = X_train.shape[0] // batch_size
opt_state = opt.init(params)
X_batched = X_train.reshape([-1, batch_size, 8, 8])
y_batched = y_train.reshape([-1, batch_size])
@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no, print_training = args
_data = data[batch_no % num_batch]
_targets = targets[batch_no % num_batch]
train_loss, grads = jax.value_and_grad(cost)(params, _data, _targets)
updates, opt_state = opt.update(grads, opt_state)
test_loss, grads = jax.value_and_grad(cost)(params, X_test, y_test)
params = optax.apply_updates(params, updates)
# Print training loss every step if print_training is True
def print_fn():
jax.debug.print("Step: {i}, Train Loss: {train_loss}", i=i, train_loss=train_loss)
jax.debug.print("Step: {i}, Test Loss: {test_loss}", i=i, test_loss=test_loss)
jax.lax.cond((jnp.mod(i, 1) == 0) & print_training, print_fn, lambda: None)
return (params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no + 1, print_training)
@jax.jit
def optimization_jit(params, data, targets, X_test, y_test, X_train, y_train, print_training = True):
opt_state = opt.init(params)
args = (params, opt_state, data, targets, X_test, y_test, X_train, y_train, 0, print_training)
(params, _, _, _, _, _, _, _, _, _) = jax.lax.fori_loop(0, 1, update_step_jit, args)
return params
start_time = time.time()
params = optimization_jit(params, X_batched, y_batched, X_test, y_test, X_train, y_train)
print("Training Done! \nTime taken:",time.time() - start_time)
start_time = time.time()
var_train_acc = acc(params, X_train, y_train)
print("Training accuracy: ", var_train_acc)
print("Time taken:",time.time() - start_time)
start_time = time.time()
var_test_acc = acc(params, X_test, y_test)
print("Testing accuracy: ", var_test_acc)
print("Time taken:",time.time() - start_time)
请注意,它仅运行
jax.lax.fori_loop
1
时间。
为了重现性,我运行了3次验证了,输出如下,
第一次运行的输出:
Training Done!
Time taken: 66.26599097251892
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Training accuracy: 0.5031055900621118
Time taken: 14.183394193649292
Testing accuracy: 0.5277777777777778
Time taken: 1.552431344985962
第二次运行的输出:
Training Done!
Time taken: 62.8515682220459
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Training accuracy: 0.5031055900621118
Time taken: 13.549866199493408
Testing accuracy: 0.5277777777777778
Time taken: 1.5097148418426514
第三次运行的输出:
Training Done!
Time taken: 63.35235905647278
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Training accuracy: 0.5031055900621118
Time taken: 13.52238941192627
Testing accuracy: 0.5277777777777778
Time taken: 1.5074975490570068
所以,然后我运行它,将
jax.lax.fori_loop
更改为运行 10
次
(params, _, _, _, _, _, _, _, _, _) = jax.lax.fori_loop(0, 10, update_step_jit, args)
令人惊讶的是,执行时间显着减少,输出为:
第一次运行:
Training Done!
Time taken: 49.8694589138031
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Step: 1, Train Loss: 0.934578537940979
Step: 1, Test Loss: 0.9935969114303589
Step: 2, Train Loss: 0.982826828956604
Step: 2, Test Loss: 1.004722237586975
Step: 3, Train Loss: 0.982965350151062
Step: 3, Test Loss: 1.0281261205673218
Step: 4, Train Loss: 1.1700845956802368
Step: 4, Test Loss: 1.0455362796783447
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0411475896835327
Step: 6, Train Loss: 1.2408322095870972
Step: 6, Test Loss: 1.0204349756240845
Step: 7, Train Loss: 0.7292405366897583
Step: 7, Test Loss: 0.9959328770637512
Step: 8, Train Loss: 1.1697252988815308
Step: 8, Test Loss: 0.9822244644165039
Step: 9, Train Loss: 1.015731692314148
Step: 9, Test Loss: 0.9667297005653381
Training accuracy: 0.5217391304347826
Time taken: 13.903431177139282
Testing accuracy: 0.5555555555555556
Time taken: 1.537736177444458
第二次运行:
Training Done!
Time taken: 56.34339928627014
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Step: 1, Train Loss: 0.934578537940979
Step: 1, Test Loss: 0.9935969114303589
Step: 2, Train Loss: 0.982826828956604
Step: 2, Test Loss: 1.004722237586975
Step: 3, Train Loss: 0.982965350151062
Step: 3, Test Loss: 1.0281261205673218
Step: 4, Train Loss: 1.1700845956802368
Step: 4, Test Loss: 1.0455362796783447
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0411475896835327
Step: 6, Train Loss: 1.2408322095870972
Step: 6, Test Loss: 1.0204349756240845
Step: 7, Train Loss: 0.7292405366897583
Step: 7, Test Loss: 0.9959328770637512
Step: 8, Train Loss: 1.1697252988815308
Step: 8, Test Loss: 0.9822244644165039
Step: 9, Train Loss: 1.015731692314148
Step: 9, Test Loss: 0.9667297005653381
Training accuracy: 0.5217391304347826
Time taken: 13.298640727996826
Testing accuracy: 0.5555555555555556
Time taken: 1.4631397724151611
第三次运行:
Training Done!
Time taken: 53.01019215583801
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 1.0022056102752686
Step: 1, Train Loss: 0.934578537940979
Step: 1, Test Loss: 0.9935969114303589
Step: 2, Train Loss: 0.982826828956604
Step: 2, Test Loss: 1.004722237586975
Step: 3, Train Loss: 0.982965350151062
Step: 3, Test Loss: 1.0281261205673218
Step: 4, Train Loss: 1.1700845956802368
Step: 4, Test Loss: 1.0455362796783447
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0411475896835327
Step: 6, Train Loss: 1.2408322095870972
Step: 6, Test Loss: 1.0204349756240845
Step: 7, Train Loss: 0.7292405366897583
Step: 7, Test Loss: 0.9959328770637512
Step: 8, Train Loss: 1.1697252988815308
Step: 8, Test Loss: 0.9822244644165039
Step: 9, Train Loss: 1.015731692314148
Step: 9, Test Loss: 0.9667297005653381
Training accuracy: 0.5217391304347826
Time taken: 13.152780055999756
Testing accuracy: 0.5555555555555556
Time taken: 1.4448845386505127
此外,我想到减少日志记录,并希望每第 5 步计算并记录
test_loss
,将代码更新为:
@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no, print_training = args
_data = data[batch_no % num_batch]
_targets = targets[batch_no % num_batch]
train_loss, grads = jax.value_and_grad(cost)(params, _data, _targets)
updates, opt_state = opt.update(grads, opt_state)
# train_accuracy, grads = jax.value_and_grad(acc)(params, X_train, y_train)
# test_accuracy, grads = jax.value_and_grad(acc)(params, X_test, y_test)
params = optax.apply_updates(params, updates)
# Print training loss every 5 steps if print_training is True
def print_fn():
test_loss, grads = jax.value_and_grad(cost)(params, X_test, y_test)
jax.debug.print("Step: {i}, Train Loss: {train_loss}", i=i, train_loss=train_loss)
# jax.debug.print("Step: {i}, Train Accuracy: {train_accuracy}", i=i, train_accuracy=train_accuracy)
jax.debug.print("Step: {i}, Test Loss: {test_loss}", i=i, test_loss=test_loss)
# jax.debug.print("Step: {i}, Test Accuracy: {test_accuracy}", i=i, test_accuracy=test_accuracy)
jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)
return (params, opt_state, data, targets, X_test, y_test, X_train, y_train, batch_no + 1, print_training)
@jax.jit
def optimization_jit(params, data, targets, X_test, y_test, X_train, y_train, print_training = True):
opt_state = opt.init(params)
args = (params, opt_state, data, targets, X_test, y_test, X_train, y_train, 0, print_training)
(params, _, _, _, _, _, _, _, _, _) = jax.lax.fori_loop(0, 10, update_step_jit, args)
return params
我虽然调用
print_fn
的次数更少会导致运行时间更短,但不,输出是:
第一次运行:
Training Done!
Time taken: 75.2902774810791
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 0.9935969114303589
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0204349756240845
Training accuracy: 0.5217391304347826
Time taken: 13.591582536697388
Testing accuracy: 0.5555555555555556
Time taken: 1.6048238277435303
第二次运行:
Training Done!
Time taken: 86.21267819404602
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 0.9935969114303589
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0204349756240845
Training accuracy: 0.5217391304347826
Time taken: 13.666489601135254
Testing accuracy: 0.5555555555555556
Time taken: 1.5537452697753906
第三次运行:
Training Done!
Time taken: 90.7916328907013
Step: 0, Train Loss: 1.015419602394104
Step: 0, Test Loss: 0.9935969114303589
Step: 5, Train Loss: 1.3356019258499146
Step: 5, Test Loss: 1.0204349756240845
Training accuracy: 0.5217391304347826
Time taken: 13.21641230583191
Testing accuracy: 0.5555555555555556
Time taken: 1.5349321365356445
不同设置的运行时间可以绘制为:
我的问题是:
print_fn
函数)始终比设置 - 2(每次迭代时调用 print_fn
)花费更多时间。编译器以神秘的方式工作!
我怀疑这里的区别在于,在长度为 1
fori_loop
的情况下,编译器会优化掉 scan
;例如:
$ print(jax.jit(lambda x: jax.lax.fori_loop(0, 1, lambda i, x: x * 2, x)).lower(1.0).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f32[])->f32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}
ENTRY %main.25 (Arg_0.1: f32[]) -> f32[] {
%Arg_0.1 = f32[] parameter(0), metadata={op_name="x"}
%constant.3 = f32[] constant(2)
ROOT %multiply.1 = f32[] multiply(f32[] %Arg_0.1, f32[] %constant.3), metadata={op_name="jit(<lambda>)/jit(main)/while/body/mul" source_file="<ipython-input-10-f67509edfb4c>" source_line=1}
}
但是对于一个不平凡的 for 循环,
scan
并没有被优化掉:
$ print(jax.jit(lambda x: jax.lax.fori_loop(0, 10, lambda i, x: x * 2, x)).lower(1.0).compile().as_text())
HloModule jit__lambda_, is_scheduled=true, entry_computation_layout={(f32[])->f32[]}, allow_spmd_sharding_propagation_to_parameters={true}, allow_spmd_sharding_propagation_to_output={true}
%region_0.8 (arg_tuple.9: (s32[], f32[])) -> (s32[], f32[]) {
%constant.12 = s32[] constant(1)
%arg_tuple.9 = (s32[], f32[]) parameter(0)
%get-tuple-element.2 = s32[] get-tuple-element((s32[], f32[]) %arg_tuple.9), index=0
%add.14 = s32[] add(s32[] %get-tuple-element.2, s32[] %constant.12), metadata={op_name="jit(<lambda>)/jit(main)/while/body/add" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}
%constant.0 = f32[] constant(2)
%get-tuple-element.3 = f32[] get-tuple-element((s32[], f32[]) %arg_tuple.9), index=1
%multiply.0 = f32[] multiply(f32[] %get-tuple-element.3, f32[] %constant.0), metadata={op_name="jit(<lambda>)/jit(main)/while/body/mul" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}
ROOT %tuple.2 = (s32[], f32[]) tuple(s32[] %add.14, f32[] %multiply.0)
}
%region_1.16 (arg_tuple.17: (s32[], f32[])) -> pred[] {
%constant.20 = s32[] constant(10)
%arg_tuple.17 = (s32[], f32[]) parameter(0)
%get-tuple-element.18 = s32[] get-tuple-element((s32[], f32[]) %arg_tuple.17), index=0
ROOT %compare.21 = pred[] compare(s32[] %get-tuple-element.18, s32[] %constant.20), direction=LT, metadata={op_name="jit(<lambda>)/jit(main)/while/cond/lt" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}
}
ENTRY %main.25 (Arg_0.1: f32[]) -> f32[] {
%Arg_0.1 = f32[] parameter(0), metadata={op_name="x"}
%copy.6 = f32[] copy(f32[] %Arg_0.1)
%constant.2 = s32[] constant(0)
%copy.7 = s32[] copy(s32[] %constant.2)
%tuple = (s32[], f32[]) tuple(s32[] %copy.7, f32[] %copy.6)
%while.22 = (s32[], f32[]) while((s32[], f32[]) %tuple), condition=%region_1.16, body=%region_0.8, metadata={op_name="jit(<lambda>)/jit(main)/while" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}, backend_config={"known_trip_count":{"n":"10"}}
ROOT %get-tuple-element.24 = f32[] get-tuple-element((s32[], f32[]) %while.22), index=1, metadata={op_name="jit(<lambda>)/jit(main)/while" source_file="<ipython-input-11-d89e1a4ad053>" source_line=1}
}
这样做的结果是,即使对于这里的简单函数,编译器在第二种情况下会产生一些融合,而在第一种情况下则不会;在你的情况下,这些融合可能会导致更快的执行。
完美的编译器永远不会做出这样导致执行速度变慢的决定,但没有一个编译器是完美的。如果您愿意,您可以在 https://github.com/openxla/xla 报告此问题,但您可能希望在这样做之前尝试更最小化的复制。