自定义开发Spark ML机器学习类 - 1

初窥门径

Spark的MLlib组件内置实现了很多常见的机器学习算法,包括数据抽取,分类,聚类,关联分析,协同过滤等等.
然鹅,内置的算法并不能满足我们所有的需求,所以我们还是经常需要自定义ML算法.

MLlib提供的API分为两类:

  • 1.基于DataFrame的API,属于spark.ml包.
  • 2.基于RDD的API, 属于spark.mllib包.

从Spark 2.0开始,Spark的API全面从RDD转向DataFrame,MLlib也是如此,官网原话如下:

Announcement: DataFrame-based API is primary API

The MLlib RDD-based API is now in maintenance mode.

所以本文将介绍基于DataFrame的自定义ml类编写方法.不涉及具体算法,只讲扩展ml类的方法.

略知一二

官方文档并没有介绍如何自定义ml类,所以只有从源码入手,看看源码里面是怎么实现的.

找一个最简单的内置算法入手,这个算法就是内置的分词器,Tokenizer.

Tokenizer只是简单的将文本以空白部分进行分割,只适合给英文进行分词,所以它的实现及其简短,源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package org.apache.spark.ml.feature

import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}

/**
* A tokenizer that converts the input string to lowercase and then splits it by white spaces.
*
* @see [[RegexTokenizer]]
*/
@Since("1.2.0")
class Tokenizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends UnaryTransformer[String, Seq[String], Tokenizer] with DefaultParamsWritable {

@Since("1.2.0")
def this() = this(Identifiable.randomUID("tok"))

override protected def createTransformFunc: String => Seq[String] = {
_.toLowerCase.split("\\s")
}

override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}

override protected def outputDataType: DataType = new ArrayType(StringType, true)

@Since("1.4.1")
override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}

@Since("1.6.0")
object Tokenizer extends DefaultParamsReadable[Tokenizer] {

@Since("1.6.0")
override def load(path: String): Tokenizer = super.load(path)
}

简单分析下源码:

  • Tokenizer继承了UnaryTransformer类.unary是’一元’的意思,也是说这个类实现的是类似一元函数的功能,一个输入变量,一个输出.直接看UnaryTransformer的源码注释:
1
2
3
4
5
/**
* :: DeveloperApi ::
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/

DeveloperApi表明这是一个开发级API,开发者可以用,不会有权限问题(源码中有很多private[spark]的类,是不允许外部调用的).
注释的大意就是:这是一个为实现transformers准备的抽象类,以一个字段(列)为输入,输出一个新字段(列).

所以实际上就是实现一个Transformer,只是这个Transformer有指定的输入字段和输出字段.

  • UnaryTransformer类中只有两个抽象方法.
    一个是createTransformFunc,是最核心的方法,这个方法需要返回一个函数,这个函数的参数即Transformer的输入字段的值,返回值为Transformer的输出字段的值.看看Tokenizer中的实现,就明白了.

另一个是outputDataType,这个方法用来返回输出字段的类型.

  • validateInputType方法是用来检查输入字段类型的,看需要实现.

  • Tokenizer混入了DefaultParamsWritable特质,使得自己可以被保存.
    对应的object Tokenizer伴生对象,用来读取已保存的Tokenizer.

  • 值得注意的是,Transformer类是PipelineStage类的子类,所以Transformer的子类,包括我们自定义的,是可以直接用在ML Pipelines中的.这就厉害了,说明自定义的算法类,可以无缝与内置机器学习算法打配合,还能利用Pipeline的调优工具(model selection,Cross-Validation等).

初出茅庐

看完源码,基本套路已经明了,不如动手抄一个,不,敲一个.
依葫芦画瓢,实现一个正则提取的Transformer.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import util.matching.Regex

import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.types._

/**
* 正则提取器
* 将匹配指定正则表达式的全部子字符串,提取到array[string]中.
*/
class RegexExtractor(override val uid: String)
extends UnaryTransformer[String, Seq[String], RegexExtractor] {

def this() = this(Identifiable.randomUID("RegexExtractor"))

/**
* 参数:正则表达式
*
* @group param
*/
final val regex = new Param[Regex](this, "RegexExpr", "正则表达式")

/** @group setParam */
def setRegexExpr(value: String): this.type = set(regex, new Regex(value))

override protected def outputDataType: DataType = new ArrayType(StringType, true)

override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == DataTypes.StringType,
s"Input type must be string type but got $inputType."
)
}

override protected def createTransformFunc: String => Seq[String] = {
parseContent
}

/**
* 数据处理
*/
private def parseContent(text: String): Seq[String] = {
if (text == null || text.isEmpty) {
return Seq.empty[String]
}
$(regex).findAllIn(text).toSeq
}

}

这个类结构与Tokenizer源码基本差不多,多用到的Param类,是一个参数的包装类.
作用是self-contained documentation and optionally default value.
其实就是把参数的值,文档,默认值等属性组合成一个类,方便调用.

比如上面定义的regex参数,就可以用$(regex)这样的方式直接调用.

另外在org.apache.spark.ml.param中有很多内置的Param类,可以直接使用.

同时org.apache.spark.ml.param.shared中有很多辅助引入参数的特质,比如HasInputCols特质,你的自定义Transformer只要混入这个特质就拥有了inputCols参数.不过目前shared中特质的作用域是private[ml],也就是说不能直接引用,而是要copy一份代码到自己的项目,并修改作用域才行.
关于这个作用域的问题,有人在spark的jira上提到,提议将其作为DeveloperApi开放出来,我也投了一票表示支持.后来在2017年11月终于resolved,该问题将在Spark2.3.0中解决.详情戳我

粗懂皮毛

自定义的类写好了,该怎么用呢? 当然是跟内置的一样啦.上栗子:

1
2
3
4
5
6
7
8
9
10
11
12
val regex="nidezhengze"

val tranTitle = new RegexExtractor()
.setInputCol("title")
.setOutputCol("title_price_texts")
.setRegexExpr(regex)

val pipeline = new Pipeline().setStages(Array(
tranTitle
))

val matched = pipeline.fit(data).transform(data)

打完收功

到这里,开发简单Transform的套路已经清楚了,不过这里实现的功能比较类似于一个UDF,只能对dataset的一个字段进行处理,而且是逐行处理,并不能根据多行数据进行处理,实现窗口函数类似的功能,而且也没有涉及模型的输出.如果要开发更复杂的算法,甚至进行模型训练,就需要更深入的了解MLlib了,阅读源码是个好途径.

有机会再说.👋