我不熟悉lua。但本文的作者使用了lua。
您能帮我理解这两行是做什么的:
做什么replicate(x,batch_size)
吗?
x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)
做什么?
原始源代码可以在这里找到https://github.com/wojzaremba/lstm/blob/master/data.lua
这基本上可以归结为简单的数学运算,并且可以在割炬手册中查找一些功能。
好吧我很无聊...
replicate(x,batch_size)
中定义的https://github.com/wojzaremba/lstm/blob/master/data.lua
-- Stacks replicated, shifted versions of x_inp
-- into a single matrix of size x_inp:size(1) x batch_size.
local function replicate(x_inp, batch_size)
local s = x_inp:size(1)
local x = torch.zeros(torch.floor(s / batch_size), batch_size)
for i = 1, batch_size do
local start = torch.round((i - 1) * s / batch_size) + 1
local finish = start + x:size(1) - 1
x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
end
return x
end
此代码正在使用Torch框架。
[x_inp:size(1)
返回Torch tensor(可能是多维矩阵)x_inp
的尺寸1的大小。
请参见https://cornebise.com/torch-doc-template/tensor.html#toc_18
因此x_inp:size(1)
给出了x_inp
中的行数。 x_inp:size(2)
,将为您提供列数...
local x = torch.zeros(torch.floor(s / batch_size), batch_size)
创建一个新的用零填充的二维张量,并为其创建本地引用,命名为x
行数由s
,x_inp
的行数和batch_size
计算得出。因此,对于您的示例输入,结果为floor(11/2) = floor(5.5) = 5
。
您的示例中的列数为2,因为batch_size
为2。
火炬
所以简单地说x
是5x2矩阵
0 0
0 0
0 0
0 0
0 0
以下几行将x_inp
的内容复制到x
。
for i = 1, batch_size do
local start = torch.round((i - 1) * s / batch_size) + 1
local finish = start + x:size(1) - 1
x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
end
[在第一次运行中,start
的值为1,finish
的值为5
,因为x:size(1)
当然是x
的行数,即5
。1+5-1=5
在第二轮中,start
计算为6,finish
计算为10
因此x_inp
的前5行(您的第一批)被复制到x
的第一列中,第二批将被复制到x
的第二列中
x:sub(1, x:size(1), i, i)
是x
的子张量,在行1到5,列1到1,在第二行行1-5,列2到2(在您的示例中)。因此,只不过是x
请参见https://cornebise.com/torch-doc-template/tensor.html#toc_42
:copy(x_inp:sub(start, finish))
将元素从x_inp
复制到x
的列中。
因此,总而言之,您将输入张量并分成若干批,然后将其存储在张量中,每批中都有一列。
因此x_inp
0
1
2
3
4
5
6
7
8
9
10
和batch_size = 2
x
是
0 5
1 6
2 7
3 8
4 9
进一步:
local function testdataset(batch_size)
local x = load_data(ptb_path .. "ptb.test.txt")
x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)
return x
end
是另一个从文件加载一些数据的函数。此x
与上面的x
无关,只是它们都是张量。
让我们使用一个简单的示例:
x
被
1
2
3
4
和batch_size = 4
x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)
[第一个x
将被调整为4x1,读取为https://cornebise.com/torch-doc-template/tensor.html#toc_36
然后通过将第一行复制3次将其扩展为4x4。
导致x
为张量
1 1 1 1
2 2 2 2
3 3 3 3
4 4 4 4
读取https://cornebise.com/torch-doc-template/tensor.html#toc_49