Gradient Descent & Ascent in Java

Dr. Manoj Kumar Yadav
3 min readNov 19, 2023

If it sounds like you want to perform an iterative optimization process, where the value of a variable in a given equation is adjusted until the final output meets a certain criteria. This is often done using optimization algorithms. A popular optimization algorithm is the gradient descent method.

Here’s a simple example in Java using gradient descent for a univariate function:

import java.util.function.Function;

public class GradientDescent {

public static void main(String[] args) {
// Example: Minimize the function f(x) = x^2
Function<Double, Double> function = x -> x * x;

double initialGuess = 3.0; // Initial guess for the minimum
double learningRate = 0.1; // The learning rate determines the size of each step

// Gradient Descent
double resultGD = gradientDescent(function, initialGuess, learningRate);
System.out.println("Local Minimum found using Gradient Descent at x = " + resultGD);
System.out.println("Minimum value of the function = " + function.apply(resultGD));


// Gradient Ascent
double resultGA = gradientAscent(function, initialGuess, learningRate);
System.out.println("Local Maximum found using Gradient Ascent at x = " + resultGA);
System.out.println("Value of the function = " + function.apply(resultGA));

}

// Gradient descent algorithm
private static double gradientDescent(Function<Double, Double> function, double initialGuess, double learningRate) {
double x = initialGuess;
long iteration=0;

// Loop until convergence (you can also use a fixed number of iterations)
while (true) {
iteration++;
double gradient = computeGradient(function, x);
double nextX = x - learningRate * gradient;

// Check for convergence (adjust the threshold as needed)
if (Math.abs(nextX - x) < 1e-6) {
break;
}

x = nextX;
}

System.out.println("iteration count = "+iteration);

return x;
}

// Gradient ascent algorithm
private static double gradientAscent(Function<Double, Double> function, double initialGuess, double learningRate) {
double x = initialGuess;
long iteration = 0;

// Loop until convergence (you can also use a fixed number of iterations)
while (true) {
iteration++;
double gradient = computeGradient(function, x);
double nextX = x + learningRate * gradient; // Note the change here for ascent

// Check for convergence (adjust the threshold as needed)
if (Math.abs(nextX - x) < 1e-6) {
break;
}

x = nextX;
}

System.out.println("Gradient Ascent Iteration Count = " + iteration);

return x;
}

// Compute the derivative of the function at a given point
private static double computeGradient(Function<Double, Double> function, double x) {
double h = 1e-5; // Small step size for numerical differentiation
return (function.apply(x + h) - function.apply(x)) / h;
}
}

The code comes from a LLM. It is interesting to see that the implementation appears to be so simple and yet the results are so profound.

This example minimizes the function f(x) = x², but you can replace the function with your own function. The “gradientDescent” function performs the optimization, and you might need to adjust the learning rate and convergence criteria based on the characteristics of your specific problem.

This has been used to generate the bounds of a function to generate meaningful graphs like this without human inputs:

The code I provided for the gradient descent is designed for smooth, differentiable functions. The function

f(x)= 1/ x² is not differentiable at x=0 , and the gradient descent algorithm may not converge in such cases.

If you want to handle non-differentiable functions or functions with singularities, you might need a more advanced optimization algorithm or handle the singularity separately. Gradient descent is not the best choice for all types of functions.

For functions like f(x)= 1/ x², you might want to explore other optimization algorithms, such as those designed for non-smooth or non-convex functions. Additionally, you might need to handle the singularity at x=0 separately.

References:

--

--

Dr. Manoj Kumar Yadav

Doctor of Business Administration | Sr. Director-Engineering at redBus | Data Engineering | ML | Servers | Serverless | Java | Python | Dart | 3D/2D