Experiments into markov chains, n-grams, and text generation.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

134 lines
4.8 KiB

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using MarkovGrams.Utilities;
using SBRL.Algorithms;
namespace MarkovGrams
{
/// <summary>
/// An unweighted character-based markov chain.
/// </summary>
public class WeightedMarkovChain
{
private WeightedRandom<string> wrandom = new WeightedRandom<string>();
/// <summary>
/// The ngrams that this markov chain currently contains.
/// </summary>
private Dictionary<string, double> ngrams;
/// <summary>
/// Whether to always start generating a new word from an n-gram that starts with
/// an uppercase letter.
/// </summary>
public bool StartOnUppercase = false;
/// <summary>
/// The generation mode to use when running the Markov Chain.
/// </summary>
/// <remarks>
/// The input n-grams must have been generated using the same mode specified here.
/// </remarks>
public GenerationMode Mode { get; private set; } = GenerationMode.CharacterLevel;
/// <summary>
/// Creates a new character-based markov chain.
/// </summary>
/// <param name="inNgrams">The ngrams to populate the new markov chain with.</param>
public WeightedMarkovChain(Dictionary<string, double> inNgrams, GenerationMode inMode) {
ngrams = inNgrams;
Mode = inMode;
}
public WeightedMarkovChain(Dictionary<string, int> inNgrams, GenerationMode inMode) {
ngrams = new Dictionary<string, double>();
foreach (KeyValuePair<string, int> ngram in inNgrams)
ngrams[ngram.Key] = ngram.Value;
Mode = inMode;
}
/// <summary>
/// Returns a random ngram that's currently loaded into this WeightedMarkovChain.
/// </summary>
/// <returns>A random ngram from this UnweightedMarkovChain's cache of ngrams.</returns>
public string RandomNgram()
{
if (wrandom.Count == 0) {
if (!StartOnUppercase)
wrandom.SetContents(ngrams);
else {
ConcurrentDictionary<string, double> filteredNGrams = new ConcurrentDictionary<string, double>();
Parallel.ForEach(ngrams, (KeyValuePair<string, double> pair) => {
if (!char.IsUpper(pair.Key[0])) return;
if (!filteredNGrams.TryAdd(pair.Key, pair.Value))
throw new Exception("Error: Couldn't add to uppercase staging n-gram ConcurrentDictionary!");
});
if (filteredNGrams.Count() == 0)
throw new Exception($"Error: No valid starting ngrams were found (StartOnUppercase: {StartOnUppercase}).");
wrandom.SetContents(filteredNGrams);
}
}
return wrandom.Next();
}
/// <summary>
/// Generates a new random string from the currently stored ngrams.
/// </summary>
/// <param name="length">
/// The length of ngram to generate.
/// Note that this is a target, not a fixed value - e.g. passing 2 when the n-gram order is 3 will
/// result in a string of length 3. Also, depending on the current ngrams this markov chain contains,
/// it may end up being cut short.
/// </param>
/// <returns>A new random string.</returns>
public string Generate(int length)
{
return Generate(length, out float noop);
}
public string Generate(int length, out float choicePointRatio)
{
string result = RandomNgram();
string lastNgram = result;
ConcurrentBag<int> choiceCounts = new ConcurrentBag<int>(); int i = 0;
while((Mode == GenerationMode.CharacterLevel ? result.Length : result.CountCharInstances(" ".ToCharArray()) + 1) < length)
{
wrandom.ClearContents();
// The substring that the next ngram in the chain needs to start with
string nextStartsWith = Mode == GenerationMode.CharacterLevel ? lastNgram.Substring(1) : string.Join(" ", lastNgram.Split(' ').Skip(1));
// Get a list of possible n-grams we could choose from next
ConcurrentDictionary<string, double> convNextNgrams = new ConcurrentDictionary<string, double>();
Parallel.ForEach(ngrams, (KeyValuePair<string, double> ngramData) => {
if (!ngramData.Key.StartsWithFast(nextStartsWith)) return;
if (!convNextNgrams.TryAdd(ngramData.Key, ngramData.Value))
throw new Exception("Error: Failed to add to staging ngram concurrent dictionary");
});
choiceCounts.Add(convNextNgrams.Count());
// If there aren't any choices left, we can't exactly keep adding to the new string any more :-(
if(convNextNgrams.Count == 0)
break;
wrandom.SetContents(convNextNgrams);
// Pick a random n-gram from the list
string nextNgram = wrandom.Next();
// Add the last character from the n-gram to the string we're building
if (Mode == GenerationMode.CharacterLevel)
result += nextNgram[nextNgram.Length - 1];
else
result += ' ' + nextNgram.Substring(nextNgram.LastIndexOf(' ') + 1);
lastNgram = nextNgram; i++;
}
wrandom.ClearContents();
if (choiceCounts.Sum() > 0)
choicePointRatio = (float)choiceCounts.Sum() / (float)(i + 1);
else
choicePointRatio = 0;
return result;
}
}
}