spark scala-按数组分组

问题描述:

我是火花scala的新手。感谢您的帮助。
我有一个数据框

I am very new to spark scala. Appreciate your help.. I have a dataframe

val df = Seq(
  ("a", "a1", Array("x1","x2")), 
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
).toDF("k1", "k2", "k3")

我正在寻找一种按k1和k3对其进行分组并将k2收集到数组中的方法。
但是,k3是一个数组,我需要对分组应用包含(而不是精确的
匹配)。换句话说,我正在寻找类似这样的结果

I am looking for a way to group it by k1 and k3 and collect k2 in an array. However, k3 is an array and I need to apply contains (rather than exact match) for the grouping. In other words, I am looking for a result something like this

k1   k3       k2                count
a   (x1,x2)   (a1,b1,c1,e1)     4
a    (x3)      (d1)             1
c    (x2)      (c3)             1

有人可以建议如何实现吗?

Can somebody advise how to achieve this?

提前感谢!

我建议您按k1列分组收集k2和k3结构的列表将收集的列表传递给udf函数,以计算k3中的数组何时包含在另一个k3数组中并添加k2元素。

I would suggest you to group by k1 column, collect list of structs of k2 and k3, pass the collected list to a udf function for counting when an array in k3 is contained in another array of k3 and adding elements of k2.

然后您可以使用 explode select 表达式来得到所需的输出

Then you can use explode and select expressions to get the desired output

以下是完整的工作解决方案

Following is the complete working solution

val df = Seq(
  ("a", "a1", Array("x1","x2")),
  ("a", "b1", Array("x1")),
  ("a", "c1", Array("x2")),
  ("c", "c3", Array("x2")),
  ("a", "d1", Array("x3")),
  ("a", "e1", Array("x2","x1"))
  ).toDF("k1", "k2", "k3")

import org.apache.spark.sql.functions._
def containsGoupingUdf = udf((arr: Seq[Row]) => {
  val firstStruct =  arr.head
  val tailStructs =  arr.tail
  var result = Array((collection.mutable.Set(firstStruct.getAs[String]("k2")), firstStruct.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
  for(str <- tailStructs){
    var added = false
    for((res, index) <- result.zipWithIndex) {
      if (str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").exists(res._2) || res._2.exists(x => str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").contains(x))) {
        result(index) = (res._1 + str.getAs[String]("k2"), res._2 ++ str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, res._3 + 1)
        added = true
      }
    }
    if(!added){
      result = result ++ Array((collection.mutable.Set(str.getAs[String]("k2")), str.getAs[scala.collection.mutable.WrappedArray[String]]("k3").toSet, 1))
    }
  }
  result.map(tuple => (tuple._1.toArray, tuple._2.toArray, tuple._3))
})

df.groupBy("k1").agg(containsGoupingUdf(collect_list(struct(col("k2"), col("k3")))).as("aggregated"))
    .select(col("k1"), explode(col("aggregated")).as("aggregated"))
    .select(col("k1"), col("aggregated._2").as("k3"), col("aggregated._1").as("k2"), col("aggregated._3").as("count"))
  .show(false)

应该给您

+---+--------+----------------+-----+
|k1 |k3      |k2              |count|
+---+--------+----------------+-----+
|c  |[x2]    |[c3]            |1    |
|a  |[x1, x2]|[b1, e1, c1, a1]|4    |
|a  |[x3]    |[d1]            |1    |
+---+--------+----------------+-----+ 

我希望答案是有帮助的,您可以根据需要进行修改。

I hope the answer is helpful and you can modify it according to your needs.