java梯度下降
public class GradientDescent {
public static void main(String[] args) {
double x = 3;
double y = pow(x);// loss函数
while ((y = pow(x)) > 0.01) {
double d = div(x);//求偏导
x = x - 0.01 * d;//下降更新
System.out.println("loss:"+y);
}
System.out.println("x:" + x + " y:" + y);
}
public static double pow(double x){
return Math.pow(x, 2);
}
public static double div(double x){
return 2*x;
}
}