[数据集API RecursionError:超过最大递归深度

问题描述 投票:0回答:1

我正在尝试使用tensorflow数据集API和地图功能构建可扩展的最小最大缩放器。

首先,我遍历我的数据集以找到所有特征(3)的最小值和最大值,然后我想使用map函数将最小值/最大值缩放器应用于该数据集。

这是我的简单代码。

import numpy as np
import tensorflow as tf

b = np.array([[1, 2, 3], [4, 5, 6], [7,8,9],[10,11,12]])
b_ds = tf.data.Dataset.from_tensor_slices(b).batch(2)

my_iterator = b_ds.make_one_shot_iterator()

def compute_min_max(i, my_min, my_max):
    new_batch = my_iterator.get_next()
    my_min = tf.minimum(my_min,tf.reduce_min(new_batch, axis=0))
    my_max = tf.maximum(my_max,tf.reduce_max(new_batch, axis=0))
    return [i+1, my_min, my_max]

i = tf.constant(0)
feat_min = tf.Variable([10,10,10],dtype=tf.int64)
feat_max = tf.Variable([0,0,0],dtype=tf.int64)

c = lambda i, min, max: i < 2
b = lambda i, min, max: compute_min_max(i, min, max)
res_i, res_min, res_max = tf.while_loop(c, b, loop_vars=[i, feat_min, feat_max])

def min_max_ds(feat):
    return tf.cast(feat-res_min,dtype=tf.float64)/tf.cast(res_max-res_min, dtype=tf.float64)

minmax_scaled_ds = b_ds.map(min_max_ds)

scaled_batch = minmax_scaled_ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    init=tf.global_variables_initializer()
    sess.run(init)
    print(sess.run((res_min, res_max, scaled_batch)))

执行此代码时,我得到一个

[RecursionError:超过最大递归深度

我的猜测是,min_max_ds函数大约每批调用tf.while_loop语句,但是我无法弄清楚如何冻结res_min和res_max,因此它们在min_max_ds函数中用作常量。

python-3.x tensorflow tensorflow-datasets
1个回答
0
投票

也许您可以使用以下方式设置递归深度的上限:

sys.setrecursionlimit(10000)

python3.X中的默认值为1000。可以使用更大的值。

© www.soinside.com 2019 - 2024. All rights reserved.