You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

312 lines
9.6 KiB

import logging
import os
from datetime import datetime
from jinja2 import Environment, FileSystemLoader
# from config import *
from db import get_table, get_columns
from utils import *
import argparse
import yaml
env = Environment(loader=FileSystemLoader("templates"))
def build_fields(table_name):
columns = get_columns(table_name)
fields = []
for c in columns:
fields.append({
"java_name": to_camel(c["column_name"]),
"tab_name": c["column_name"],
"tab_type": c["data_type"],
"java_get_name": to_m_camel(c["column_name"]),
"java_type": mysql_to_java(c["data_type"]),
"comment": c["column_comment"]
})
return fields
def render(template_name, out_path, context, overwrite=False):
"""
:param template_name: 模板文件名
:param out_path: 输出文件路径
:param context: 渲染上下文
:param overwrite: 是否覆盖已存在文件,默认 False
"""
# 文件存在且不允许覆盖 → 直接跳过
if os.path.exists(out_path) and not overwrite:
logging.info("Skip exists file: %s", out_path)
return
tpl = env.get_template(template_name)
content = tpl.render(**context)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
f.write(content)
logging.info("Generated file: %s", out_path)
def generate(table_names: list[str], model_names: list[str], conf_name: str, over_write: bool):
# context = {
# "mainModule": MAIN_MODULE,
# "moduleName": MODULE_NAME,
# "groupId": GROUP_ID,
# "author": AUTHOR,
# "wechat": WECHAT,
# "date": datetime.now().strftime("%Y-%m-%d"),
# "entityLombokModel": True,
# "package": {
# "Base": BASE_PACKAGE,
# "Common": f"{BASE_PACKAGE}.{DEFAULT_PREFIX_URL}.common",
# "Entity": f"{BASE_PACKAGE}.{DEFAULT_PREFIX_URL}.entity",
# "Service": f"{BASE_PACKAGE}.{DEFAULT_PREFIX_URL}.service",
# "Controller": f"{BASE_PACKAGE}.{DEFAULT_PREFIX_URL}.controller",
# "ServiceImpl": f"{BASE_PACKAGE}.{DEFAULT_PREFIX_URL}.service.impl",
# "Mapper": f"{BASE_PACKAGE}.{DEFAULT_PREFIX_URL}.mapper"
# },
# "db": DB,
# "application": APPLICATION,
# "restControllerStyle": REST_CONTROLLER_STYLE
# }
with open(conf_name, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
cfg = resolve_config(cfg)
context = {
"mainModule": cfg["mainModule"],
"moduleName": cfg["moduleName"],
"groupId": cfg["groupId"],
"author": cfg["author"],
"wechat": cfg["wechat"],
"date": datetime.now().strftime("%Y-%m-%d"),
"entityLombokModel": cfg["entityLombokModel"],
"package": cfg["package"],
"db": cfg["db"],
"application": cfg["application"],
"restControllerStyle": cfg["restControllerStyle"]
}
# MAIN_BASE_PACKAGE_DIR = f"{OUTPUT_DIR}/java/{to_path(BASE_PACKAGE)}"
# MAIN_OUTPUT_DIR = f"{MAIN_BASE_PACKAGE_DIR}/{DEFAULT_PREFIX_URL}"
BASE_DIR = cfg["baseDir"]
BASE_PACKAGE = cfg["package"]["Base"]
OUTPUT_DIR = cfg["outputDir"]
MAIN_MODULE = cfg["mainModule"]
MODULE_NAME = cfg["moduleName"]
MAIN_BASE_PACKAGE_DIR = f"{cfg['outputDir']}/java/{to_path(BASE_PACKAGE)}"
MAIN_OUTPUT_DIR = f"{MAIN_BASE_PACKAGE_DIR}/{cfg['package']['Models']}"
# ========= 按表循环 =========
for table_name in table_names:
table = get_table(table_name)
entity = to_class(table_name)
context = dict(context)
context.update({
"fields": build_fields(table_name),
"table": {
"entity": entity,
"lowerEntity": lower_first(entity),
"name": table_name,
"comment": table["table_comment"]
}
})
# ========= 需要循环生成的模板 =========
render(
"entity.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/entity/{entity}.java",
context,
over_write
)
render(
"controller.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/controller/{entity}Controller.java",
context,
over_write
)
render(
"service.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/service/{entity}Service.java",
context,
over_write
)
render(
"serviceImpl.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/service/impl/{entity}MPJBaseServiceImpl.java",
context,
over_write
)
render(
"mapper.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/mapper/{entity}Mapper.java",
context,
over_write
)
render(
"mapper.xml.j2",
f"{MAIN_MODULE}/{MODULE_NAME}/src/main/resources/mappers/{entity}Mapper.xml",
context,
over_write
)
# ========= 生成固定模板 =========
# BaseEntity
render(
"baseEntity.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/entity/BaseEntity.java",
context
)
# common MybatisPlusConfig
render(
"mybatisPlusConfig.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/config/MybatisPlusConfig.java",
context
)
# common MybatisPlusConfig
render(
"webLogAspect.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/config/WebLogAspect.java",
context
)
# common 基础输出result
render(
"result.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/vo/Result.java",
context
)
#Util 公共功能
for file in ["Md5HashUtil.java.j2","FilesUtil.java.j2"]:
render(
file,
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/unit/{file.replace('.j2', '')}",
context
)
# application 启动的方法
render(
"application.java.j2",
f"{BASE_DIR}{MAIN_BASE_PACKAGE_DIR}/Application.java",
context
)
# test 测试类
render(
"applicationTests.java.j2",
f"{BASE_DIR}{OUTPUT_DIR}/test/{to_path(BASE_PACKAGE)}/ApplicationTests.java",
context
)
# 主pom文件
render(
"main.pom.xml.j2",
f"{BASE_DIR}{MAIN_MODULE}/pom.xml",
context
)
# 子项目pom文件
render(
"project.pom.xml.j2",
f"{BASE_DIR}{MAIN_MODULE}/{MODULE_NAME}/pom.xml",
context
)
#项目的yml配置文件 resources 生成环境配置为了最低限度能将项目跑起来
render(
"application.yml.j2",
f"{MAIN_MODULE}/{MODULE_NAME}/src/main/resources/application.yml",
context
)
#项目开发环境的yml配置文件 resources yml 只生成dev环境配置为了最低限度能将项目跑起来
render(
"application-dev.yml.j2",
f"{MAIN_MODULE}/{MODULE_NAME}/src/main/resources/application-dev.yml",
context
)
#项目开发环境的yml配置文件 resources yml 只生成dev环境配置为了最低限度能将项目跑起来
render(
"logback.xml.j2",
f"{MAIN_MODULE}/{MODULE_NAME}/src/main/resources/logback.xml",
context
)
# ========= 功能模块 =========
for model_name in model_names:
match model_name:
case "swagger":
# common Swagger2
render(
"swagger2.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/config/Swagger2.java",
context
)
case "saToken":
# common GlobalException soToken 报错自定义
render(
"globalException.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/config/GlobalException.java",
context
)
render(
"saTokenConfigure.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/config/SaTokenConfig.java",
context
)
case "minio":
#MinioConfig
render(
"MinioConfig.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/config/MinioConfig.java",
context
)
render(
"MinioUpController.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/controller/MinioUpController.java",
context
)
render(
"MinioUpComponent.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/unit/MinioUpComponent.java",
context
)
case "xxlJob":
# common XxlJobConfig
render(
"xxlJobConfig.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/config/XxlJobConfig.java",
context
)
# common xxjob的测试类
render(
"testJob.java.j2",
f"{BASE_DIR}{MAIN_OUTPUT_DIR}/common/job/TestJob.java",
context
)
if __name__ == "__main__":
args = parse_args()
tables = [t.strip() for t in args.tab.split(",") if t.strip()]
models = [m.strip() for m in args.model.split(",") if m.strip()]
conf = args.conf
re = args.re
generate(
table_names=tables,
model_names=models,
conf_name=conf,
over_write=re
)