table_detector.py 5.72 KB
import cv2
import numpy as np
import os

def filter_horizontal_lines(lines_h, img_width, min_h_len_ratio=0.7, tol_y=10):
    if lines_h is None:
        return [], []

    ys_candidates = []
    for l in lines_h:
        x1, y1, x2, y2 = l[0]
        if abs(y1 - y2) <= 3:  # ngang
            line_len = abs(x2 - x1)
            y_mid = int(round((y1 + y2) / 2))
            ys_candidates.append((y_mid, line_len, x1, x2))

    ys_candidates.sort(key=lambda x: x[0])
    filtered_lines, line_segments, current_group = [], [], []

    for y, length, x1, x2 in ys_candidates:
        if not current_group:
            current_group.append((y, length, x1, x2))
        else:
            if abs(y - current_group[-1][0]) <= tol_y:
                current_group.append((y, length, x1, x2))
            else:
                longest = max(current_group, key=lambda x: x[1])
                if longest[1] >= min_h_len_ratio * img_width:
                    filtered_lines.append(longest[0])
                    line_segments.append((longest[2], longest[3], longest[0]))
                else:
                    break
                current_group = [(y, length, x1, x2)]

    if current_group:
        longest = max(current_group, key=lambda x: x[1])
        if longest[1] >= min_h_len_ratio * img_width:
            filtered_lines.append(longest[0])
            line_segments.append((longest[2], longest[3], longest[0]))

    total_rows = max(0, len(filtered_lines) - 1)
    print(f"Detected {total_rows} rows")
    return filtered_lines, line_segments


def detect_tables(image_path, gap_threshold=50):
    img = cv2.imread(image_path)
    if img is None:
        raise FileNotFoundError(f"Không đọc được ảnh: {image_path}")

    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blur = cv2.GaussianBlur(gray, (3, 3), 0)
    edges = cv2.Canny(blur, 50, 150, apertureSize=3)

    # --- Horizontal lines ---
    lines_h = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=120,
                              minLineLength=int(img.shape[1] * 0.6), maxLineGap=20)
    img_height, img_width = img.shape[:2]
    ys, line_segments = filter_horizontal_lines(lines_h, img_width, min_h_len_ratio=0.7, tol_y=10)
    total_rows = max(0, len(ys) - 1)

    # --- Vertical lines ---
    lines_v = cv2.HoughLinesP(edges, 1, np.pi/180, threshold=100,
                              minLineLength=int(img.shape[0] * 0.4), maxLineGap=20)
    v_lines, xs = [], []
    if lines_v is not None:
        for l in lines_v:
            x1, y1, x2, y2 = l[0]
            if abs(x1 - x2) <= 3:
                xs.append(int(round((x1 + x2) / 2)))
                v_lines.append((int(round((x1 + x2) / 2)), min(y1, y2), max(y1, y2)))

    # gom nhóm x
    x_pos, tol_v = [], 10
    for v in sorted(xs):
        if not x_pos or v - x_pos[-1] > tol_v:
            x_pos.append(v)

    total_cols = max(0, len(x_pos) - 1)
    tables = []

    if total_rows > 0 and total_cols > 0:
        y_min, y_max = ys[0], ys[-1]
        x_min, x_max = x_pos[0], x_pos[-1]
        table_box = (x_min, y_min, x_max, y_max)

        rows_data = []
        for i in range(total_rows):
            row_cells = []
            j = 0
            while j < total_cols:
                cell_box = (x_pos[j], ys[i], x_pos[j+1], ys[i+1])
                row_height = cell_box[3] - cell_box[1]

                # Check vertical line coverage (>=70% chiều cao hàng)
                has_left = any(
                    abs(x - cell_box[0]) <= tol_v and
                    (min(y_end, cell_box[3]) - max(y_start, cell_box[1])) >= 0.7 * row_height
                    for x, y_start, y_end in v_lines
                )
                has_right = any(
                    abs(x - cell_box[2]) <= tol_v and
                    (min(y_end, cell_box[3]) - max(y_start, cell_box[1])) >= 0.7 * row_height
                    for x, y_start, y_end in v_lines
                )

                if has_left and has_right:
                    col_start = j
                    col_end = j
                    # nếu cột tiếp theo không có line → merge
                    while col_end + 1 < total_cols:
                        next_box = (x_pos[col_end+1], ys[i], x_pos[col_end+2], ys[i+1])
                        has_next_left = any(
                            abs(x - next_box[0]) <= tol_v and
                            (min(y_end, next_box[3]) - max(y_start, next_box[1])) >= 0.7 * row_height
                            for x, y_start, y_end in v_lines
                        )
                        if not has_next_left:  # merge tiếp
                            col_end += 1
                        else:
                            break

                    merged_box = (x_pos[col_start], ys[i], x_pos[col_end+1], ys[i+1])
                    if col_start == col_end:
                        col_id = col_start
                    else:
                        col_id = f"{col_start}-{col_end}"

                    row_cells.append({
                        "cell": merged_box,
                        "row_idx": i,
                        "col_idx": col_id
                    })
                    cv2.rectangle(img, (merged_box[0], merged_box[1]),
                                  (merged_box[2], merged_box[3]), (0, 255, 255), 1)
                    j = col_end + 1
                else:
                    j += 1  # skip ô lỗi (không có line đầy đủ)

            rows_data.append(row_cells)

        tables.append({
            "total_rows": int(total_rows),
            "total_cols": int(total_cols),
            "table_box": table_box,
            "cells": rows_data
        })
        cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255, 0, 0), 2)

    debug_path = os.path.splitext(image_path)[0] + "_fix_debug.jpg"
    cv2.imwrite(debug_path, img)

    return tables