The gradientof a scalar-valued function is a vectorthat points in the direction of the steepest increase of the function. It’s one of the most fundamental ideas in multivariable calculus.
For a function f(x,y), the gradient is defined as:
▽f(x,y)=(∂x∂f,∂y∂f)
For a function 𝑓 ( 𝑥 , 𝑦 , 𝑧 ):
▽f(x,y,z)=(∂x∂f,∂y∂f,∂z∂f)
Why is the Gradient Important?
It tells us the directionof the fastest increase of the function.
It is perpendicular (normal) tolevel curves (in 2D) or level surfaces (in 3D).
It is used in optimization problems like gradient descent.
Step-by-Step Mathematical Explanation
Let’s say we have a scalar function:
f(x,y)=x2+y2
Step 1: Compute Partial Derivatives
∂x∂f=2xand∂y∂f=2y
Step 2: Write the Gradient Vector
▽f(x,y)=(2x,2y)
At a point like ( 1 , 1 ) , the gradient becomes:
▽f(1,1)=(2,2)
This vector points in the direction of steepest ascent of the function.
Example 1
f(x,y)=x2+y2
Let’s compute and visualize gradient vectors on a 2D surface.
import numpy as np
import matplotlib.pyplot as plt
# Define the function
def f(x, y):
return x**2 + y**2
# Define partial derivatives (gradient components)
def grad_f(x, y):
df_dx = 2 * x
df_dy = 2 * y
return df_dx, df_dy
# Generate a grid
x = np.linspace(-2, 2, 20)
y = np.linspace(-2, 2, 20)
X, Y = np.meshgrid(x, y)
# Compute gradient on the grid
U, V = grad_f(X, Y)
# Plot the vector field (gradient)
plt.figure(figsize=(8, 6))
plt.quiver(X, Y, U, V, color='blue')
plt.title("Gradient Vector Field of f(x, y) = x² + y²")
plt.xlabel('x')
plt.ylabel('y')
plt.grid()
plt.axis('equal')
plt.show()
def f2(x, y):
return x * np.exp(-x**2 - y**2)
def grad_f2(x, y):
df_dx = (1 - 2*x**2) * np.exp(-x**2 - y**2)
df_dy = -2 * x * y * np.exp(-x**2 - y**2)
return df_dx, df_dy
# Generate grid
x = np.linspace(-2, 2, 20)
y = np.linspace(-2, 2, 20)
X, Y = np.meshgrid(x, y)
# Gradient values
U2, V2 = grad_f2(X, Y)
# Plot
plt.figure(figsize=(8, 6))
plt.quiver(X, Y, U2, V2, color='green')
plt.title("Gradient Vector Field of f(x, y) = x * exp(-x² - y²)")
plt.xlabel('x')
plt.ylabel('y')
plt.grid()
plt.axis('equal')
plt.show()