在Spark中构造区分矩阵

问题描述 投票:0回答:1

我正在尝试使用spark构建区分矩阵,并且我很困惑如何以最佳方式进行。我是新来的火花。我举了一个小例子,说明我在下面尝试做什么。

区分矩阵构造示例:

鉴于数据集D:

+----+-----+------+-----+  
| id | a1  |  a2  | a3  |  
+----+-----+------+-----+  
|  1 | yes | high | on  |  
|  2 | no  | high | off |
|  3 | yes | low  | off |
+----+-----+------+-----+

我的区别表是

+-------+----+----+----+
| id,id | a1 | a2 | a3 |
+-------+----+----+----+
| 1,2   |  1 |  0 |  1 |
| 1,3   |  0 |  1 |  1 |
| 2,3   |  1 |  1 |  0 |
+-------+----+----+----+

即,只要属性ai有助于区分一对元组,区别表就有1,否则为0。

我的数据集是巨大的,我试图在spark中做到这一点。以下是我想到的方法:

  1. 使用嵌套for循环迭代RDD(数据集)的所有成员
  2. 使用对原始RDD的cartesian()转换并迭代所有RDD的成员以获得区别表。

我的问题是: 在第一种方法中,spark是否会在内部自动优化嵌套for循环设置以进行并行处理?

在第二种方法中,使用cartesian()会导致额外的存储开销来存储中间RDD。有没有办法避免这种存储开销并获得最终的区分表?

哪种方法更好,是否有其他方法可以有效地构建区分矩阵(空间和时间)?

scala apache-spark rdd
1个回答
0
投票

对于此数据帧:

scala> val df = List((1, "yes", "high", "on" ), (2,  "no", "high", "off"), (3, "yes",  "low", "off") ).toDF("id", "a1", "a2", "a3")
df: org.apache.spark.sql.DataFrame = [id: int, a1: string ... 2 more fields]

scala> df.show
+---+---+----+---+
| id| a1|  a2| a3|
+---+---+----+---+
|  1|yes|high| on|
|  2| no|high|off|
|  3|yes| low|off|
+---+---+----+---+

我们可以通过使用crossJoin来构建笛卡尔积。但是,列名称将是模糊的(我真的不知道如何轻松处理它)。为此做准备,让我们创建第二个数据帧:

scala> val df2 = df.toDF("id_2", "a1_2", "a2_2", "a3_2")
df2: org.apache.spark.sql.DataFrame = [id_2: int, a1_2: string ... 2 more fields]

scala> df2.show
+----+----+----+----+
|id_2|a1_2|a2_2|a3_2|
+----+----+----+----+
|   1| yes|high|  on|
|   2|  no|high| off|
|   3| yes| low| off|
+----+----+----+----+

在这个例子中,我们可以通过使用id < id_2进行过滤来获得组合。

scala> val xp = df.crossJoin(df2)
xp: org.apache.spark.sql.DataFrame = [id: int, a1: string ... 6 more fields]

scala> xp.show
+---+---+----+---+----+----+----+----+
| id| a1|  a2| a3|id_2|a1_2|a2_2|a3_2|
+---+---+----+---+----+----+----+----+
|  1|yes|high| on|   1| yes|high|  on|
|  1|yes|high| on|   2|  no|high| off|
|  1|yes|high| on|   3| yes| low| off|
|  2| no|high|off|   1| yes|high|  on|
|  2| no|high|off|   2|  no|high| off|
|  2| no|high|off|   3| yes| low| off|
|  3|yes| low|off|   1| yes|high|  on|
|  3|yes| low|off|   2|  no|high| off|
|  3|yes| low|off|   3| yes| low| off|
+---+---+----+---+----+----+----+----+


scala> val filtered = xp.filter($"id" < $"id_2")
filtered: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [id: int, a1: string ... 6 more fields]

scala> filtered.show
+---+---+----+---+----+----+----+----+
| id| a1|  a2| a3|id_2|a1_2|a2_2|a3_2|
+---+---+----+---+----+----+----+----+
|  1|yes|high| on|   2|  no|high| off|
|  1|yes|high| on|   3| yes| low| off|
|  2| no|high|off|   3| yes| low| off|
+---+---+----+---+----+----+----+----+

此时问题基本解决了。要获得最终表,我们可以在每个列对上使用when().otherwise()语句,或者像我在这里所做的那样使用UDF:

scala> val dist = udf((a:String, b: String) => if (a != b) 1 else 0)
dist: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,IntegerType,Some(List(StringType, StringType)))

scala> val distinction = filtered.select($"id", $"id_2", dist($"a1", $"a1_2").as("a1"), dist($"a2", $"a2_2").as("a2"), dist($"a3", $"a3_2").as("a3"))
distinction: org.apache.spark.sql.DataFrame = [id: int, id_2: int ... 3 more fields]

scala> distinction.show
+---+----+---+---+---+
| id|id_2| a1| a2| a3|
+---+----+---+---+---+
|  1|   2|  1|  0|  1|
|  1|   3|  0|  1|  1|
|  2|   3|  1|  1|  0|
+---+----+---+---+---+
© www.soinside.com 2019 - 2024. All rights reserved.