我正在尝试在 scala/breeze 中实现以下 python 代码:
import numpy as np
mat = np.random.normal(size=(2, 5))
print(mat)
indexes = np.random.choice(5, replace = False, size = 3)
print(indexes)
mat[0, [indexes]] = 0
print(mat)
# the output:
# ./rowsliceOverwrite.py
[[ 0.30389599 0.84549682 -0.38408994 -1.11550844 -0.28496995]
[-1.55260273 -0.41368681 -0.40455289 0.13054527 -1.43541557]]
[1 3 4]
[[ 0.30389599 0. -0.38408994 0. 0. ]
[-1.55260273 -0.41368681 -0.40455289 0.13054527 -1.43541557]]
目标是直接将单个表达式中 DenseMatrix 行的选定索引归零。
这是我的 scala 尝试,后面是错误消息:
import breeze.linalg.*
import breeze.stats.*
def main(args: Array[String]): Unit =
val mat = DenseMatrix(
(-0.25010575, 0.44800905, 0.13285604, 0.34085698, 0.38346101),
(-1.97209990, 1.37114368, 1.56601999, -0.13052228, 0.86001178)
)
println(mat)
val indexes = IndexedSeq(1, 3, 4)
println(indexes)
mat(0, indexes) = 0.0
println(mat)
# ./rowsliceOverwrite.sc
-- [E007] Type Mismatch Error: C:\Users\philwalk\workspace\tprf_py\.\rowsliceOverwrite.sc:14:9 -----------------------------------------------------------------------------------------------------------------------------------------------
14 | mat(0, indexes) = 0
| ^^^^^^^
| Found: (indexes : IndexedSeq[Int])
| Required: Int
|
| longer explanation available when compiling with `-explain`
1 error found
Errors encountered during compilation
只需3步即可完成:
import breeze.linalg.*
import breeze.stats.*
def main(args: Array[String]): Unit =
val mat = DenseMatrix(
(-0.25010575, 0.44800905, 0.13285604, 0.34085698, 0.38346101),
(-1.97209990, 1.37114368, 1.56601999, -0.13052228, 0.86001178)
)
println(mat)
val indexes = IndexedSeq(1, 3, 4)
println(indexes)
var row0 = mat(0, ::).t
row0(indexes) := 0.0
mat(0, ::) := row0.t
println(mat)
# ./rowsliceOverwrite.sc
-0.25010575 0.44800905 0.13285604 0.34085698 0.38346101
-1.9720999 1.37114368 1.56601999 -0.13052228 0.86001178
Vector(1, 3, 4)
-0.25010575 0.0 0.13285604 0.0 0.0
-1.9720999 1.37114368 1.56601999 -0.13052228 0.86001178
它不像Python代码那样可读,如果有一种方法可以更直接地作为单个表达式来完成它,那就太好了。
尝试了一段时间后,我发现了一个似乎有效的表达方式:
import breeze.linalg.*
import breeze.stats.*
def main(args: Array[String]): Unit =
val mat = DenseMatrix(
(-0.25010575, 0.44800905, 0.13285604, 0.34085698, 0.38346101),
(-1.97209990, 1.37114368, 1.56601999, -0.13052228, 0.86001178)
)
println(mat)
val indexes = IndexedSeq(1, 3, 4)
println(indexes)
mat(0, ::).t(indexes) := 0.0
println(mat)
# the output:
# ./rowsliceOverwrite.sc
-0.25010575 0.44800905 0.13285604 0.34085698 0.38346101
-1.9720999 1.37114368 1.56601999 -0.13052228 0.86001178
Vector(1, 3, 4)
-0.25010575 0.0 0.13285604 0.0 0.0
-1.9720999 1.37114368 1.56601999 -0.13052228 0.86001178