spark 动态UDF加载实现(2)

 def LoadUserDefineUDF(user: String, spark: SparkSession): Unit = {
    val brainUrl: String = PropertiesLoader.getProperties("database.properties").getProperty("brain.url")
    val brainPrefix = brainUrl.substring(0, brainUrl.indexOf("feature-panel/online-feature") - 1)
    val udfURL = s"$brainPrefix/udfWarehouse/findInfoByUserId?userId=$user"
    val simpleHttp = new SimpleHttp
    val result = simpleHttp.fetchResponseText(udfURL)
    logger.info("***********************************************************")
    logger.info(s"result:$result")
    try {
      val resultJson = JSON.parseObject(result)
      val flag = resultJson.getInteger("code").toInt
      flag match {
        case 0 => LoadUDF(resultJson, spark)
        case _ => logger.error(s"加载用户自定义离线特征处理udf失败!原因:${resultJson.getString("msg")}")
      }
    } catch {
      case e: Exception =>
        logger.error(s"加载用户自定义离线特征处理udf失败!原因:服务器异常!" + e.getMessage, e)
    }
  }

  def LoadUDF(jsonObj: JSONObject, spark: SparkSession): Unit = {
    val udfArray = jsonObj.getJSONObject("data").getJSONArray("data")
    var array = mutable.ArrayBuilder.make[URL]()
    logger.info("************************************************************")
    val methodMap = new mutable.HashMap[String, (String, String, String)]()
    for (i <- 0 to udfArray.length - 1) {
      val udfJson = udfArray.getJSONObject(i)
      val udfName = udfJson.getString("udfName")
      val downLoadJarUrl = udfJson.getString("downLoadJarUrl")
      val entryClass = udfJson.getString("entryClass")
      val jarName = udfJson.getString("jarName")+".jar"
      val functionName = udfJson.getString("functionName")
      try {
        downLoadJar(downLoadJarUrl, jarName)
        spark.sparkContext.addJar(HdfsPrefix+jarName)
        val url2 = new URL(s"file:./${jarName}")
        logger.info(s"*********加载udf $udfName 成功**********")
        methodMap.put(udfName, (functionName, entryClass, jarName))
        array += url2
      } catch {
        case e: Exception =>
          logger.error(s"$jarName $functionName $entryClass Exception!!!", e.getMessage)
      }
    }
    ScalaGenerateFunctions(array.result())
    methodMap.foreach {
      map =>
        try {
          val (fun, inputType, returnType) = ScalaGenerateFunctions.genetateFunction(map._2._1, map._2._2, map._2._3)
          val inputTypes = Try(List(inputType)).toOption
          spark.udf.register(map._1, fun, returnType)
          logger.info(s"*********spark 注册udf ${map._1} 成功**********")
        } catch {
          case e: Exception =>
            logger.error(s"*********spark 注册udf ${map._1} 失败!!", e.getMessage)
        }
    }
  }

  def downLoadJar(url: String, jarName: String): Unit = {
    logger.info("*******************************************")
    logger.info(s"****************url:$url**********************")
    //val path = "E:\\temp\\"
    val path = "./"
    val file = new File(path)
    //val jars = Array("test.jar", "test2.jar")
    if (!file.exists()) {
      //如果文件夹不存在,则创建新的的文件夹
      file.mkdirs()
    }
    var fileOut: FileOutputStream = null
    var conn: HttpURLConnection = null
    var inputStream: InputStream = null
    try {
      val httpUrl = new URL(url)
      conn = httpUrl.openConnection().asInstanceOf[HttpURLConnection]
      conn.setRequestMethod("GET")
      conn.setDoInput(true)
      conn.setDoOutput(true)
      // post方式不能使用缓存
      conn.setUseCaches(false)
      //连接指定的资源
      conn.connect()
      //获取网络输入流
      inputStream = conn.getInputStream();
      val bis = new BufferedInputStream(inputStream)
      fileOut = new FileOutputStream(path + jarName)
      val bos = new BufferedOutputStream(fileOut)
      val buf = new Array[Byte](4096)
      var length = bis.read(buf);
      //保存文件
      while (length != -1) {
        bos.write(buf, 0, length);
        length = bis.read(buf);
      }
      //关闭流
      bos.close();
      bis.close();
      conn.disconnect();
    } catch {
      case e: Exception =>
        logger.error(s"下载jar:$jarName 出错" + e.getMessage, e)
    }
  }

下面是一个单元测试

@Test
  def testStr2VecJson(): Unit = {
    System.setProperty("hadoop.home.dir", "D:\\winutils")
    val conf = new SparkConf().setAppName("test").setMaster("local[2]")//.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    //    val sc = new SparkContext(conf)
    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._
    val data = Array("1", "2")
    val rdd = spark.sparkContext.parallelize(data)
    val df = rdd.toDF("str")
    //这里套用工具类 E:\adworkSpace\autotask\target
    val url = new URL("file:F:/ad_codes/data_flow_test/target/data_flow_test-1.0-SNAPSHOT.jar")
    val urls = Array(url)
    ScalaGenerateFunctions(urls)


    val className = "com.vivo.ai.temp.Method"
    val methodArray = Array("str2VecJson")

    methodArray.foreach {
      methodName =>
        val (fun, inputType, returnType) = ScalaGenerateFunctions.genetateFunction(methodName, className,"autotask-2.0-SNAPSHOT.jar")
        val inputTypes = Try(List(inputType)).toOption

        //def builder(e: Seq[Expression]) = ScalaUDF(fun, returnType, e, inputTypes.getOrElse(Nil), Some(methodName))
        spark.udf.register(methodName, fun, returnType)
      //        def builder(e: Seq[Expression]) = ScalaUDF(function = fun, dataType = returnType, children = e, Seq(true), inputTypes = inputTypes.getOrElse(Nil), udfName = Some(methodName))
      //
      //        spark.sessionState.functionRegistry.registerFunction(new FunctionIdentifier(methodName), builder)
    }
    df.createTempView("strDF")
    df.show()
    spark.sql("select str2VecJson(str) from strDF").show()
  }
  
其中 com.vivo.ai.temp.Method定义如下

import com.alibaba.fastjson.JSON
import com.vivo.ai.encode.ContinuousEncoder
import com.vivo.ai.encode.util.EncodeEnv
import com.vivo.vector.Vector
import com.alibaba.fastjson.serializer.SerializerFeature
import scala.collection.mutable
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
/**
  * @author liangwt
  * @create 2020/11/16
  * @since 1.0.0
  * @Description :
  */
class Method {
  def method(value:String):String={
    //用户自定义处理方法
    value
  }
  def method2(value:Int):String={
    (value+100).toString
  }
  def testMap(value:Int):scala.collection.Map[String,String]={
    scala.collection.Map("1"->"1")
  }
  def testJMap(value:Int):java.util.Map[String,String]={
    scala.collection.Map("1"->"1").asJava
  }

  def testMap2(value:Int):Map[String,String]={
    Map("1"->"1")
  }
  def testSet(value:Int)={
    val set=mutable.Set("1")
    set.asJava
  }
  def testSeq(value:Int)={
   Seq("1")
  }
  def inputSeq(seq:Seq[Int]): String ={
    "1"
  }
  def inputMap(map:Map[String,Integer]):String={
    "1"
  }
  def str2VecJson(str:String):String={
    var userNewsRTFeatureVec: Vector = Vector.builder(24).build()
    var userArrayBuffer = new ArrayBuffer[(Int,Float)]()
//    str.split(",").foreach{
//      line=>
//
//    }
//    Tools.str2Map(str).map {
//      case (k, v) =>
//        //val index = ContinuousEncoder.encode("news_category_v3", k, EncodeEnv.PRD)
//        userArrayBuffer += (index -> v.toString.toFloat)
//    }
    userArrayBuffer +=(1->0.1f)
    userArrayBuffer.sortWith((x,y) => x._1 < y._1)
    userNewsRTFeatureVec.setIndices(userArrayBuffer.map(_._1).toArray)
    userNewsRTFeatureVec.setValues(userArrayBuffer.map(_._2).toArray)
    JSON.toJSONString(userNewsRTFeatureVec,SerializerFeature.IgnoreNonFieldGetter)
  }

}

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容