小能豆

如何在spark 1.5数据框中计算“自定义运行总计”

py

我将每个 LoanId 的贷款还款历史记录存储在 parquet 文件中,并尝试计算每笔贷款每个期间的“逾期”金额。如果不是计算到期金额的棘手性质,这将是一个简单的窗口分区任务。

如果客户支付的金额低于到期金额,则逾期金额会增加,另一方面,如果客户提前付款,则后续期间将忽略额外付款(下方示例中的第 5 行和第 6 行)。

LoanID  Period  DueAmt  ActualPmt   PastDue
1       1       100     100             0
1       2       100     60              -40
1       3       100     100             -40
1       4       100     200             0   <== This advance payment is not rolled to next period
1       5       100     110             0   <== This advance payment is not rolled to next period
1       6       100     80              -20
1       7       100     60              -60
1       8       100     100             -60
2       1       150     150             0
2       2       150     150             0
2       3       150     150             0
3       1       200     200             0
3       2       200     120             -80
3       3       200     120             -160

为了解决这个问题,我需要按照时间段排序,为每个分区(LoanID)应用自定义函数。

spark中有哪些可用的选项。

简单但复杂似乎使用 DF->RDD->groupby,应用 lambda 转换回数据框。

更优雅的是使用窗口函数的自定义UDAF(在 scala 中?)但找不到它的单一实现示例。


好的,所以我尝试了第一个解决方案,从 Dataframe 到 Pair RDD 再返回

    from pyspark.sql import Row 
    def dueAmt(partition):
        '''
        @type partition:list 
        '''
        #first sort rows
        sp=sorted(partition, key=lambda r: r.Period )
        res=[]
        due=0
        for r in sp:
            due+=r.ActualPmt-r.DueAmt
            if due>0: due=0;
            #row is immutable so we need to create new row with updated value
            d=r.asDict()
            d['CalcDueAmt']=-due
            newRow=Row(**d)
            res.append(newRow)
        return res    

    df = sqlContext.read.format('com.databricks.spark.csv').options(header='true', inferschema='true').load('PmtDueSample2.csv').cache()
    rd1=df.rdd.map(lambda r: (r.LoanID, r ) )
    rd2=rd1.groupByKey()
    rd3=rd2.mapValues(dueAmt)
    rd4=rd3.flatMap(lambda t: t[1] )
    df2=rd4.toDF()

似乎有效。

在此过程中,我实际上发现了 pyspark 实现中的几个错误。

  1. 类 Row 中 _call_ 的实现是错误的。
  2. Row 的构造函数中有一个烦人的错误。由于没有明显的原因,_new_ 对列进行排序,因此在旅程结束时,我得到的表的列按字母顺序排列。这只会让查看最终结果变得更加困难。

阅读 70

收藏
2025-02-20

共1个答案

小能豆

虽然不美观也不高效,但应该能给你一些帮助。让我们从创建和注册表开始:

val df = sc.parallelize(Seq(
  (1, 1, 100, 100), (1, 2, 100, 60), (1, 3, 100, 100),
  (1, 4, 100, 200), (1, 5, 100, 110), (1, 6, 100, 80),
  (1, 7, 100, 60), (1, 8, 100, 100), (2, 1, 150, 150),
  (2, 2, 150, 150), (2, 3, 150, 150), (3, 1, 200, 200),
  (3, 2, 200, 120), (3, 3, 200, 120)
)).toDF("LoanID", "Period", "DueAmt", "ActualPmt")

df.registerTempTable("df")

接下来让我们定义并注册一个 UDF:

case class Record(period: Int, dueAmt: Int, actualPmt: Int, pastDue: Int)

def runningPastDue(idxs: Seq[Int], dues: Seq[Int], pmts: Seq[Int]) = {
  def f(acc: List[(Int, Int, Int, Int)], x: (Int, (Int, Int))) = 
    (acc.head, x) match {
      case ((_, _, _, pastDue), (idx, (due, pmt))) => 
        (idx, due, pmt, (pmt - due + pastDue).min(0)) :: acc
    }

  idxs.zip(dues.zip(pmts))
    .toList
    .sortBy(_._1)
    .foldLeft(List((0, 0, 0, 0)))(f)
    .reverse
    .tail
    .map{ case (i, due, pmt, past) => Record(i, due, pmt, past) }
}

sqlContext.udf.register("runningPastDue", runningPastDue _)

汇总并计算总和:

val aggregated = sqlContext.sql("""
  SELECT LoanID, explode(pmts) pmts FROM (
    SELECT LoanId, 
           runningPastDue(
             collect_list(Period), 
             collect_list(DueAmt), 
             collect_list(ActualPmt)
           ) pmts
    FROM df GROUP BY LoanID) tmp""")

val flattenExprs = List("Period", "DueAmt", "ActualPmt", "PastDue")
  .zipWithIndex
  .map{case (c, i) => col(s"tmp._${i+1}").alias(c)}

最后压平:

val result = aggregated.select($"LoanID" :: flattenExprs: _*)
2025-02-20