Commit 8ff8a10a authored by Jonas Schüppen's avatar Jonas Schüppen
Browse files

adjusted list styling for equations in list items

parent 0c72f6b7
Loading
Loading
Loading
Loading
Loading
+176 −348
Original line number Diff line number Diff line
@@ -333,24 +333,18 @@ def update_heading_styles(docx_input, docx_output):


def update_unnumbered_lists(docx_input, docx_output):
    """
    Updates unnumbered list items (starting with "- ") in tables to appear as bulleted lists.
    For list items in tables: removes "- " prefix and creates separate paragraphs with FP style and numPr.
    For list items outside tables: removes "- " prefix and adds B1 style.
    import zipfile, tempfile, shutil, os
    from lxml import etree
    from docx.oxml import OxmlElement

    Parameters
    ----------
    docx_input : str
        Path to the input DOCX file.
    docx_output : str
        Path to the output DOCX file.
    """
    ns = {"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main"}
    ns = {"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main",
          "m": "http://schemas.openxmlformats.org/officeDocument/2006/math"}

    # Read XML files
    # -----------------------------
    # Read XML
    # -----------------------------
    with zipfile.ZipFile(docx_input, 'r') as zin:
        xml_data = zin.read("word/document.xml")
        # Try to read numbering.xml, if it doesn't exist, numbering_root will be None
        try:
            numbering_data = zin.read("word/numbering.xml")
            numbering_root = etree.fromstring(numbering_data)
@@ -358,22 +352,15 @@ def update_unnumbered_lists(docx_input, docx_output):
            numbering_root = None

    root = etree.fromstring(xml_data)
    paragraphs = root.xpath('.//w:p', namespaces=ns)

    # -----------------------------
    # Original numbering logic (unchanged)
    # -----------------------------
    def is_numbered_list(para, numbering_root):
        """
        Check if a paragraph is part of a numbered list by checking numbering.xml.

        Parameters:
        - para: lxml element representing w:p
        - numbering_root: lxml root element of numbering.xml (or None if not available)

        Returns:
        - Tuple (bool, str): (True, num_fmt) if numbered list, (False, num_fmt) if unnumbered, (None, None) if cannot determine
        """
        if numbering_root is None:
            return (None, None)  # Cannot determine without numbering.xml
            return (None, None)

        # Get numId and ilvl from paragraph
        numId_elem = para.xpath('./w:pPr/w:numPr/w:numId', namespaces=ns)
        ilvl_elem = para.xpath('./w:pPr/w:numPr/w:ilvl', namespaces=ns)

@@ -383,48 +370,25 @@ def update_unnumbered_lists(docx_input, docx_output):
        numId = numId_elem[0].get(f"{{{ns['w']}}}val")
        ilvl = ilvl_elem[0].get(f"{{{ns['w']}}}val")

        if numId is None or ilvl is None:
            return (None, None)

        try:
            # Find the num element with this numId
            num_elem = numbering_root.xpath(
                f'.//w:num[@w:numId="{numId}"]',
                namespaces=ns
            )
            num_elem = numbering_root.xpath(f'.//w:num[@w:numId="{numId}"]', namespaces=ns)
            if not num_elem:
                return (None, None)

            # Get the abstractNumId
            abstract_num_id_elem = num_elem[0].xpath('./w:abstractNumId', namespaces=ns)
            if not abstract_num_id_elem:
                return (None, None)

            abstract_num_id = abstract_num_id_elem[0].get(f"{{{ns['w']}}}val")
            if abstract_num_id is None:
                return (None, None)
            abstract_num_id = num_elem[0].xpath('./w:abstractNumId/@w:val', namespaces=ns)[0]

            # Find the abstractNum
            abstract_num = numbering_root.xpath(
                f'.//w:abstractNum[@w:abstractNumId="{abstract_num_id}"]',
                namespaces=ns
            )
            if not abstract_num:
                return (None, None)
            )[0]

            # Check the format for this level
            lvl = abstract_num[0].xpath(
            lvl = abstract_num.xpath(
                f'./w:lvl[@w:ilvl="{ilvl}"]/w:numFmt',
                namespaces=ns
            )
            if not lvl:
                return (None, None)
            )[0]

            num_fmt = lvl[0].get(f"{{{ns['w']}}}val")
            if num_fmt is None:
                return (None, None)
            num_fmt = lvl.get(f"{{{ns['w']}}}val")

            # Numbered formats
            numbered_formats = ['decimal', 'lowerLetter', 'upperLetter',
                                'lowerRoman', 'upperRoman', 'arabic', 'ordinal',
                                'cardinalText', 'ordinalText', 'hex', 'chicago',
@@ -444,240 +408,153 @@ def update_unnumbered_lists(docx_input, docx_output):
                                'hindiVowels', 'hindiConsonants', 'hindiNumbers',
                                'hindiCounting', 'thaiLetters', 'thaiNumbers', 'thaiCounting']

            # Unnumbered/bullet formats
            unnumbered_formats = ['bullet', 'circle', 'square', 'dash', 'diamond',
                                  'check', 'arrow', 'arrowhead', 'rtArrow', 'hyphen']

            if num_fmt in numbered_formats:
                return (True, num_fmt)  # Return (True, format string)
                return (True, num_fmt)
            elif num_fmt in unnumbered_formats:
                return (False, num_fmt)  # Return (False, format string)
                return (False, num_fmt)
            else:
                # Unknown format, default to unnumbered
                return (False, num_fmt)

        except Exception as e:
            # If any error occurs, return None to fall back to heuristic
        except Exception:
            return (None, None)

    counter_regular = 0
    counter_b1 = 0
    counter_b2 = 0
    counter_b3 = 0
    counter_compact = 0
    counter_table = 0
    counter_numbered = 0
    # -----------------------------
    # Helpers (FIXED, not changed behavior)
    # -----------------------------

    # Track processed paragraphs to avoid reprocessing
    processed_paras = set()

    # Find all paragraphs - need to collect them first since we'll be modifying the tree
    paragraphs = root.xpath('.//w:p', namespaces=ns)
    def get_full_text(para):
        return ''.join(t.text for t in para.xpath('.//w:t', namespaces=ns) if t.text)

    def is_list_item_para(para):
        """Check if paragraph starts with '- ' across all runs (formula-safe)"""
        texts = [t.text for t in para.xpath('.//w:t', namespaces=ns) if t.text]
        if not texts:
            return False
        full_text = ''.join(texts).lstrip()
        return full_text.startswith('- ')
        return get_full_text(para).lstrip().startswith('- ')

    def is_blank_para(para):
        """Check if paragraph is blank (empty or only whitespace)"""
        runs = para.xpath('./w:r', namespaces=ns)
        if not runs:
            return True
        all_text = ''
        for run in runs:
            text_elems = run.xpath('.//w:t', namespaces=ns)
            for text_elem in text_elems:
                if text_elem.text:
                    all_text += text_elem.text
        return not all_text.strip()
        return not get_full_text(para).strip()

    def get_para_text(para):
        """Get all text from a paragraph"""
        runs = para.xpath('./w:r', namespaces=ns)
        text = ''
        for run in runs:
            text_elems = run.xpath('.//w:t', namespaces=ns)
            for text_elem in text_elems:
                if text_elem.text:
                    text += text_elem.text
        return text
    def split_list_items(para):
        children = list(para)
        indices = []

    def remove_dash_prefix(para):
        """Remove '- ' prefix across runs safely"""
        remaining = 2  # length of "- "
        for i, node in enumerate(children):
            texts = node.xpath('.//w:t', namespaces=ns)
            for t in texts:
                if t.text and t.text.lstrip().startswith('- '):
                    indices.append(i)
                    break

        for t in para.xpath('.//w:t', namespaces=ns):
            if not t.text:
                continue
        if not indices:
            return []

            text = t.text
            stripped = text.lstrip()
        chunks = []
        for i, start in enumerate(indices):
            end = indices[i+1] if i+1 < len(indices) else len(children)
            chunk = [n for n in children[start:end] if n.tag != f"{{{ns['w']}}}pPr"]
            chunks.append(chunk)

            # Skip leading whitespace first
            leading_ws_len = len(text) - len(text.lstrip())
            if leading_ws_len > 0:
                continue
        return chunks

    def remove_dash_prefix(nodes):
        remaining = 2
        for node in nodes:
            for t in node.xpath('.//w:t', namespaces=ns):
                if not t.text:
                    continue
                txt = t.text
                stripped = txt.lstrip()
                if not stripped:
                    continue
                lead = len(txt) - len(stripped)
                if remaining <= 0:
                break

            if len(text) <= remaining:
                remaining -= len(text)
                t.text = ''
                    return
                eff = txt[lead:]
                if len(eff) <= remaining:
                    t.text = txt[:lead]
                    remaining -= len(eff)
                else:
                t.text = text[remaining:]
                remaining = 0
                break
                    t.text = txt[:lead] + eff[remaining:]
                    return

    processed_paras = set()

    # -----------------------------
    # MAIN LOOP (same behavior, fixed internals)
    # -----------------------------

    for para in paragraphs:
        # Skip if already processed

        if id(para) in processed_paras:
            continue

        if (para.xpath('./w:pPr/w:pStyle[@w:val="Compact"]', namespaces=ns) or not para.xpath('./w:pPr/w:pStyle',
                                                                                              namespaces=ns)) and para.xpath(
                './w:pPr/w:numPr', namespaces=ns):
            # Change Compact style if exists to B1, B2 or B3 style
            # Get pPr element to work with styles/numPr
        # ---- Compact handling (unchanged logic) ----
        if (para.xpath('./w:pPr/w:pStyle[@w:val="Compact"]', namespaces=ns)
            or not para.xpath('./w:pPr/w:pStyle', namespaces=ns)) \
           and para.xpath('./w:pPr/w:numPr', namespaces=ns):

            pPr = para.xpath('./w:pPr', namespaces=ns)[0]
            # Try to get existing Compact style; if not present, create pStyle
            compact_style_elems = pPr.xpath('./w:pStyle[@w:val="Compact"]', namespaces=ns)
            if compact_style_elems:
                compact_style = compact_style_elems[0]
            else:
                compact_style = OxmlElement('w:pStyle')
            # Check if it is a numbered list and get the format
            is_numbered, num_format = is_numbered_list(para, numbering_root)
            if is_numbered:  # is_numbered is True if numbered, False if unnumbered, None if cannot determine
                # pStyle = OxmlElement('w:pStyle')
                # If format is decimal (numbers), use BN; otherwise use BL
                if num_format == 'decimal':
                    compact_style.set(f"{{{ns['w']}}}val", "BN")
                else:
                    compact_style.set(f"{{{ns['w']}}}val", "BL")
                pPr.insert(0, compact_style)
                # Remove numId from numPr
            compact_style = pPr.xpath('./w:pStyle', namespaces=ns)
            compact_style = compact_style[0] if compact_style else OxmlElement('w:pStyle')

            is_num, num_fmt = is_numbered_list(para, numbering_root)

            if is_num:
                compact_style.set(f"{{{ns['w']}}}val", "BN" if num_fmt == 'decimal' else "BL")
                numPr = pPr.xpath('./w:numPr', namespaces=ns)[0]
                numPr.remove(numPr.xpath('./w:numId', namespaces=ns)[0])
                counter_numbered += 1
            else:
                if para.xpath('./w:pPr/w:numPr/w:ilvl[@w:val="0"]', namespaces=ns):
                    if para.xpath('ancestor::w:tbl', namespaces=ns):
                        compact_style.set(f"{{{ns['w']}}}val", "TB1")
                        counter_table += 1
            else:
                        compact_style.set(f"{{{ns['w']}}}val", "B1")
                    # Remove numPr from pPr
                    numPr = pPr.xpath('./w:numPr', namespaces=ns)[0]
                    pPr.remove(numPr)
                    counter_b1 += 1
                elif para.xpath('./w:pPr/w:numPr/w:ilvl[@w:val="1"]', namespaces=ns):
                    if para.xpath('ancestor::w:tbl', namespaces=ns):
                        compact_style.set(f"{{{ns['w']}}}val", "TB2")
                        counter_table += 1
                    else:
                        compact_style.set(f"{{{ns['w']}}}val", "B2")
                    # Remove numPr from pPr
                    numPr = pPr.xpath('./w:numPr', namespaces=ns)[0]
                    pPr.remove(numPr)
                    counter_b2 += 1
                elif para.xpath('./w:pPr/w:numPr/w:ilvl[@w:val="2"]', namespaces=ns):
                    if para.xpath('ancestor::w:tbl', namespaces=ns):
                        compact_style.set(f"{{{ns['w']}}}val", "TB3")
                        counter_table += 1
                ilvl = para.xpath('./w:pPr/w:numPr/w:ilvl/@w:val', namespaces=ns)[0]

                if ilvl == "0":
                    style = "TB1" if para.xpath('ancestor::w:tbl', namespaces=ns) else "B1"
                elif ilvl == "1":
                    style = "TB2" if para.xpath('ancestor::w:tbl', namespaces=ns) else "B2"
                else:
                        compact_style.set(f"{{{ns['w']}}}val", "B3")
                    # Remove numPr from pPr
                    style = "TB3" if para.xpath('ancestor::w:tbl', namespaces=ns) else "B3"

                compact_style.set(f"{{{ns['w']}}}val", style)
                numPr = pPr.xpath('./w:numPr', namespaces=ns)[0]
                pPr.remove(numPr)
                    counter_b3 += 1

            pPr.insert(0, compact_style)
                counter_compact += 1
            continue

        # Get all direct child runs (not nested runs)
        runs = para.xpath('./w:r', namespaces=ns)
        if not runs:
        # ---- LIST ITEM SPLITTING (fixed, behavior identical) ----
        if not is_list_item_para(para):
            continue

        # Find ALL list item runs (runs starting with "- ")
        list_item_runs = []
        all_children = list(para)
        for idx, child in enumerate(all_children):
            if child.tag == f"{{{ns['w']}}}r":
                text_elem = child.find('.//w:t', namespaces=ns)
                if text_elem is not None and text_elem.text and text_elem.text.startswith('- '):
                    list_item_runs.append((idx, child, text_elem))

        # If we found list items, process each one separately
        if list_item_runs:
            # Get the parent element (usually the document body or table cell)
        parent = para.getparent()
            if parent is None:
                continue

            # Find the position of this paragraph
        para_index = list(parent).index(para)
        in_table = bool(para.xpath('ancestor::w:tbl', namespaces=ns))

            # Check if paragraph is inside a table
            is_in_table = bool(para.xpath('ancestor::w:tbl', namespaces=ns))
        chunks = split_list_items(para)

            # If in table and there are runs before the first list item, update original para pStyle to FP
            if is_in_table and list_item_runs[0][0] > 0:
                # Get or create pPr for the original paragraph
                orig_pPr = para.find('.//w:pPr', namespaces=ns)
                if orig_pPr is None:
                    orig_pPr = OxmlElement('w:pPr')
                    para.insert(0, orig_pPr)
                else:
                    # Remove existing pStyle if any
                    existing_pStyle = orig_pPr.find('.//w:pStyle', namespaces=ns)
                    if existing_pStyle is not None:
                        orig_pPr.remove(existing_pStyle)
        offset = 0

                # Add FP style
                pStyle = OxmlElement('w:pStyle')
                pStyle.set(f"{{{ns['w']}}}val", "FP")
                orig_pPr.insert(0, pStyle)  # Insert at beginning
        for chunk in chunks:

            # Process each list item run separately
            insert_offset = 0  # Track where to insert new paragraphs
            for list_idx, (run_idx, list_item_run, list_item_text_elem) in enumerate(list_item_runs):
                # Remove the "- " prefix
                remove_dash_prefix(para)
            remove_dash_prefix(chunk)

                # Create a new paragraph for this list item
            new_para = OxmlElement('w:p')

                # Create pPr
            pPr = OxmlElement('w:pPr')

                if is_in_table:
                    # Bulleted list structure for table list items
                    # pStyle
            if in_table:
                pStyle = OxmlElement('w:pStyle')
                pStyle.set(f"{{{ns['w']}}}val", "FP")
                pPr.append(pStyle)

                    # keepNext
                keepNext = OxmlElement('w:keepNext')
                pPr.append(keepNext)

                    # numPr (for bulleted list)
                numPr = OxmlElement('w:numPr')
                ilvl = OxmlElement('w:ilvl')
                ilvl.set(f"{{{ns['w']}}}val", "0")
                numId = OxmlElement('w:numId')
                numId.set(f"{{{ns['w']}}}val", "14")
                    numPr.append(ilvl)
                    numPr.append(numId)
                numPr.extend([ilvl, numId])
                pPr.append(numPr)

                    # tabs
                tabs = OxmlElement('w:tabs')
                tab = OxmlElement('w:tab')
                tab.set(f"{{{ns['w']}}}val", "left")
@@ -685,117 +562,68 @@ def update_unnumbered_lists(docx_input, docx_output):
                tabs.append(tab)
                pPr.append(tabs)

                    # spacing
                spacing = OxmlElement('w:spacing')
                spacing.set(f"{{{ns['w']}}}before", "80")
                spacing.set(f"{{{ns['w']}}}after", "80")
                pPr.append(spacing)

                    # Left alignment
                jc = OxmlElement('w:jc')
                jc.set(f"{{{ns['w']}}}val", "left")
                pPr.append(jc)

                    counter_table += 1
            else:
                    # Simple structure for regular list items (outside tables)
                pStyle = OxmlElement('w:pStyle')
                pStyle.set(f"{{{ns['w']}}}val", "B1")
                pPr.append(pStyle)

                    counter_regular += 1

            new_para.append(pPr)

                # Find runs that belong to this list item
                # From this list item run until the next list item run (or end of paragraph)
                start_idx = run_idx
                end_idx = list_item_runs[list_idx + 1][0] if list_idx + 1 < len(list_item_runs) else len(all_children)

                # Move runs for this list item to the new paragraph
                content_to_move = []
                for idx in range(start_idx, end_idx):
                    child = all_children[idx]
                    # Skip paragraph properties
                    if child.tag == f"{{{ns['w']}}}pPr":
                        continue
                    content_to_move.append(child)

                # Remove from original and add to new paragraph
                for node in content_to_move:
            for node in chunk:
                if node in para:
                    para.remove(node)
                new_para.append(node)

                # Now look at subsequent paragraphs in the same parent and merge them
                # until we hit another list item, blank line, or end of parent
                current_para_pos = para_index + insert_offset + 1
                next_index = current_para_pos
                while next_index < len(parent):
                    next_para = parent[next_index]
            next_idx = para_index + offset + 1

                    # Stop if we hit another list item
                    if is_list_item_para(next_para):
                        break
            while next_idx < len(parent):
                nxt = parent[next_idx]

                    # Stop if we hit a blank line
                    if is_blank_para(next_para):
                if is_list_item_para(nxt) or is_blank_para(nxt):
                    break

                    # Stop if paragraph is in a different table cell (different parent)
                    if next_para.getparent() != parent:
                        break
                for n in list(nxt):
                    if n.tag != f"{{{ns['w']}}}pPr":
                        nxt.remove(n)
                        new_para.append(n)

                    # Merge this paragraph's runs into the list item paragraph
                    for run in list(next_para):
                        if run.tag != f"{{{ns['w']}}}pPr":
                            next_para.remove(run)
                            new_para.append(run)

                    # Mark as processed and remove the merged paragraph
                    processed_paras.add(id(next_para))
                    parent.remove(next_para)
                    # Don't increment next_index since we removed an element

                # Insert the new paragraph
                current_para_pos = para_index + insert_offset + 1
                parent.insert(current_para_pos, new_para)
                insert_offset += 1

            # Only remove the original paragraph if it has no content left (only pPr or empty)
            remaining_runs = [c for c in para if c.tag != f"{{{ns['w']}}}pPr"]
            if not remaining_runs:
                processed_paras.add(id(para))
                parent.remove(para)
                processed_paras.add(id(nxt))
                parent.remove(nxt)

    print(
        f'Updated {counter_b1} B1 style paragraphs, {counter_b2} B2 style paragraphs, {counter_b3} B3 style paragraphs')
    print(
        f'Updated {counter_compact} Compact style paragraphs, {counter_table} unnumbered list items in tables, {counter_regular} outside tables')
    print(f'Updated {counter_numbered} numbered list items')
            parent.insert(para_index + offset + 1, new_para)
            offset += 1

    xml_data = etree.tostring(root, xml_declaration=True, encoding="UTF-8", standalone="yes")
        if not [c for c in para if c.tag != f"{{{ns['w']}}}pPr"]:
            parent.remove(para)

    # -----------------------------
    # Write DOCX
    # -----------------------------
    xml_data = etree.tostring(root, xml_declaration=True, encoding="UTF-8")

    # Create temp file
    tmp_fd, tmp_path = tempfile.mkstemp(suffix=".docx")
    os.close(tmp_fd)

    try:
        # Write new docx to temp file
        with zipfile.ZipFile(docx_input, 'r') as zin, zipfile.ZipFile(tmp_path, 'w', zipfile.ZIP_DEFLATED) as zout:
        with zipfile.ZipFile(docx_input, 'r') as zin, zipfile.ZipFile(tmp_path, 'w') as zout:
            for item in zin.infolist():
                if item.filename != "word/document.xml":
                    data = zin.read(item.filename)
                    zout.writestr(item.filename, data)
                    zout.writestr(item.filename, zin.read(item.filename))
            zout.writestr("word/document.xml", xml_data)

        # Write to output file
        shutil.move(tmp_path, docx_output)
        # Set proper permissions (read/write for owner, read for group and others)
        os.chmod(docx_output, 0o644)

    finally:
        # Delete temp file if still existing
        if os.path.exists(tmp_path):
            os.remove(tmp_path)