运行类函数时变量被覆盖

问题描述 投票:0回答:1

我有以下 MWE:

import numpy as np

class PlainLR:

    def __init__(self) -> None:
        
        rng = np.random.default_rng(42)
        init_range = 1.0 / np.sqrt(float(10))
        self.weights = rng.uniform(low=-init_range, high=init_range, size=10)
        self.bias = rng.uniform(low=-init_range, high=init_range, size=1)

        self.plaintext_weights = self.weights
        self.plaintext_bias = self.bias
    
    def get_model_parameters(self):
        return self.plaintext_weights,self.plaintext_bias
    
    def train(self):
        self.weights -= 3* (1 / 5) + self.weights * 10
        self.bias -= 4 * (1 / 5)

plaintextLR = PlainLR()
plaintext_weights, plaintext_bias = plaintextLR.get_model_parameters() # Grab the initial values here
print(plaintext_weights, plaintext_bias)
plaintextLR.train()
print(plaintext_weights, plaintext_bias) # I was expecting here the same output as in the previous print line

有人可以向我解释为什么代码末尾的两条打印行会产生不同的输出吗?我希望当我听到这条线时:

plaintext_weights, plaintext_bias = plaintextLR.get_model_parameters() # Grab model parameters before training the data to supply the exact same parameters to the encrypted LR model for comparison

我将初始化随机值(第7-10行)存储到变量

plaintext_weights, plaintext_bias
中,但似乎通过随后执行该类的函数
train
,它会覆盖我之前存储的初始值。有人可以向我解释为什么以及如何实现仅将初始值存储到变量中吗?

python class variables
1个回答
0
投票

使用

self.weights.copy
self.bias.copy
代替
self.weights
self.bias

此外,您根本不需要定义

self.plaintext_weights
self.plaintext_bias
,因为您只需返回
self.weights.copy
self.bias.copy
即可。

你的代码看起来像这样:

import numpy as np

class PlainLR:

    def __init__(self) -> None:
        
        rng = np.random.default_rng(42)
        init_range = 1.0 / np.sqrt(float(10))
        self.weights = rng.uniform(low=-init_range, high=init_range, size=10)
        self.bias = rng.uniform(low=-init_range, high=init_range, size=1)
    
    def get_model_parameters(self):
        return self.weights.copy(),self.bias.copy()
    
    def train(self):
        self.weights -= 3* (1 / 5) + self.weights * 10
        self.bias -= 4 * (1 / 5)

plaintextLR = PlainLR()
plaintext_weights, plaintext_bias = plaintextLR.get_model_parameters()
print(plaintext_weights, plaintext_bias)
plaintextLR.train()
print(plaintext_weights, plaintext_bias)
© www.soinside.com 2019 - 2024. All rights reserved.