Back

Adding Windows support to killport

How I added Windows support to the killport project


This blog post covers what I learned and how I implemented Windows support for the killport project Here

The pull request these changes were made in Reference Pull Request

Intro

While doing my late night scrolling through GitHub I stumbled upon this cool little project called killport in my GitHub Explore. I took a peek through the code and the README and though this looked like a cool project. I noticed that the project didn’t support Windows so I jotted down a note to begin working on it tomrrow.

The killport project is a command line tool which allows you to kill any process that is using specific ports

Getting started

Fast forward to the next day im ready to start the project. With my very little experience in working with the Windows API I first decided to browser the API docs in search of a function which could provide me a mapping between port numbers and process IDs.

Functions

The functions that I found for this were the following

Now its time to implement the Rust code to interface with these functions

Setting up the Rust

Now in order to use thse functions I will need to make use of the windows-sys crate to provide mappings for it so I added the crate along with the required feature flags for each of the functions.

[target.'cfg(target_os = "windows")'.dependencies.windows-sys]
version = "0.48"
features = [
  "Win32_Foundation",
  "Win32_NetworkManagement_IpHelper",
  "Win32_System_Threading",
  "Win32_Networking_WinSock",
]

I only add this dependency for the Windows platform as thats the only target that will make use of this dependency

Implementation

Because the GetExtendedTcpTable and GetExtendedUdpTable need to provide a list of unknown size back to me after being called I need to do heap allocations for the API to store the structured data.

Retrieving tables

The following code allocates a buffer for the Windows API to store the table in. The first call to GetExtendedTcpTable will fail as the initialize size is zero, After this first call fails the size variable will be changed to the size estimate from the function this change is detected because the call will return a result of ERROR_INSUFFICIENT_BUFFER when this is recieved the buffer memory can be reallocated to the new size estimate.


/// Reads the extended TCP table into memory using the provided address `family`
/// to determine the output type. Returns the memory pointer to the loaded struct
///
/// # Arguments
///
/// * `layout` - The layout of the memory
/// * `family` - The address family type
unsafe fn get_extended_tcp_table(layout: Layout, family: ADDRESS_FAMILY) -> Result<*mut u8, Error> {
    let mut buffer = std::alloc::alloc(layout);

    // Size estimate for resizing the buffer
    let mut size = 0;

    // Result of asking for the TCP table
    let mut result: u32;

    loop {
        // Ask windows for the extended TCP table mapping between TCP ports and PIDs
        result = GetExtendedTcpTable(
            buffer.cast(),
            &mut size,
            1,
            family as u32,
            TCP_TABLE_OWNER_MODULE_ALL,
            0,
        );

        // No error occurred
        if result == NO_ERROR {
            break;
        }

        // Handle buffer too small
        if result == ERROR_INSUFFICIENT_BUFFER {
            // Resize the buffer to the new size
            buffer = std::alloc::realloc(buffer, layout, size as usize);
            continue;
        }

        // Deallocate the buffer memory
        std::alloc::dealloc(buffer, layout);

        // Handle unknown failures
        return Err(std::io::Error::new(
            ErrorKind::Other,
            "Failed to get size estimate for TCP table",
        ));
    }

    Ok(buffer)
}

The UDP implementation of this is pretty much the same but uses GetExtendedUdpTable and UDP_TABLE_OWNER_MODULE instead

Now that we have the backing memory for the table we can cast the memory pointer to the table structure and create a slice from the known length and begin iterating over checking if any of the ports match

/// Searches through the IPv4 extended TCP table for any processes
/// that are listening on the provided `port`. Will append any processes
/// found onto the provided `pids` list
///
/// # Arguments
///
/// * `port` The port to search for
/// * `pids` The list of process IDs to append to
unsafe fn get_process_tcp_v4(port: u16, pids: &mut Vec<u32>) -> Result<(), Error> {
    // Create the memory layout for the table
    let layout = Layout::new::<MIB_TCPTABLE_OWNER_MODULE>();
    let buffer = get_extended_tcp_table(layout, AF_INET)?;

    let tcp_table: *const MIB_TCPTABLE_OWNER_MODULE = buffer.cast();

    // Read the length of the table
    let length = std::ptr::addr_of!((*tcp_table).dwNumEntries).read_unaligned() as usize;

    // Get a pointer to the start of the table
    let table_ptr: *const MIB_TCPROW_OWNER_MODULE = std::ptr::addr_of!((*tcp_table).table).cast();

    // Find the process IDs
    std::slice::from_raw_parts(table_ptr, length)
        .iter()
        .for_each(|element| {
            // Convert the port value
            let local_port: u16 = (element.dwLocalPort as u16).to_be();
            if local_port == port {
                pids.push(element.dwOwningPid)
            }
        });

    // Deallocate the buffer memory
    std::alloc::dealloc(buffer, layout);

    Ok(())
}

I had to make different implementations for IPv4 and IPv6 as the returned structure was different for the different address families

The UDP version of this function is the same just using the UDP structures instead

Killing Processes

Now for killing processes. In order to kill processes I must first obtain the process handle using OpenProcess with terminate access then I can call TerminateProcess the following code handles this

/// Kills a process with the provided process ID
///
/// # Arguments
///
/// * `pid` - The process ID
unsafe fn kill_process(pid: u32) -> Result<(), Error> {
    info!("Killing process with PID {}", pid);

    // Open the process handle with intent to terminate
    let handle = OpenProcess(PROCESS_TERMINATE, 0, pid);
    if (&handle as *const isize).is_null() {
        return Err(std::io::Error::new(
            ErrorKind::Other,
            format!("Failed to obtain handle to process: {}", pid),
        ));
    }

    let result = TerminateProcess(handle, 0);
    if result == 0 {
        let error = GetLastError();
        return Err(std::io::Error::new(
            ErrorKind::Other,
            format!("Failed to terminate process {}: {:#x}", pid, error),
        ));
    }

    Ok(())
}

Program function

Now that I have all the functions required to implement the functionality I can create the function that the program uses to kill by port. The function name and arguments are the same across all the platform implementations. This function finds all the PIDs for the different protocols and address families and collects them into a list of pids which it then iterates over and kills each of the processes

/// Attempts to kill processes listening on the specified `port`.
///
/// Returns a `Result` with `true` if any processes were killed, `false` if no
/// processes were found listening on the port, and an `Error` if the operation
/// failed or the platform is unsupported.
///
/// # Arguments
///
/// * `port` - A u16 value representing the port number.
pub fn kill_processes_by_port(port: u16, _: KillPortSignalOptions) -> Result<bool, Error> {
    let mut pids = Vec::new();

    unsafe { get_process_tcp_v4(port, &mut pids)? }
    unsafe { get_process_tcp_v6(port, &mut pids)? }
    unsafe { get_process_udp_v4(port, &mut pids)? }
    unsafe { get_process_udp_v6(port, &mut pids)? }

    let mut killed = false;

    for pid in pids {
        debug!("Found process with PID {}", pid);
        unsafe { kill_process(pid)? }
        killed = true;
    }

    Ok(killed)
}

All that was left is to add the new module and function into the main program

#[cfg(target_os = "windows")]
mod windows;

#[cfg(target_os = "windows")]
use windows::kill_processes_by_port;

And with this the implementation was complete and working

What I learned

Through working on this project I learned how to allocate heap memory manually in Rust, How to interface with the Windows API, How to navigate the Windows API documentation, how to work with unsafe Rust and dealing with pointers.