我有一个在内部循环中频繁调用的函数。看起来像这样:
import qualified Data.Vector.Storable as SV
newtype Timedelta = Timedelta Double
cklsLogDens :: SV.Vector Double -> Timedelta -> Double -> Double -> Double
cklsLogDens p (Timedelta dt) x0 x1 = if si <= 0 then -1e50 else c - 0.5*((x1-mu)/sd)^2
where
al = p `SV.unsafeIndex` 0
be = p `SV.unsafeIndex` 1
si = p `SV.unsafeIndex` 2
xi = p `SV.unsafeIndex` 3
sdt = sqrt dt
mu = x0 + (al + be*x0)*dt
sd = si * (x0 ** xi) * sdt
c = sd `seq` -0.5 * log (2*pi*sd^2)
(使用 Data.Vector.Storable 是因为该函数稍后需要处理来自 C 函数的数据)
GHC 对此进行了很好的优化(据我所知,所有变量和操作都是原语),但看看核心,有一个
let
仍然位于函数体(以前的)内部。我读过here(以及我不记得的其他地方),“让”分配惰性重击,因此可能不利于紧密循环中的性能。我可以摆脱它吗?如果可能的话,我不想将我的函数转换为 20 个 case 语句,但如果这要求太多,我会接受。
这是核心:
$wloop_s4Li [Occ=LoopBreaker]
:: GHC.Prim.Double#
-> GHC.Prim.Int# -> GHC.Prim.Int# -> GHC.Prim.Double#
[LclId, Arity=3, Str=DmdType LLL]
$wloop_s4Li =
\ (ww_X4OR :: GHC.Prim.Double#)
(ww1_X4OW :: GHC.Prim.Int#)
(ww2_X4P1 :: GHC.Prim.Int#) ->
case GHC.Prim.<# ww1_X4OW ww2_X4P1 of _ {
GHC.Types.False -> ww_X4OR;
GHC.Types.True ->
case GHC.Prim.<=## x_a4tg 0.0 of _ {
GHC.Types.False ->
case GHC.Prim.indexDoubleArray#
rb2_a4rT (GHC.Prim.+# rb_a4rR (GHC.Prim.-# ww1_X4OW 1))
of wild17_X4xM { __DEFAULT ->
let {
---- ^^^^ want to get rid off this!
----
----
ipv1_X2S8 [Dmd=Just L] :: GHC.Prim.Double#
[LclId, Str=DmdType]
ipv1_X2S8 =
GHC.Prim.*##
(GHC.Prim.*## x_a4tg (GHC.Prim.**## wild17_X4xM y_a3BN))
(GHC.Prim.sqrtDouble# tpl1_B3) } in
case GHC.Prim.logDouble#
(GHC.Prim.*##
6.283185307179586 (GHC.Prim.*## ipv1_X2S8 ipv1_X2S8))
of wild18_X3Gn { __DEFAULT ->
case GHC.Prim.indexDoubleArray#
rb2_a4rT (GHC.Prim.+# rb_a4rR ww1_X4OW)
of wild19_X4AY { __DEFAULT ->
case GHC.Prim./##
(GHC.Prim.-##
wild19_X4AY
(GHC.Prim.+##
wild17_X4xM
(GHC.Prim.*##
(GHC.Prim.+##
x1_X3GA (GHC.Prim.*## x2_X3cb wild17_X4xM))
tpl1_B3)))
ipv1_X2S8
of wild20_X3x8 { __DEFAULT ->
$wloop_s4Li
(GHC.Prim.+##
ww_X4OR
(GHC.Prim.-##
(GHC.Prim.negateDouble# (GHC.Prim.*## 0.5 wild18_X3Gn))
(GHC.Prim.*##
0.5 (GHC.Prim.*## wild20_X3x8 wild20_X3x8))))
(GHC.Prim.+# ww1_X4OW 1)
ww2_X4P1
}
}
}
};
GHC.Types.True ->
$wloop_s4Li
(GHC.Prim.+## ww_X4OR -1.0e50)
(GHC.Prim.+# ww1_X4OW 1)
ww2_X4P1
}
}; }
(是的,当然,既然你一定会问,我在过早优化上花了太多时间......)
这是带有 NOINLINE 的当前版本
import qualified Data.Vector.Storable as SV
newtype Timedelta = Timedelta Double
cklsLogDens :: SV.Vector Double -> Timedelta -> Double -> Double -> Double
{-# NOINLINE cklsLogDens #-}
cklsLogDens p (Timedelta dt) x0 x1 = si `seq` (if si <= 0 then -1e50 else (sd `seq` (c - 0.5*((x1-mu)/sd)^2)))
where
al = p `SV.unsafeIndex` 0
be = p `SV.unsafeIndex` 1
si = p `SV.unsafeIndex` 2
xi = p `SV.unsafeIndex` 3
sdt = sqrt dt
mu = x0 + (al + be*x0)*dt
sd = si * (x0 ** xi) * sdt
c = sd `seq` (-0.5 * log (2*pi*sd^2))
main = putStrLn . show $ cklsLogDens SV.empty (Timedelta 0.1) 0.1 0.15
对应的核心片段:
Main.cklsLogDens [InlPrag=NOINLINE]
:: Data.Vector.Storable.Vector GHC.Types.Double
-> Main.Timedelta
-> GHC.Types.Double
-> GHC.Types.Double
-> GHC.Types.Double
[GblId, Arity=4, Caf=NoCafRefs, Str=DmdType U(ALL)LLL]
Main.cklsLogDens =
\ (p_atw :: Data.Vector.Storable.Vector GHC.Types.Double)
(ds_dVa :: Main.Timedelta)
(x0_aty :: GHC.Types.Double)
(x1_atz :: GHC.Types.Double) ->
case p_atw
of _ { Data.Vector.Storable.Vector rb_a2ml rb1_a2mm rb2_a2mn ->
case GHC.Prim.readDoubleOffAddr#
@ GHC.Prim.RealWorld rb1_a2mm 2 GHC.Prim.realWorld#
of _ { (# s2_a2nH, x_a2nI #) ->
case GHC.Prim.touch#
@ GHC.ForeignPtr.ForeignPtrContents rb2_a2mn s2_a2nH
of _ { __DEFAULT ->
case GHC.Prim.<=## x_a2nI 0.0 of _ {
GHC.Types.False ->
case x0_aty of _ { GHC.Types.D# x2_a13d ->
case GHC.Prim.readDoubleOffAddr#
@ GHC.Prim.RealWorld rb1_a2mm 3 GHC.Prim.realWorld#
of _ { (# s1_X2oB, x3_X2oD #) ->
case GHC.Prim.touch#
@ GHC.ForeignPtr.ForeignPtrContents rb2_a2mn s1_X2oB
of _ { __DEFAULT ->
case ds_dVa
`cast` (Main.NTCo:Timedelta :: Main.Timedelta ~# GHC.Types.Double)
of _ { GHC.Types.D# x4_a13m ->
let {
--- ^^^^ want to get rid of this!
---
ipv_sYP [Dmd=Just L] :: GHC.Prim.Double#
[LclId, Str=DmdType]
ipv_sYP =
GHC.Prim.*##
(GHC.Prim.*## x_a2nI (GHC.Prim.**## x2_a13d x3_X2oD))
(GHC.Prim.sqrtDouble# x4_a13m) } in
case x1_atz of _ { GHC.Types.D# x5_X14E ->
case GHC.Prim.readDoubleOffAddr#
@ GHC.Prim.RealWorld rb1_a2mm 0 GHC.Prim.realWorld#
of _ { (# s3_X2p2, x6_X2p4 #) ->
case GHC.Prim.touch#
@ GHC.ForeignPtr.ForeignPtrContents rb2_a2mn s3_X2p2
of _ { __DEFAULT ->
case GHC.Prim.readDoubleOffAddr#
@ GHC.Prim.RealWorld rb1_a2mm 1 GHC.Prim.realWorld#
of _ { (# s4_X2pi, x7_X2pk #) ->
case GHC.Prim.touch#
@ GHC.ForeignPtr.ForeignPtrContents rb2_a2mn s4_X2pi
of _ { __DEFAULT ->
case GHC.Prim.logDouble#
(GHC.Prim.*## 6.283185307179586 (GHC.Prim.*## ipv_sYP ipv_sYP))
of wild9_a13D { __DEFAULT ->
case GHC.Prim./##
(GHC.Prim.-##
x5_X14E
(GHC.Prim.+##
x2_a13d
(GHC.Prim.*##
(GHC.Prim.+## x6_X2p4 (GHC.Prim.*## x7_X2pk x2_a13d)) x4_a13m)))
ipv_sYP
of wild10_a13O { __DEFAULT ->
GHC.Types.D#
(GHC.Prim.-##
(GHC.Prim.negateDouble# (GHC.Prim.*## 0.5 wild9_a13D))
(GHC.Prim.*## 0.5 (GHC.Prim.*## wild10_a13O wild10_a13O)))
}
}
}
}
}
}
}
}
}
}
};
GHC.Types.True -> lvl_r2v7
}
}
}
}
Daniel 是对的 - 事实上,所讨论的
let
并没有分配 thunk。这实际上是不可能的,因为像 Double#
这样的原始类型没有堆表示。在所谓的核心准备阶段,这些 let
实际上在转换为 STG 之前先转换为 case
表达式(这是“let
= 分配”规则实际成立的地方)。请参阅 CorePrep.lhs 中对此主题的评论。
这又是准备之前的核心内容(
-ddump-simpl
):
let {
ipv_sPL [Dmd=Just L] :: GHC.Prim.Double#
ipv_sPL =
GHC.Prim.*##
(GHC.Prim.*## x_a160 (GHC.Prim.**## x1_a11G x2_X17h))
(GHC.Prim.sqrtDouble# x3_a11P) } in [...]
这是之后(
-ddump-prep
):
case GHC.Prim.sqrtDouble# x3_s1aU of sat_s1cB { __DEFAULT ->
case GHC.Prim.**## x1_s1aQ x2_s1aR of sat_s1cC { __DEFAULT ->
case GHC.Prim.*## x_s1aC sat_s1cC of sat_s1cD { __DEFAULT ->
case GHC.Prim.*## sat_s1cD sat_s1cB of ipv_s1aW [Dmd=Just L] { __DEFAULT ->
所以实际上没有任何堆分配。
另一方面,请注意,核心准备工作还显式地将每个应用程序包装到
let
或 case
语句中,从而生成相当冗长的代码。这就是为什么 -ddump-simpl
可能被认为是查看 Core 的默认值,尽管它的性能模型实际上稍微令人惊讶。
使用 ghc-7.6.1,我发现
-O
和 -O2
之间没有区别,任何 seq
或爆炸模式也没有区别。let
保留在核心中。
但我怀疑
let
是否真的有害,它绑定一个原始值,而不是装箱值,并且该值此后在三个地方使用。此外,在生成的汇编中,我找不到任何懒惰的thunk的迹象(但由于我对汇编的了解相当有限,所以不要将此视为福音)。
我可以通过引入案例分支来摆脱
let
,
cklsLogDens p (Timedelta dt) x0 x1
= case p `SV.unsafeIndex` 2 of
si | si <= 0 -> -1e50
| otherwise ->
let al = p `SV.unsafeIndex` 0
be = p `SV.unsafeIndex` 1
xi = p `SV.unsafeIndex` 3
sdt = sqrt dt
mu = x0 + (al + be*x0)*dt
in case si*(x0**xi)*sdt of
0 -> 0
sd -> -0.5*log (2*pi*sd^2) - 0.5*((x1-mu)/sd)^2
仅在核心中产生
case
。由于 sd
永远不应该为 0,在循环中,即使是平庸的分支预测器也应该使该分支本质上自由。
但是,我怀疑这是否真的会提高性能。与 0 的比较需要一个寄存器,原始产生的汇编需要更少的间接寻址,并且可以在需要时在寄存器中保留更多值。