是否可以使用随机索引列表直接覆盖 DenseMatrix 行的选定列

问题描述 投票:0回答:1

我正在尝试在 scala/breeze 中实现以下 python 代码:

import numpy as np
mat = np.random.normal(size=(2, 5))
print(mat)
indexes = np.random.choice(5, replace = False, size = 3)
print(indexes)
mat[0, [indexes]] = 0
print(mat)

# the output:
# ./rowsliceOverwrite.py
[[ 0.30389599  0.84549682 -0.38408994 -1.11550844 -0.28496995]
 [-1.55260273 -0.41368681 -0.40455289  0.13054527 -1.43541557]]
[1 3 4]
[[ 0.30389599  0.         -0.38408994  0.          0.        ]
 [-1.55260273 -0.41368681 -0.40455289  0.13054527 -1.43541557]]

目标是直接将单个表达式中 DenseMatrix 行的选定索引归零。

这是我的 scala 尝试,后面是错误消息:

import breeze.linalg.*
import breeze.stats.*

def main(args: Array[String]): Unit =
  val mat = DenseMatrix(
    (-0.25010575, 0.44800905, 0.13285604,  0.34085698,  0.38346101),
    (-1.97209990, 1.37114368, 1.56601999, -0.13052228,  0.86001178)
  )
  println(mat)
  val indexes = IndexedSeq(1, 3, 4)
  println(indexes)
  mat(0, indexes) = 0.0
  println(mat)

# ./rowsliceOverwrite.sc
-- [E007] Type Mismatch Error: C:\Users\philwalk\workspace\tprf_py\.\rowsliceOverwrite.sc:14:9 -----------------------------------------------------------------------------------------------------------------------------------------------
14 |  mat(0, indexes) = 0
   |         ^^^^^^^
   |         Found:    (indexes : IndexedSeq[Int])
   |         Required: Int
   |
   | longer explanation available when compiling with `-explain`
1 error found
Errors encountered during compilation

只需3步即可完成:

  • 获取行
  • 覆盖转置行的列
  • 将行放回原处
import breeze.linalg.*
import breeze.stats.*

def main(args: Array[String]): Unit =
  val mat = DenseMatrix(
    (-0.25010575, 0.44800905, 0.13285604,  0.34085698,  0.38346101),
    (-1.97209990, 1.37114368, 1.56601999, -0.13052228,  0.86001178)
  )
  println(mat)
  val indexes = IndexedSeq(1, 3, 4)
  println(indexes)
  var row0 = mat(0, ::).t
  row0(indexes) := 0.0
  mat(0, ::) := row0.t
  println(mat)
# ./rowsliceOverwrite.sc
-0.25010575  0.44800905  0.13285604  0.34085698   0.38346101
-1.9720999   1.37114368  1.56601999  -0.13052228  0.86001178
Vector(1, 3, 4)
-0.25010575  0.0         0.13285604  0.0          0.0
-1.9720999   1.37114368  1.56601999  -0.13052228  0.86001178

它不像Python代码那样可读,如果有一种方法可以更直接地作为单个表达式来完成它,那就太好了。

scala slice overwrite scala-breeze
1个回答
0
投票

尝试了一段时间后,我发现了一个似乎有效的表达方式:

import breeze.linalg.*
import breeze.stats.*

def main(args: Array[String]): Unit =
  val mat = DenseMatrix(
    (-0.25010575, 0.44800905, 0.13285604,  0.34085698,  0.38346101),
    (-1.97209990, 1.37114368, 1.56601999, -0.13052228,  0.86001178)
  )
  println(mat)
  val indexes = IndexedSeq(1, 3, 4)
  println(indexes)
  mat(0, ::).t(indexes) := 0.0
  println(mat)

# the output:
# ./rowsliceOverwrite.sc
-0.25010575  0.44800905  0.13285604  0.34085698   0.38346101
-1.9720999   1.37114368  1.56601999  -0.13052228  0.86001178
Vector(1, 3, 4)
-0.25010575  0.0         0.13285604  0.0          0.0
-1.9720999   1.37114368  1.56601999  -0.13052228  0.86001178
© www.soinside.com 2019 - 2024. All rights reserved.