import numpy as np
import random
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
np.random.seed(42)
random.seed(42)
def generate_data():
X = np.linspace(0, 10, 60)
y_true = 2.5 * X - 1.0
y_inliers = y_true + np.random.normal(0, 0.3, size=X.shape)
X_near = np.linspace(0, 10, 40)
y_near = 2.5 * X_near - 1.0 + np.random.normal(1.5, 0.5, size=X_near.shape)
X_all = np.concatenate([X, X_near])
y_all = np.concatenate([y_inliers, y_near])
points = list(zip(X_all, y_all))
return points, X_all, y_all, y_true
def fit_line(points):
xs = np.array([p[0] for p in points])
ys = np.array([p[1] for p in points])
a, b = np.polyfit(xs, ys, 1)
return a, b
def ransac(points, iterations=200, threshold=2.0):
best_model = None
best_inliers = []
for _ in range(iterations):
sample = random.sample(points, 2)
a, b = fit_line(sample)
inliers = [(x, y) for x, y in points if abs(y - (a * x + b)) < threshold]
if len(inliers) > len(best_inliers):
best_inliers = inliers
best_model = (a, b)
return best_model, best_inliers
def msac(points, iterations=200, threshold=2.0):
best_model = None
best_cost = float("inf")
for _ in range(iterations):
sample = random.sample(points, 2)
a, b = fit_line(sample)
cost = sum((y - (a * x + b)) ** 2 if abs(y - (a * x + b)) < threshold else threshold ** 2 for x, y in points)
if cost < best_cost:
best_cost = cost
best_model = (a, b)
return best_model, best_cost
def mse(model, points):
a, b = model
return np.mean([(y - (a * x + b)) ** 2 for x, y in points])
if __name__ == "__main__":
points, X_all, y_all, y_true = generate_data()
ransac_model, ransac_inliers = ransac(points)
msac_model, msac_cost = msac(points)
print("=== Model parameters ===")
print("True line: y = 2.5x - 1.0")
print(f"RANSAC line: y = {ransac_model[0]:.3f}x + {ransac_model[1]:.3f}")
print(f"MSAC line: y = {msac_model[0]:.3f}x + {msac_model[1]:.3f}")
print("\n=== Error comparison ===")
print(f"RANSAC MSE: {mse(ransac_model, points):.4f}")
print(f"MSAC MSE: {mse(msac_model, points):.4f}")
plt.figure(figsize=(8, 5))
plt.scatter(X_all, y_all, s=10, color="gray", label="data")
X_plot = np.linspace(0, 10, 100)
plt.plot(X_plot, 2.5 * X_plot - 1, "k--", label="Ground truth")
plt.plot(X_plot, ransac_model[0] * X_plot + ransac_model[1], "b", label="RANSAC")
plt.plot(X_plot, msac_model[0] * X_plot + msac_model[1], "r", label="MSAC")
plt.legend()
plt.title("RANSAC vs MSAC (near-outliers visible)")
plt.xlabel("x")
plt.ylabel("y")
plt.tight_layout()
plt.savefig("ransac_vs_msac.png", dpi=150)
plt.close()
print("\nFigure saved to: ransac_vs_msac.png")