﻿using Clustering_Iris.DataStructures;
using IrisClustering.DataStructures;
using Microsoft.ML;
using Microsoft.ML.Data;
using ScottPlot;
using System.Linq;
using System;
using System.IO;

namespace Clustering_Iris
{
    internal static class Program
    {
        static void Main(string[] args)
        {
            // kontekst
            var mlContext = new MLContext(seed: 1);

            //ścieżki
            var dataPath = Path.Combine(AppContext.BaseDirectory, "Data", "iris-full.txt");
            Directory.CreateDirectory(Path.Combine(AppContext.BaseDirectory, "Models"));
            var modelPath = Path.Combine(AppContext.BaseDirectory, "Models", "IrisModel.zip");

            //wczytanie danych do IDataView
            IDataView fullData = mlContext.Data.LoadFromTextFile<IrisData>(
                path: dataPath,
                hasHeader: true,
                separatorChar: '\t'
            );

            //podział na trening i test(80 / 20)
            var split = mlContext.Data.TrainTestSplit(fullData, testFraction: 0.2);
            var trainingData = split.TrainSet;
            var testData = split.TestSet;

            //łączenie cech numerycznych do jednej kolumny Features
            var dataProcessPipeline = mlContext.Transforms.Concatenate(
                "Features",
                nameof(IrisData.SepalLength),
                nameof(IrisData.SepalWidth),
                nameof(IrisData.PetalLength),
                nameof(IrisData.PetalWidth)
            ).Append(mlContext.Transforms.NormalizeMinMax("Features")); ;

            //(Opcjonalnie) można tu dodać normalizację
            

            //Trener KMeans - uczymy 3 klastry
            var trainer = mlContext.Clustering.Trainers.KMeans(
                featureColumnName: "Features",
                numberOfClusters: 3
            );

            //Trening
            var trainingPipeline = dataProcessPipeline.Append(trainer);
            var model = trainingPipeline.Fit(trainingData);

            //Ewaluacja na zbiorze testowym
            var predictions = model.Transform(testData);
            var metrics = mlContext.Clustering.Evaluate(
                data: predictions,
                scoreColumnName: "Score",
                featureColumnName: "Features"
            );

            var scoredAll = model.Transform(fullData);
            PrintMajorityVote(mlContext, scoredAll, k: 3);
            PrintClusteringMetrics(metrics);
            PlotTruthAndClusters(mlContext, fullData, model);

            //Zapis modelu
            mlContext.Model.Save(model, trainingData.Schema, modelPath);
            Console.WriteLine($"Model zapisany do: {modelPath}");
            Console.WriteLine("=============== End of training process ===============");

            //Predykcja jednego przykładu (przypisanie do klastra)
            var sample = new IrisData
            {
                SepalLength = 3.3f,
                SepalWidth = 1.6f,
                PetalLength = 0.2f,
                PetalWidth = 5.1f,
            };

            //wczytanie modelu z pliku
            var loadedModel = mlContext.Model.Load(modelPath, out _);

            var predEngine = mlContext.Model.CreatePredictionEngine<IrisData, IrisPrediction>(loadedModel);
            var pred = predEngine.Predict(sample);

            Console.WriteLine("=============== Single Prediction ===============");
            Console.WriteLine($"Assigned cluster: {pred.SelectedClusterId}");
            Console.WriteLine($"Distances vector length: {pred.Distance.Length}");
            Console.WriteLine("Distances: " + string.Join(", ", pred.Distance));
            Console.WriteLine("================================================");
        }

        private static void PrintClusteringMetrics(ClusteringMetrics m)
        {
            Console.WriteLine("=============== Clustering Metrics ===============");
            Console.WriteLine($"AverageDistance: {m.AverageDistance}");
            Console.WriteLine($"DaviesBouldinIndex: {m.DaviesBouldinIndex}");
            Console.WriteLine("==================================================");
        }

        static string SpeciesName(int label) => label switch
        {
            0 => "Iris-setosa",
            1 => "Iris-versicolor",
            2 => "Iris-virginica",
            _ => $"Label {label}"
        };

    static void PlotTruthAndClusters(MLContext mlContext, IDataView fullData, ITransformer model)
    {
        var scored = model.Transform(fullData);
        var rows = mlContext.Data.CreateEnumerable<IrisScoredRow>(scored, reuseRowObject: false).ToList();

        //Prawdziwe gatunki
        {
            var plt = new ScottPlot.Plot(900, 600);

            foreach (var g in rows.GroupBy(r => (int)r.Label).OrderBy(g => g.Key))
            {
                var xs = g.Select(r => (double)r.PetalLength).ToArray();
                var ys = g.Select(r => (double)r.PetalWidth).ToArray();
                plt.AddScatter(xs, ys, label: SpeciesName(g.Key), lineWidth: 0, markerSize: 6);
            }

            plt.Title("Iris: prawdziwe gatunki (kolor = label)");
            plt.XLabel("PetalLength");
            plt.YLabel("PetalWidth");
            plt.Legend();

            var path = Path.Combine(AppContext.BaseDirectory, "iris_truth.png");
            plt.SaveFig(path);
            Console.WriteLine($"Zapisano: {path}");
        }

        //KMeans (predicted cluster)
        {
            var plt = new ScottPlot.Plot(900, 600);

            foreach (var g in rows.GroupBy(r => r.PredictedClusterId).OrderBy(g => g.Key))
            {
                var xs = g.Select(r => (double)r.PetalLength).ToArray();
                var ys = g.Select(r => (double)r.PetalWidth).ToArray();
                plt.AddScatter(xs, ys, label: $"Cluster {g.Key}", lineWidth: 0, markerSize: 6);
            }

            plt.Title("Iris: KMeans (kolor = predicted cluster)");
            plt.XLabel("PetalLength");
            plt.YLabel("PetalWidth");
            plt.Legend();

            var path = Path.Combine(AppContext.BaseDirectory, "iris_clusters.png");
            plt.SaveFig(path);
            Console.WriteLine($"Zapisano: {path}");
        }
    }
        //gatunki, które dominują w danym klastrze
        static void PrintMajorityVote(MLContext mlContext, IDataView scoredData, int k)
        {
            var rows = mlContext.Data.CreateEnumerable<IrisScoredRow>(scoredData, reuseRowObject: false).ToList();

            Console.WriteLine("=============== Majority vote: Cluster -> dominant species ===============");

            var summary =
                rows.GroupBy(r => r.PredictedClusterId)
                    .OrderBy(g => g.Key)
                    .Select(g =>
                    {
                        int total = g.Count();

                        var best = g.GroupBy(r => (int)r.Label)
                                    .Select(h => new { Label = h.Key, Count = h.Count() })
                                    .OrderByDescending(x => x.Count)
                                    .First();

                        double purity = (double)best.Count / total;

                        return new
                        {
                            Cluster = g.Key,
                            DominantLabel = best.Label,
                            DominantName = SpeciesName(best.Label),
                            DominantCount = best.Count,
                            Total = total,
                            Purity = purity
                        };
                    })
                    .ToList();

            foreach (var s in summary)
            {
                Console.WriteLine($"Cluster {s.Cluster}: {s.DominantName} ({s.DominantCount}/{s.Total}, purity={s.Purity:P1})");
            }

            Console.WriteLine("==========================================================================");
            Console.WriteLine();
            Console.WriteLine("=============== Cluster x Species counts ===============");

            Console.WriteLine("Cluster | Setosa | Versicolor | Virginica | Total");
            Console.WriteLine("--------------------------------------------------");

            //Wiersze tabeli: dla każdego klastra ile jest label 0/1/2
            for (uint clusterId = 1; clusterId <= (uint)k; clusterId++)
            {
                var clusterRows = rows.Where(r => r.PredictedClusterId == clusterId).ToList();
                int c0 = clusterRows.Count(r => (int)r.Label == 0);
                int c1 = clusterRows.Count(r => (int)r.Label == 1);
                int c2 = clusterRows.Count(r => (int)r.Label == 2);
                int total = clusterRows.Count;

                Console.WriteLine($"{clusterId,7} | {c0,6} | {c1,10} | {c2,9} | {total,5}");
            }

            Console.WriteLine("===============================================================================");

            int correct = summary.Sum(s => s.DominantCount);
            Console.WriteLine();
            Console.WriteLine($"Majority-vote accuracy: {(double)correct / rows.Count:P2}");
        }

    }
}
