初窥门径 Spark的MLlib组件内置实现了很多常见的机器学习算法,包括数据抽取,分类,聚类,关联分析,协同过滤等等. 然鹅,内置的算法并不能满足我们所有的需求,所以我们还是经常需要自定义ML算法.
MLlib提供的API分为两类:
1.基于DataFrame的API,属于spark.ml包.
2.基于RDD的API, 属于spark.mllib包.
从Spark 2.0开始,Spark的API全面从RDD转向DataFrame,MLlib也是如此,官网原话如下:
Announcement: DataFrame-based API is primary API
The MLlib RDD-based API is now in maintenance mode.
所以本文将介绍基于DataFrame的自定义ml类编写方法.不涉及具体算法,只讲扩展ml类的方法.
略知一二 官方文档并没有介绍如何自定义ml类,所以只有从源码入手,看看源码里面是怎么实现的.
找一个最简单的内置算法入手,这个算法就是内置的分词器,Tokenizer.
Tokenizer只是简单的将文本以空白部分进行分割,只适合给英文进行分词,所以它的实现及其简短,源码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 package org.apache.spark.ml.featureimport org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._import org.apache.spark.ml.util._import org.apache.spark.sql.types.{ArrayType , DataType , StringType }@Since ("1.2.0" )class Tokenizer @Since ("1.4.0" ) (@Since ("1.4.0" ) override val uid : String ) extends UnaryTransformer [String , Seq [String ], Tokenizer ] with DefaultParamsWritable { @Since ("1.2.0" ) def this () = this (Identifiable .randomUID("tok" )) override protected def createTransformFunc : String => Seq [String ] = { _.toLowerCase.split("\\s" ) } override protected def validateInputType (inputType: DataType ): Unit = { require(inputType == StringType , s"Input type must be string type but got $inputType ." ) } override protected def outputDataType : DataType = new ArrayType (StringType , true ) @Since ("1.4.1" ) override def copy (extra: ParamMap ): Tokenizer = defaultCopy(extra) } @Since ("1.6.0" )object Tokenizer extends DefaultParamsReadable [Tokenizer ] { @Since ("1.6.0" ) override def load (path: String ): Tokenizer = super .load(path) }
简单分析下源码:
Tokenizer继承了UnaryTransformer类.unary是’一元’的意思,也是说这个类实现的是类似一元函数的功能,一个输入变量,一个输出.直接看UnaryTransformer的源码注释:
DeveloperApi表明这是一个开发级API,开发者可以用,不会有权限问题(源码中有很多private[spark]的类,是不允许外部调用的). 注释的大意就是:这是一个为实现transformers准备的抽象类,以一个字段(列)为输入,输出一个新字段(列).
所以实际上就是实现一个Transformer,只是这个Transformer有指定的输入字段和输出字段.
UnaryTransformer类中只有两个抽象方法. 一个是createTransformFunc,是最核心的方法,这个方法需要返回一个函数,这个函数的参数即Transformer的输入字段的值,返回值为Transformer的输出字段的值.看看Tokenizer中的实现,就明白了.
另一个是outputDataType,这个方法用来返回输出字段的类型.
validateInputType方法是用来检查输入字段类型的,看需要实现.
Tokenizer混入了DefaultParamsWritable特质,使得自己可以被保存. 对应的object Tokenizer伴生对象,用来读取已保存的Tokenizer.
值得注意的是,Transformer类是PipelineStage类的子类,所以Transformer的子类,包括我们自定义的,是可以直接用在ML Pipelines 中的.这就厉害了,说明自定义的算法类,可以无缝与内置机器学习算法打配合,还能利用Pipeline的调优工具(model selection,Cross-Validation等).
初出茅庐 看完源码,基本套路已经明了,不如动手抄一个,不,敲一个. 依葫芦画瓢,实现一个正则提取的Transformer.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 import util.matching.Regex import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.Param import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.types._class RegexExtractor (override val uid: String ) extends UnaryTransformer [String , Seq [String ], RegexExtractor ] { def this () = this (Identifiable .randomUID("RegexExtractor" )) final val regex = new Param [Regex ](this , "RegexExpr" , "正则表达式" ) def setRegexExpr (value: String ): this .type = set(regex, new Regex (value)) override protected def outputDataType : DataType = new ArrayType (StringType , true ) override protected def validateInputType (inputType: DataType ): Unit = { require(inputType == DataTypes .StringType , s"Input type must be string type but got $inputType ." ) } override protected def createTransformFunc : String => Seq [String ] = { parseContent } private def parseContent (text: String ): Seq [String ] = { if (text == null || text.isEmpty) { return Seq .empty[String ] } $(regex).findAllIn(text).toSeq } }
这个类结构与Tokenizer源码基本差不多,多用到的Param类,是一个参数的包装类. 作用是self-contained documentation and optionally default value. 其实就是把参数的值,文档,默认值等属性组合成一个类,方便调用.
比如上面定义的regex参数,就可以用$(regex)这样的方式直接调用.
另外在org.apache.spark.ml.param中有很多内置的Param类,可以直接使用.
同时org.apache.spark.ml.param.shared中有很多辅助引入参数的特质,比如HasInputCols特质,你的自定义Transformer只要混入这个特质就拥有了inputCols参数.不过目前shared中特质的作用域是private[ml],也就是说不能直接引用,而是要copy一份代码到自己的项目,并修改作用域才行. 关于这个作用域的问题,有人在spark的jira上提到,提议将其作为DeveloperApi开放出来,我也投了一票表示支持.后来在2017年11月终于resolved,该问题将在Spark2.3.0中解决.详情戳我
粗懂皮毛 自定义的类写好了,该怎么用呢? 当然是跟内置的一样啦.上栗子:
1 2 3 4 5 6 7 8 9 10 11 12 val regex="nidezhengze" val tranTitle = new RegexExtractor () .setInputCol("title" ) .setOutputCol("title_price_texts" ) .setRegexExpr(regex) val pipeline = new Pipeline ().setStages(Array ( tranTitle )) val matched = pipeline.fit(data).transform(data)
打完收功 到这里,开发简单Transform的套路已经清楚了,不过这里实现的功能比较类似于一个UDF,只能对dataset的一个字段进行处理,而且是逐行处理,并不能根据多行数据进行处理,实现窗口函数类似的功能,而且也没有涉及模型的输出.如果要开发更复杂的算法,甚至进行模型训练,就需要更深入的了解MLlib了,阅读源码是个好途径.
有机会再说.👋