| 注册
请输入搜索内容

热门搜索

Java Linux MySQL PHP JavaScript Hibernate jQuery Nginx
ye34
10年前发布

K-means算法(Spark Demo)

import java.util.Random  import spark.SparkContext  import spark.SparkContext._  import spark.examples.Vector._     object SparkKMeans {      /**       * line -> vector       */  def parseVector (line: String) : Vector = {          return new Vector (line.split (' ').map (_.toDouble) )      }         /**       * 计算该节点的最近中心节点       */  def closestCenter (p: Vector, centers: Array[Vector]) : Int = {          var bestIndex = 0          var bestDist = p.squaredDist (centers (0) ) //差平方之和          for (i < - 1 until centers.length) {              val dist = p.squaredDist (centers (i) )              if (dist < bestDist) {                  bestDist = dist                  bestIndex = i              }          }          return bestIndex      }     def main (args: Array[String]) {          if (args.length < 3) {              System.err.println ("Usage: SparkKMeans <master> <file> <dimensions> <k> <iters>")              System.exit (1)          }          val sc = new SparkContext (args (0), "SparkKMeans")          val lines = sc.textFile (args (1), args (5).toInt)                      val points = lines.map (parseVector (_) ).cache() //文本中每行为一个节点,再将每个节点转换成Vector                                   val dimensions = args (2).toInt //节点的维度                                           val k = args (3).toInt //聚类个数                                                   val iterations = args (4).toInt //迭代次数                                                              // 随机初始化k个中心节点                                                           val rand = new Random (42)          var centers = new Array[Vector] (k)          for (i < - 0 until k)              centers (i) = Vector (dimensions, _ => 2 * rand.nextDouble - 1)                            println ("Initial centers: " + centers.mkString (", ") )                            val time1 = System.currentTimeMillis()              for (i < - 1 to iterations) {                  println ("On iteration " + i)                     // Map each point to the index of its closest center and a (point, 1) pair                  // that we will use to compute an average later                  val mappedPoints = points.map { p => (closestCenter (p, centers), (p, 1) ) }                     val newCenters = mappedPoints.reduceByKey {                  case ( (sum1, count1), (sum2, count2) ) => (sum1 + sum2, count1 + count2) //(向量相加, 计数器相加)                      } .map {                  case (id, (sum, count) ) => (id, sum / count) //根据前面的聚类,重新计算中心节点的位置                      } .collect                     // 更新中心节点                  for ( (id, value) < - newCenters) {                      centers (id) = value                  }              }                         val time2 = System.currentTimeMillis()                                     println ("Final centers: " + centers.mkString (", ") + ", time: " + (time2 - time1) )      }  }