#!/usr/bin/env python3
# Copyright (C) 2025 Rivos Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import filecmp
import os
import pathlib
import shutil
import subprocess
import sys
import tempfile
from collections import defaultdict

SOURCE_FILES = [
    {
        "files": [
            "protos/perfetto/common/builtin_clock.proto",
            "protos/perfetto/common/semantic_type.proto",
            "protos/perfetto/common/track_event_descriptor.proto",
            "protos/perfetto/config/priority_boost/priority_boost_config.proto",
            "protos/perfetto/config/test_config.proto",
            "protos/perfetto/config/trace_config.proto",
            "protos/perfetto/config/track_event/track_event_config.proto",
            "protos/perfetto/trace/clock_snapshot.proto",
            "protos/perfetto/trace/profiling/profile_common.proto",
            "protos/perfetto/trace/test_event.proto",
            "protos/perfetto/trace/trace.proto",
            "protos/perfetto/trace/track_event/chrome_active_processes.proto",
            "protos/perfetto/trace/track_event/chrome_application_state_info.proto",
            "protos/perfetto/trace/track_event/chrome_compositor_scheduler_state.proto",
            "protos/perfetto/trace/track_event/chrome_content_settings_event_info.proto",
            "protos/perfetto/trace/track_event/chrome_frame_reporter.proto",
            "protos/perfetto/trace/track_event/chrome_histogram_sample.proto",
            "protos/perfetto/trace/track_event/chrome_keyed_service.proto",
            "protos/perfetto/trace/track_event/chrome_latency_info.proto",
            "protos/perfetto/trace/track_event/chrome_legacy_ipc.proto",
            "protos/perfetto/trace/track_event/chrome_message_pump.proto",
            "protos/perfetto/trace/track_event/chrome_mojo_event_info.proto",
            "protos/perfetto/trace/track_event/chrome_process_descriptor.proto",
            "protos/perfetto/trace/track_event/chrome_renderer_scheduler_state.proto",
            "protos/perfetto/trace/track_event/chrome_thread_descriptor.proto",
            "protos/perfetto/trace/track_event/chrome_user_event.proto",
            "protos/perfetto/trace/track_event/chrome_window_handle_event_info.proto",
            "protos/perfetto/trace/track_event/counter_descriptor.proto",
            "protos/perfetto/trace/track_event/debug_annotation.proto",
            "protos/perfetto/trace/track_event/log_message.proto",
            "protos/perfetto/trace/track_event/process_descriptor.proto",
            "protos/perfetto/trace/track_event/screenshot.proto",
            "protos/perfetto/trace/track_event/source_location.proto",
            "protos/perfetto/trace/track_event/task_execution.proto",
            "protos/perfetto/trace/track_event/thread_descriptor.proto",
            "protos/perfetto/trace/track_event/track_descriptor.proto",
            "protos/perfetto/trace/track_event/track_event.proto",
        ],
        "custom_files": [
            "protos/perfetto/common/data_source_descriptor.proto",
            "protos/perfetto/config/data_source_config.proto",
            "protos/perfetto/trace/interned_data/interned_data.proto",
            "protos/perfetto/trace/trace_packet.proto",
        ],
        "path_strip_prefix": "protos/perfetto",
        "path_add_prefix": "contrib/rust-sdk/perfetto/src/protos",
    },
    {
        "files": [
            "protos/perfetto/common/gpu_counter_descriptor.proto",
            "protos/perfetto/config/gpu/gpu_counter_config.proto",
            "protos/perfetto/config/gpu/gpu_renderstages_config.proto",
            "protos/perfetto/config/gpu/vulkan_memory_config.proto",
            "protos/perfetto/trace/gpu/gpu_counter_event.proto",
            "protos/perfetto/trace/gpu/gpu_log.proto",
            "protos/perfetto/trace/gpu/gpu_render_stage_event.proto",
            "protos/perfetto/trace/gpu/vulkan_api_event.proto",
            "protos/perfetto/trace/gpu/vulkan_memory_event.proto",
        ],
        "custom_files": [
            "protos/perfetto/common/data_source_descriptor.proto",
            "protos/perfetto/config/data_source_config.proto",
            "protos/perfetto/trace/interned_data/interned_data.proto",
            "protos/perfetto/trace/trace_packet.proto",
        ],
        "path_strip_prefix": "protos/perfetto",
        "path_add_prefix": "contrib/rust-sdk/perfetto-protos-gpu/src/protos",
    },
]

ROOT_DIR = os.path.dirname(
    os.path.dirname(
        os.path.dirname(os.path.dirname(os.path.realpath(__file__)))))
IS_WIN = sys.platform.startswith("win")

SCRIPT_PATH = "contrib/rust-sdk/tools/gen_rust_protos"


def protozero_rust_plugin_path(out_directory):
  path = os.path.join(out_directory,
                      "protozero_rust_plugin") + (".exe" if IS_WIN else "")
  assert os.path.isfile(path)
  return path


def protoc_path(out_directory):
  path = os.path.join(out_directory, "protoc") + (".exe" if IS_WIN else "")
  assert os.path.isfile(path)
  return path


def call(cmd, *args):
  path = os.path.join("tools", cmd)
  command = ["python3", path] + list(args)
  print("Running", " ".join(command))
  try:
    subprocess.check_call(command, cwd=ROOT_DIR)
  except subprocess.CalledProcessError as e:
    assert False, "Command: {} failed".format(" ".join(command))


# Transforms filename extension like the ProtoZero Rust plugin
def transform_extension(filename):
  old_suffix = ".proto"
  new_suffix = ".pz.rs"
  if filename.endswith(old_suffix):
    return filename[:-len(old_suffix)] + new_suffix
  return filename


def generate(
    source,
    outdir,
    protoc_path,
    protozero_rust_plugin_path,
    path_strip_prefix,
    path_add_prefix,
):
  options = {
      "path_strip_prefix": path_strip_prefix,
      "path_add_prefix": path_add_prefix,
      "invoker": SCRIPT_PATH,
  }
  serialized_options = ",".join(
      ["{}={}".format(name, value) for name, value in options.items()])
  subprocess.check_call(
      [
          protoc_path,
          "--proto_path=.",
          "--plugin=protoc-gen-plugin={}".format(protozero_rust_plugin_path),
          "--plugin_out={}:{}".format(serialized_options, outdir),
          source,
      ],
      cwd=ROOT_DIR,
  )


def generate_mod(tmpfilename, mods):
  with open(tmpfilename, "w") as f:
    print(
        """// Copyright (C) 2025 Rivos Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Autogenerated by the gen_rust_protos script.
// DO NOT EDIT.""",
        file=f,
    )
    for mod in sorted(mods):
      modname = transform_extension(os.path.basename(mod))
      if modname.endswith(".pz.rs"):
        print(f"\n/// `{modname[:-6]}` protos.", file=f)
        print(f'#[path = "{modname}"]', file=f)
        print(f"pub mod {modname[:-6]};", file=f)
      else:
        print(f"\n/// `{modname}` protos.", file=f)
        print(f"pub mod {modname};", file=f)


def mods_by_directory(sources, path_strip_prefix):
  groups = defaultdict(set)
  for path in sources:
    if path.startswith(path_strip_prefix):
      path = path[len(path_strip_prefix):]
    modpath = pathlib.Path(path)
    for parent in modpath.parents:
      if parent != pathlib.Path("/"):
        directory = str(parent)
        groups[directory].add(path)
        path = directory
  return dict(groups)


def rust_path_for(path, path_strip_prefix, path_add_prefix):
  if path.startswith(path_strip_prefix):
    rust_path = path[len(path_strip_prefix):]
  else:
    rust_path = path
  return path_add_prefix + rust_path


def main():
  parser = argparse.ArgumentParser()
  parser.add_argument("--check-only", action="store_true")
  parser.add_argument("OUT")
  args = parser.parse_args()
  out = args.OUT

  call("ninja", "-C", out, "protoc", "protozero_rust_plugin")

  try:
    with tempfile.TemporaryDirectory() as tmpdirname:
      for sources in SOURCE_FILES:
        for source in sources["files"]:
          generate(
              source,
              tmpdirname,
              protoc_path(out),
              protozero_rust_plugin_path(out),
              path_strip_prefix=sources["path_strip_prefix"],
              path_add_prefix=sources["path_add_prefix"],
          )

          tmpfilename = os.path.join(tmpdirname, transform_extension(source))
          targetfilename = rust_path_for(
              source,
              sources["path_strip_prefix"],
              sources["path_add_prefix"],
          )
          targetfilename = transform_extension(targetfilename)

          if args.check_only:
            if not filecmp.cmp(tmpfilename, targetfilename):
              raise AssertionError("Target {} does not match", targetfilename)
          else:
            os.makedirs(os.path.dirname(targetfilename), exist_ok=True)
            shutil.copyfile(tmpfilename, targetfilename)

        modsources = sources["files"] + sources["custom_files"]
        for directory, mods in mods_by_directory(
            modsources, sources["path_strip_prefix"]).items():
          tmpfilename = os.path.join(tmpdirname, "mod.rs")
          generate_mod(tmpfilename, mods)
          targetmoddir = rust_path_for(
              directory,
              sources["path_strip_prefix"],
              sources["path_add_prefix"],
          )
          targetfilename = os.path.join(targetmoddir, "mod.rs")

          if args.check_only:
            if not filecmp.cmp(tmpfilename, targetfilename):
              raise AssertionError("Target {} does not match", targetfilename)
          else:
            os.makedirs(os.path.dirname(targetfilename), exist_ok=True)
            shutil.copyfile(tmpfilename, targetfilename)

  except AssertionError as e:
    if not str(e):
      raise
    print("Error: {}".format(e))
    return 1


if __name__ == "__main__":
  exit(main())
