You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
8.4 KiB
261 lines
8.4 KiB
#!/usr/bin/env python3
|
|
import os
|
|
import sys
|
|
import socket
|
|
import struct
|
|
import json
|
|
import hashlib
|
|
import argparse
|
|
import time
|
|
import zipfile
|
|
import tarfile
|
|
import tempfile
|
|
import shutil
|
|
import threading
|
|
|
|
CHUNK_SIZE = 1024 * 1024 # 1MB
|
|
|
|
|
|
def sha256sum(path):
|
|
h = hashlib.sha256()
|
|
with open(path, "rb") as f:
|
|
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
|
h.update(chunk)
|
|
return h.hexdigest()
|
|
|
|
|
|
def pack_folder(folder_path):
|
|
"""打包目录,Windows -> zip,Linux -> tar.gz"""
|
|
tmp_dir = tempfile.gettempdir()
|
|
base_name = os.path.basename(os.path.abspath(folder_path))
|
|
if os.name == "nt":
|
|
archive_path = os.path.join(tmp_dir, base_name + ".zip")
|
|
with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
|
for root, _, files in os.walk(folder_path):
|
|
for file in files:
|
|
full = os.path.join(root, file)
|
|
arcname = os.path.relpath(full, start=folder_path)
|
|
zf.write(full, arcname)
|
|
else:
|
|
archive_path = os.path.join(tmp_dir, base_name + ".tar.gz")
|
|
with tarfile.open(archive_path, "w:gz") as tf:
|
|
tf.add(folder_path, arcname=base_name)
|
|
return archive_path
|
|
|
|
|
|
def unpack_archive(archive_path, outdir):
|
|
if archive_path.endswith(".zip"):
|
|
with zipfile.ZipFile(archive_path, "r") as zf:
|
|
zf.extractall(outdir)
|
|
elif archive_path.endswith(".tar.gz"):
|
|
with tarfile.open(archive_path, "r:gz") as tf:
|
|
tf.extractall(outdir)
|
|
|
|
def send_file(conn, filepath, rate_limit=None):
|
|
"""发送文件(支持限速+断点续传)"""
|
|
filesize = os.path.getsize(filepath)
|
|
checksum = sha256sum(filepath)
|
|
|
|
header = {
|
|
"filename": os.path.basename(filepath),
|
|
"filesize": filesize,
|
|
"sha256": checksum,
|
|
}
|
|
header_bytes = json.dumps(header).encode()
|
|
conn.sendall(struct.pack("!I", len(header_bytes)))
|
|
conn.sendall(header_bytes)
|
|
|
|
# 接收对方告知已接收大小(用于断点续传)
|
|
resume_bytes = struct.unpack("!Q", conn.recv(8))[0]
|
|
print(f" 从 {resume_bytes} 字节处开始续传")
|
|
|
|
sent = resume_bytes
|
|
start_time = time.time()
|
|
last_print = start_time
|
|
|
|
with open(filepath, "rb") as f:
|
|
f.seek(resume_bytes)
|
|
while True:
|
|
chunk = f.read(CHUNK_SIZE)
|
|
if not chunk:
|
|
break
|
|
conn.sendall(chunk)
|
|
sent += len(chunk)
|
|
|
|
# 限速控制
|
|
if rate_limit:
|
|
elapsed = time.time() - start_time
|
|
if sent / elapsed > rate_limit:
|
|
time.sleep(0.01)
|
|
|
|
# 打印进度每0.5秒更新一次
|
|
now = time.time()
|
|
if now - last_print >= 0.5 or sent == filesize:
|
|
elapsed = now - start_time
|
|
progress = sent / filesize * 100
|
|
avg_speed = sent / elapsed
|
|
remaining = (filesize - sent) / avg_speed if avg_speed > 0 else 0
|
|
print(f"\r {progress:.2f}% | 已发送 {sent/(1024*1024):.2f}/{filesize/(1024*1024):.2f} MB | "
|
|
f"平均速率 {avg_speed/(1024*1024):.2f} MB/s | 预计剩余 {remaining:.2f}s", end="")
|
|
last_print = now
|
|
sys.stdout.flush()
|
|
|
|
total_time = time.time() - start_time
|
|
print(f"\n 发送完成: {filepath} ({filesize} bytes) | 总耗时: {total_time:.2f}s")
|
|
sys.stdout.flush()
|
|
|
|
|
|
def recv_file(conn, outdir):
|
|
"""接收文件(支持断点续传)"""
|
|
header_len = struct.unpack("!I", conn.recv(4))[0]
|
|
header = json.loads(conn.recv(header_len).decode())
|
|
|
|
filename = header["filename"]
|
|
filesize = header["filesize"]
|
|
checksum = header["sha256"]
|
|
|
|
outpath = os.path.join(outdir, filename)
|
|
tmp_path = outpath + ".part"
|
|
|
|
received = 0
|
|
if os.path.exists(tmp_path):
|
|
received = os.path.getsize(tmp_path)
|
|
print(f" 检测到未完成文件,已接收 {received} 字节,将继续下载...")
|
|
|
|
# 告诉发送端已接收多少
|
|
conn.sendall(struct.pack("!Q", received))
|
|
|
|
start_time = time.time()
|
|
last_print = start_time
|
|
|
|
with open(tmp_path, "ab") as f:
|
|
while received < filesize:
|
|
chunk = conn.recv(min(CHUNK_SIZE, filesize - received))
|
|
if not chunk:
|
|
break
|
|
f.write(chunk)
|
|
received += len(chunk)
|
|
|
|
# 打印进度每0.5秒更新一次
|
|
now = time.time()
|
|
if now - last_print >= 0.5 or received == filesize:
|
|
elapsed = now - start_time
|
|
progress = received / filesize * 100
|
|
avg_speed = received / elapsed
|
|
remaining = (filesize - received) / avg_speed if avg_speed > 0 else 0
|
|
print(f"\r {progress:.2f}% | 已接收 {received/(1024*1024):.2f}/{filesize/(1024*1024):.2f} MB | "
|
|
f"平均速率 {avg_speed/(1024*1024):.2f} MB/s | 预计剩余 {remaining:.2f}s", end="")
|
|
sys.stdout.flush()
|
|
last_print = now
|
|
|
|
total_time = time.time() - start_time
|
|
print() # 换行
|
|
if received == filesize and sha256sum(tmp_path) == checksum:
|
|
os.rename(tmp_path, outpath)
|
|
print(f"接收完成: {outpath} | 总耗时: {total_time:.2f}s")
|
|
# 自动解压
|
|
#if outpath.endswith(".zip") or outpath.endswith(".tar.gz"):
|
|
# print("📦 检测到压缩包,自动解压...")
|
|
# unpack_archive(outpath, outdir)
|
|
# print(f"📂 已解压到: {outdir}")
|
|
else:
|
|
print("文件不完整或校验失败")
|
|
sys.stdout.flush()
|
|
|
|
|
|
|
|
def run_sender(host, port, path, rate, psk):
|
|
if os.path.isdir(path):
|
|
archive = pack_folder(path)
|
|
print(f"目录已打包: {archive}")
|
|
filepath = archive
|
|
cleanup = True
|
|
else:
|
|
filepath = path
|
|
cleanup = False
|
|
|
|
s = socket.socket()
|
|
s.connect((host, port))
|
|
s.sendall(psk.encode().ljust(32, b"\0")) # PSK验证
|
|
|
|
rate_limit = None
|
|
if rate:
|
|
unit = rate[-1].upper()
|
|
value = float(rate[:-1])
|
|
if unit == "K":
|
|
rate_limit = value * 1024
|
|
elif unit == "M":
|
|
rate_limit = value * 1024 * 1024
|
|
|
|
send_file(s, filepath, rate_limit)
|
|
s.close()
|
|
|
|
if cleanup:
|
|
os.remove(filepath)
|
|
|
|
|
|
def handle_client(conn, addr, outdir, psk):
|
|
try:
|
|
recv_psk = conn.recv(32).strip(b"\0").decode()
|
|
if recv_psk != psk:
|
|
print(f"{addr} PSK 验证失败")
|
|
conn.close()
|
|
return
|
|
|
|
recv_file(conn, outdir)
|
|
except Exception as e:
|
|
print(f"{addr} 处理错误: {e}")
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def run_receiver(port, outdir, psk, once=False):
|
|
s = socket.socket()
|
|
s.bind(("0.0.0.0", port))
|
|
s.listen(5)
|
|
print(f"正在监听 {port}...")
|
|
sys.stdout.flush()
|
|
|
|
while True:
|
|
conn, addr = s.accept()
|
|
print(f"收到连接: {addr}")
|
|
# 用线程处理每个客户端
|
|
t = threading.Thread(target=handle_client, args=(conn, addr, outdir, psk))
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
if once:
|
|
print("一次传输模式,退出主循环")
|
|
break
|
|
|
|
s.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="P2P 文件/目录传输工具(支持续传)")
|
|
parser.add_argument("--mode", choices=["send", "recv"], required=True)
|
|
parser.add_argument("--host", default="127.0.0.1", help="接收端IP (send模式必需,默认127.0.0.1)")
|
|
parser.add_argument("--port", type=int, help="端口号")
|
|
parser.add_argument("--path", help="要发送的文件/目录 (send模式必需)")
|
|
parser.add_argument("--outdir", help="接收保存目录 (recv模式必需)")
|
|
parser.add_argument("--rate", help="限速,例: 500K, 5M")
|
|
parser.add_argument("--psk", required=True, help="预共享密钥")
|
|
parser.add_argument("--once", action="store_true", help="接收一个文件后退出")
|
|
args = parser.parse_args()
|
|
|
|
# 设置默认端口
|
|
if args.port is None:
|
|
if args.mode == "send":
|
|
args.port = 7001
|
|
else:
|
|
args.port = 6000
|
|
|
|
if args.mode == "send":
|
|
if not args.path:
|
|
parser.error("send 模式需要 --path")
|
|
run_sender(args.host, args.port, args.path, args.rate, args.psk)
|
|
else:
|
|
if not args.outdir:
|
|
parser.error("recv 模式需要 --outdir")
|
|
os.makedirs(args.outdir, exist_ok=True)
|
|
run_receiver(args.port, args.outdir, args.psk, args.once)
|