我正在从专用网络计算梯度并将其应用于另一个主网络。然后,我将主服务器的权重复制到私有服务器上(听起来有些多余,但请耐心等待)。问题在于,每次迭代get_weights都会变慢,我甚至会用光内存。
def work(self, session):
with session.as_default(), session.graph.as_default():
self.private_net = ACNetwork()
state = self.env.reset()
while counter<TOTAL_TR_STEPS:
action_index, action_vector = self.get_action(state)
next_state, reward, done, info = self.env.step(action_index)
....# store the new data : reward, state etc...
if done == True:
# end of episode
state = self.env.reset()
a_grads, c_grads = self.private_net.get_gradients()
self.master.update_from_gradients(a_grads, c_grads)
self._update_worker_net() #this is the slow one
!!!!!!
这是使用get_weights的函数。
def _update_worker_net(self):
self.private_net.actor_t.set_weights(\
self.master.actor_t.get_weights())
self.private_net.critic.set_weights(\
self.master.critic.get_weights())
return
环顾四周,我发现了一条建议使用]的帖子>
K.clear_session()
while块的末尾(在!!!!!!段),因为以某种方式在图上添加了新节点(?!)。但是那个onle返回了一个错误:
AssertionError: Do not use tf.reset_default_graph() to clear nested graphs. If you need a cleared graph, exit the nesting and create a new graph.
是否有更快的重量转移方法?有没有一种方法可以不添加新节点(如果确实发生了这种情况?)
我正在从专用网络计算梯度并将其应用于另一个主网络。然后,我将主服务器的权重复制到私有服务器上(听起来有些多余,但请耐心等待)。 ...
通常会在向图中动态添加新节点时发生这种情况。情况示例: